diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 8f8faf9..6552f9d 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -35,6 +35,7 @@ var ( type Config struct { MigrationsTable string DatabaseName string + SchemaName string } type Postgres struct { @@ -68,6 +69,18 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.DatabaseName = databaseName + query = `SELECT CURRENT_SCHEMA()` + var schemaName string + if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } + + if len(schemaName) == 0 { + return nil, ErrNoSchema + } + + config.SchemaName = schemaName + if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable } @@ -133,7 +146,7 @@ func (p *Postgres) Lock() error { return database.ErrLocked } - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) if err != nil { return err } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 9da9938..07c0cd3 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -178,6 +178,56 @@ func TestWithSchema(t *testing.T) { }) } +func TestParallelSchema(t *testing.T) { + mt.ParallelTest(t, versions, isReady, + func(t *testing.T, i mt.Instance) { + p := &Postgres{} + addr := pgConnectionString(i.Host(), i.Port()) + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + defer d.Close() + + // create foo and bar schemas + if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA foo AUTHORIZATION postgres"))); err != nil { + t.Fatal(err) + } + if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA bar AUTHORIZATION postgres"))); err != nil { + t.Fatal(err) + } + + // re-connect using that schemas + dfoo, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=foo", i.Host(), i.Port())) + if err != nil { + t.Fatalf("%v", err) + } + defer dfoo.Close() + + dbar, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=bar", i.Host(), i.Port())) + if err != nil { + t.Fatalf("%v", err) + } + defer dbar.Close() + + if err := dfoo.Lock(); err != nil { + t.Fatal(err) + } + + if err := dbar.Lock(); err != nil { + t.Fatal(err) + } + + if err := dbar.Unlock(); err != nil { + t.Fatal(err) + } + + if err := dfoo.Unlock(); err != nil { + t.Fatal(err) + } + }) +} + func TestWithInstance(t *testing.T) { } diff --git a/database/util.go b/database/util.go index 7de1d1b..f0ffd61 100644 --- a/database/util.go +++ b/database/util.go @@ -3,12 +3,16 @@ package database import ( "fmt" "hash/crc32" + "strings" ) const advisoryLockIdSalt uint = 1486364155 // GenerateAdvisoryLockId inspired by rails migrations, see https://goo.gl/8o9bCT -func GenerateAdvisoryLockId(databaseName string) (string, error) { +func GenerateAdvisoryLockId(databaseName string, additionalNames ...string) (string, error) { + if len(additionalNames) > 0 { + databaseName = strings.Join(append(additionalNames, databaseName), "\x00") + } sum := crc32.ChecksumIEEE([]byte(databaseName)) sum = sum * uint32(advisoryLockIdSalt) return fmt.Sprintf("%v", sum), nil diff --git a/database/util_test.go b/database/util_test.go index 0b66d2d..13cba46 100644 --- a/database/util_test.go +++ b/database/util_test.go @@ -7,14 +7,33 @@ import ( func TestGenerateAdvisoryLockId(t *testing.T) { testcases := []struct { dbname string + additional []string expectedID string // empty string signifies that an error is expected }{ - {dbname: "database_name", expectedID: "1764327054"}, + { + dbname: "database_name", + expectedID: "1764327054", + }, + { + dbname: "database_name", + additional: []string{"schema_name_1"}, + expectedID: "2453313553", + }, + { + dbname: "database_name", + additional: []string{"schema_name_2"}, + expectedID: "235207038", + }, + { + dbname: "database_name", + additional: []string{"schema_name_1", "schema_name_2"}, + expectedID: "3743845847", + }, } for _, tc := range testcases { t.Run(tc.dbname, func(t *testing.T) { - if id, err := GenerateAdvisoryLockId("database_name"); err == nil { + if id, err := GenerateAdvisoryLockId(tc.dbname, tc.additional...); err == nil { if id != tc.expectedID { t.Error("Generated incorrect ID:", id, "!=", tc.expectedID) }