mirror of https://github.com/status-im/migrate.git
Add support for multi-schema migrations in Postgres
There is lock conflict on parallel migrations in different postgres schemas. To avoid this conflicts function GenerateAdvisoryLockId added variadic params to change lock id with schema name. Schema name taked with postgres CURRENT_SCHEMA function. Null byte used as separator between database and schema name, because any other symbol may be used in both of it. Closes #118
This commit is contained in:
parent
9f5e1bd505
commit
16d63e3a76
|
@ -35,6 +35,7 @@ var (
|
|||
type Config struct {
|
||||
MigrationsTable string
|
||||
DatabaseName string
|
||||
SchemaName string
|
||||
}
|
||||
|
||||
type Postgres struct {
|
||||
|
@ -68,6 +69,18 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) {
|
|||
|
||||
config.DatabaseName = databaseName
|
||||
|
||||
query = `SELECT CURRENT_SCHEMA()`
|
||||
var schemaName string
|
||||
if err := instance.QueryRow(query).Scan(&schemaName); err != nil {
|
||||
return nil, &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
|
||||
if len(schemaName) == 0 {
|
||||
return nil, ErrNoSchema
|
||||
}
|
||||
|
||||
config.SchemaName = schemaName
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
@ -133,7 +146,7 @@ func (p *Postgres) Lock() error {
|
|||
return database.ErrLocked
|
||||
}
|
||||
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName)
|
||||
aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -178,6 +178,56 @@ func TestWithSchema(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestParallelSchema(t *testing.T) {
|
||||
mt.ParallelTest(t, versions, isReady,
|
||||
func(t *testing.T, i mt.Instance) {
|
||||
p := &Postgres{}
|
||||
addr := pgConnectionString(i.Host(), i.Port())
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
// create foo and bar schemas
|
||||
if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA foo AUTHORIZATION postgres"))); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA bar AUTHORIZATION postgres"))); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// re-connect using that schemas
|
||||
dfoo, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=foo", i.Host(), i.Port()))
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
defer dfoo.Close()
|
||||
|
||||
dbar, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=bar", i.Host(), i.Port()))
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
defer dbar.Close()
|
||||
|
||||
if err := dfoo.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Lock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dbar.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := dfoo.Unlock(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestWithInstance(t *testing.T) {
|
||||
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package database
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"hash/crc32"
|
||||
)
|
||||
|
@ -8,8 +9,13 @@ import (
|
|||
const advisoryLockIdSalt uint = 1486364155
|
||||
|
||||
// GenerateAdvisoryLockId inspired by rails migrations, see https://goo.gl/8o9bCT
|
||||
func GenerateAdvisoryLockId(databaseName string) (string, error) {
|
||||
sum := crc32.ChecksumIEEE([]byte(databaseName))
|
||||
func GenerateAdvisoryLockId(databaseName string, additionalNames ...string) (string, error) {
|
||||
buf := bytes.NewBufferString(databaseName)
|
||||
for _, name := range additionalNames {
|
||||
buf.WriteByte(0)
|
||||
buf.WriteString(name)
|
||||
}
|
||||
sum := crc32.ChecksumIEEE(buf.Bytes())
|
||||
sum = sum * uint32(advisoryLockIdSalt)
|
||||
return fmt.Sprintf("%v", sum), nil
|
||||
}
|
||||
|
|
|
@ -7,14 +7,21 @@ import (
|
|||
func TestGenerateAdvisoryLockId(t *testing.T) {
|
||||
testcases := []struct {
|
||||
dbname string
|
||||
schema string
|
||||
expectedID string // empty string signifies that an error is expected
|
||||
}{
|
||||
{dbname: "database_name", expectedID: "1764327054"},
|
||||
{dbname: "database_name", schema: "schema_name_1", expectedID: "3244152297"},
|
||||
{dbname: "database_name", schema: "schema_name_2", expectedID: "810103531"},
|
||||
}
|
||||
|
||||
for _, tc := range testcases {
|
||||
t.Run(tc.dbname, func(t *testing.T) {
|
||||
if id, err := GenerateAdvisoryLockId("database_name"); err == nil {
|
||||
names := []string{}
|
||||
if len(tc.schema) > 0 {
|
||||
names = append(names, tc.schema)
|
||||
}
|
||||
if id, err := GenerateAdvisoryLockId(tc.dbname, names...); err == nil {
|
||||
if id != tc.expectedID {
|
||||
t.Error("Generated incorrect ID:", id, "!=", tc.expectedID)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue