diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 8f8faf9..6552f9d 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -35,6 +35,7 @@ var ( type Config struct { MigrationsTable string DatabaseName string + SchemaName string } type Postgres struct { @@ -68,6 +69,18 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { config.DatabaseName = databaseName + query = `SELECT CURRENT_SCHEMA()` + var schemaName string + if err := instance.QueryRow(query).Scan(&schemaName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } + + if len(schemaName) == 0 { + return nil, ErrNoSchema + } + + config.SchemaName = schemaName + if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable } @@ -133,7 +146,7 @@ func (p *Postgres) Lock() error { return database.ErrLocked } - aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName) + aid, err := database.GenerateAdvisoryLockId(p.config.DatabaseName, p.config.SchemaName) if err != nil { return err } diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index 9da9938..07c0cd3 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -178,6 +178,56 @@ func TestWithSchema(t *testing.T) { }) } +func TestParallelSchema(t *testing.T) { + mt.ParallelTest(t, versions, isReady, + func(t *testing.T, i mt.Instance) { + p := &Postgres{} + addr := pgConnectionString(i.Host(), i.Port()) + d, err := p.Open(addr) + if err != nil { + t.Fatalf("%v", err) + } + defer d.Close() + + // create foo and bar schemas + if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA foo AUTHORIZATION postgres"))); err != nil { + t.Fatal(err) + } + if err := d.Run(bytes.NewReader([]byte("CREATE SCHEMA bar AUTHORIZATION postgres"))); err != nil { + t.Fatal(err) + } + + // re-connect using that schemas + dfoo, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=foo", i.Host(), i.Port())) + if err != nil { + t.Fatalf("%v", err) + } + defer dfoo.Close() + + dbar, err := p.Open(fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable&search_path=bar", i.Host(), i.Port())) + if err != nil { + t.Fatalf("%v", err) + } + defer dbar.Close() + + if err := dfoo.Lock(); err != nil { + t.Fatal(err) + } + + if err := dbar.Lock(); err != nil { + t.Fatal(err) + } + + if err := dbar.Unlock(); err != nil { + t.Fatal(err) + } + + if err := dfoo.Unlock(); err != nil { + t.Fatal(err) + } + }) +} + func TestWithInstance(t *testing.T) { } diff --git a/database/util.go b/database/util.go index 7de1d1b..bc62171 100644 --- a/database/util.go +++ b/database/util.go @@ -1,6 +1,7 @@ package database import ( + "bytes" "fmt" "hash/crc32" ) @@ -8,8 +9,13 @@ import ( const advisoryLockIdSalt uint = 1486364155 // GenerateAdvisoryLockId inspired by rails migrations, see https://goo.gl/8o9bCT -func GenerateAdvisoryLockId(databaseName string) (string, error) { - sum := crc32.ChecksumIEEE([]byte(databaseName)) +func GenerateAdvisoryLockId(databaseName string, additionalNames ...string) (string, error) { + buf := bytes.NewBufferString(databaseName) + for _, name := range additionalNames { + buf.WriteByte(0) + buf.WriteString(name) + } + sum := crc32.ChecksumIEEE(buf.Bytes()) sum = sum * uint32(advisoryLockIdSalt) return fmt.Sprintf("%v", sum), nil } diff --git a/database/util_test.go b/database/util_test.go index 0b66d2d..710a248 100644 --- a/database/util_test.go +++ b/database/util_test.go @@ -7,14 +7,21 @@ import ( func TestGenerateAdvisoryLockId(t *testing.T) { testcases := []struct { dbname string + schema string expectedID string // empty string signifies that an error is expected }{ {dbname: "database_name", expectedID: "1764327054"}, + {dbname: "database_name", schema: "schema_name_1", expectedID: "3244152297"}, + {dbname: "database_name", schema: "schema_name_2", expectedID: "810103531"}, } for _, tc := range testcases { t.Run(tc.dbname, func(t *testing.T) { - if id, err := GenerateAdvisoryLockId("database_name"); err == nil { + names := []string{} + if len(tc.schema) > 0 { + names = append(names, tc.schema) + } + if id, err := GenerateAdvisoryLockId(tc.dbname, names...); err == nil { if id != tc.expectedID { t.Error("Generated incorrect ID:", id, "!=", tc.expectedID) }