mirror of https://github.com/status-im/migrate.git
add mysql custom TLS config
closes https://github.com/mattes/migrate/pull/117
This commit is contained in:
parent
be1ba9204a
commit
150ac7d708
|
@ -1,11 +1,14 @@
|
||||||
package mysql
|
package mysql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"crypto/x509"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
nurl "net/url"
|
nurl "net/url"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/go-sql-driver/mysql"
|
"github.com/go-sql-driver/mysql"
|
||||||
|
@ -23,6 +26,7 @@ var (
|
||||||
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
||||||
ErrNilConfig = fmt.Errorf("no config")
|
ErrNilConfig = fmt.Errorf("no config")
|
||||||
ErrNoDatabaseName = fmt.Errorf("no database name")
|
ErrNoDatabaseName = fmt.Errorf("no database name")
|
||||||
|
ErrAppendPEM = fmt.Errorf("failed to append PEM")
|
||||||
)
|
)
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
|
@ -94,6 +98,42 @@ func (m *Mysql) Open(url string) (database.Driver, error) {
|
||||||
migrationsTable = DefaultMigrationsTable
|
migrationsTable = DefaultMigrationsTable
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// use custom TLS?
|
||||||
|
ctls := purl.Query().Get("tls")
|
||||||
|
if len(ctls) > 0 {
|
||||||
|
if _, isBool := readBool(ctls); !isBool && strings.ToLower(ctls) != "skip-verify" {
|
||||||
|
rootCertPool := x509.NewCertPool()
|
||||||
|
pem, err := ioutil.ReadFile(purl.Query().Get("x-tls-ca"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
if ok := rootCertPool.AppendCertsFromPEM(pem); !ok {
|
||||||
|
return nil, ErrAppendPEM
|
||||||
|
}
|
||||||
|
|
||||||
|
certs, err := tls.LoadX509KeyPair(purl.Query().Get("x-tls-cert"), purl.Query().Get("x-tls-key"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
insecureSkipVerify := false
|
||||||
|
if len(purl.Query().Get("x-tls-insecure-skip-verify")) > 0 {
|
||||||
|
x, err := strconv.ParseBool(purl.Query().Get("x-tls-insecure-skip-verify"))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
insecureSkipVerify = x
|
||||||
|
}
|
||||||
|
|
||||||
|
mysql.RegisterTLSConfig(ctls, &tls.Config{
|
||||||
|
RootCAs: rootCertPool,
|
||||||
|
Certificates: []tls.Certificate{certs},
|
||||||
|
InsecureSkipVerify: insecureSkipVerify,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
mx, err := WithInstance(db, &Config{
|
mx, err := WithInstance(db, &Config{
|
||||||
DatabaseName: purl.Path,
|
DatabaseName: purl.Path,
|
||||||
MigrationsTable: migrationsTable,
|
MigrationsTable: migrationsTable,
|
||||||
|
@ -270,3 +310,18 @@ func (m *Mysql) ensureVersionTable() error {
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Returns the bool value of the input.
|
||||||
|
// The 2nd return value indicates if the input was a valid bool value
|
||||||
|
// See https://github.com/go-sql-driver/mysql/blob/a059889267dc7170331388008528b3b44479bffb/utils.go#L71
|
||||||
|
func readBool(input string) (value bool, valid bool) {
|
||||||
|
switch input {
|
||||||
|
case "1", "true", "TRUE", "True":
|
||||||
|
return true, true
|
||||||
|
case "0", "false", "FALSE", "False":
|
||||||
|
return false, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// Not a valid bool value
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue