diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index 73eb646..71fe086 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -332,20 +332,19 @@ func (p *Postgres) Drop() error { return nil } -func (p *Postgres) ensureVersionTable() error { - // check if migration table exists - var count int - query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` - if err := p.conn.QueryRowContext(context.Background(), query, p.config.MigrationsTable).Scan(&count); err != nil { - return &database.Error{OrigErr: err, Query: []byte(query)} - } - if count == 1 { - return nil +func (p *Postgres) ensureVersionTable() (err error) { + if err = p.Lock(); err != nil { + return err } - // if not, create the empty migration table - query = `CREATE TABLE "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` - if _, err := p.conn.ExecContext(context.Background(), query); err != nil { + defer func() { + if e := p.Unlock(); err == nil { + err = e + } + }() + + query := `CREATE TABLE IF NOT EXISTS "` + p.config.MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` + if _, err = p.conn.ExecContext(context.Background(), query); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } return nil