diff --git a/Makefile b/Makefile index baa8705..2e6817b 100644 --- a/Makefile +++ b/Makefile @@ -26,7 +26,7 @@ test-short: test: @-rm -r .coverage @mkdir .coverage - make test-with-flags TEST_FLAGS='-v -race -covermode atomic -coverprofile .coverage/_$$(RAND).txt -bench=. -benchmem' + make test-with-flags TEST_FLAGS='-v -race -covermode atomic -coverprofile .coverage/_$$(RAND).txt -bench=. -benchmem' @echo 'mode: atomic' > .coverage/combined.txt @cat .coverage/*.txt | grep -v 'mode: atomic' >> .coverage/combined.txt diff --git a/database/driver.go b/database/driver.go index 30b8051..654ad61 100644 --- a/database/driver.go +++ b/database/driver.go @@ -51,6 +51,7 @@ type Driver interface { // Lock should acquire a database lock so that only one migration process // can run at a time. Migrate will call this function before Run is called. // If the implementation can't provide this functionality, return nil. + // Return database.ErrLocked if database is already locked. Lock() error // Unlock should release the lock. Migrate will call this function after diff --git a/database/error.go b/database/error.go new file mode 100644 index 0000000..eb802c7 --- /dev/null +++ b/database/error.go @@ -0,0 +1,27 @@ +package database + +import ( + "fmt" +) + +// Error should be used for errors involving queries ran against the database +type Error struct { + // Optional: the line number + Line uint + + // Query is a query excerpt + Query []byte + + // Err is a useful/helping error message for humans + Err string + + // OrigErr is the underlying error + OrigErr error +} + +func (e Error) Error() string { + if len(e.Err) == 0 { + return fmt.Sprintf("%v in line %v: %s", e.OrigErr, e.Line, e.Query) + } + return fmt.Sprintf("%v in line %v: %s (details: %v)", e.Err, e.Line, e.Query, e.OrigErr) +} diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index ff2b2e4..0bbbb07 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -16,13 +16,16 @@ func init() { database.Register("postgres", &Postgres{}) } +var MigrationsTable = "schema_migrations" + var ( ErrNilConfig = fmt.Errorf("no config") ErrNoDatabaseName = fmt.Errorf("no database name") + ErrNoSchema = fmt.Errorf("no schema") + ErrDatabaseDirty = fmt.Errorf("database is dirty") ) type Config struct { - // DatbaseName is the name of the database DatabaseName string } @@ -31,10 +34,18 @@ func WithInstance(instance *sql.DB, config *Config) (database.Driver, error) { return nil, ErrNilConfig } - if len(config.DatabaseName) == 0 { + query := `SELECT CURRENT_DATABASE()` + var databaseName string + if err := instance.QueryRow(query).Scan(&databaseName); err != nil { + return nil, &database.Error{OrigErr: err, Query: []byte(query)} + } + + if len(databaseName) == 0 { return nil, ErrNoDatabaseName } + config.DatabaseName = databaseName + px := &Postgres{ db: instance, config: config, @@ -55,8 +66,6 @@ type Postgres struct { config *Config } -const tableName = "schema_migrations" - func (p *Postgres) Open(url string) (database.Driver, error) { purl, err := nurl.Parse(url) if err != nil { @@ -97,10 +106,12 @@ func (p *Postgres) Lock() error { return err } - // It will either obtain the lock immediately and return true, or return false if the lock cannot be acquired immediately. + // This will either obtain the lock immediately and return true, + // or return false if the lock cannot be acquired immediately. + query := `SELECT pg_try_advisory_lock($1)` var success bool - if err := p.db.QueryRow("SELECT pg_try_advisory_lock($1)", aid).Scan(&success); err != nil { - return err + if err := p.db.QueryRow(query, aid).Scan(&success); err != nil { + return &database.Error{OrigErr: err, Err: "try lock failed", Query: []byte(query)} } if success { @@ -121,104 +132,168 @@ func (p *Postgres) Unlock() error { return err } - if _, err := p.db.Exec("SELECT pg_advisory_unlock($1)", aid); err != nil { - return err + query := `SELECT pg_advisory_unlock($1)` + if _, err := p.db.Exec(query, aid); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} } p.isLocked = false return nil } func (p *Postgres) Run(version int, migration io.Reader) error { + if dirty, err := p.isDirty(); err != nil { + return err + } else if dirty { + return ErrDatabaseDirty + } + if migration == nil { // just apply version - return p.saveVersion(version) + return p.saveVersion(version, false) } - mgr, err := ioutil.ReadAll(migration) + migr, err := ioutil.ReadAll(migration) if err != nil { return err } - // it would be nice to be able to wrap the migration into the transaction, too - // unfortunately things like `CREATE INDEX CONCURRENTLY` aren't possible in a - // transaction. so if something fails between running the migration, and - // storing the latest migration version in the version table, we alert the user - // who then needs to manually fix. - // TODO: two phase commit? - if _, err := p.db.Exec(string(mgr[:])); err != nil { + // set dirty flag and set version + if err := p.saveVersion(version, true); err != nil { return err } - return p.saveVersion(version) + // run migration + query := string(migr[:]) + if _, err := p.db.Exec(query); err != nil { + // TODO: cast to postgress error and get line number + return database.Error{OrigErr: err, Err: "migration failed", Query: migr} + } + + // remove dirty flag + return p.saveVersion(version, false) } -func (p *Postgres) saveVersion(version int) error { +func (p *Postgres) saveVersion(version int, dirty bool) error { tx, err := p.db.Begin() if err != nil { - return err // TODO: warn user + return &database.Error{OrigErr: err, Err: "transaction start failed"} } - if _, err := p.db.Exec("TRUNCATE " + tableName + ""); err != nil { + query := `TRUNCATE "` + MigrationsTable + `"` + if _, err := p.db.Exec(query); err != nil { tx.Rollback() - return err // TODO: warn user + return &database.Error{OrigErr: err, Query: []byte(query)} } if version >= 0 { - if _, err := p.db.Exec("INSERT INTO "+tableName+" (version) VALUES ($1)", version); err != nil { + query = `INSERT INTO "` + MigrationsTable + `" (version, dirty) VALUES ($1, $2)` + if _, err := p.db.Exec(query, version, dirty); err != nil { tx.Rollback() - return err // TODO: warn user + return &database.Error{OrigErr: err, Query: []byte(query)} } } if err := tx.Commit(); err != nil { - return err // TODO: warn user + return &database.Error{OrigErr: err, Err: "transaction commit failed"} } return nil } +func (p *Postgres) isDirty() (bool, error) { + query := `SELECT dirty FROM "` + MigrationsTable + `" LIMIT 1` + var dirty bool + err := p.db.QueryRow(query).Scan(&dirty) + switch { + case err == sql.ErrNoRows: + return false, nil + + case err != nil: + if e, ok := err.(*pq.Error); ok { + if e.Code.Name() == "undefined_table" { + return false, nil + } + } + return false, &database.Error{OrigErr: err, Query: []byte(query)} + + default: + return dirty, nil + } +} + func (p *Postgres) Version() (int, error) { + query := `SELECT version FROM "` + MigrationsTable + `" LIMIT 1` var version uint64 - err := p.db.QueryRow("SELECT version FROM " + tableName + " ORDER BY version DESC LIMIT 1").Scan(&version) + err := p.db.QueryRow(query).Scan(&version) switch { case err == sql.ErrNoRows: return database.NilVersion, nil + case err != nil: if e, ok := err.(*pq.Error); ok { if e.Code.Name() == "undefined_table" { return database.NilVersion, nil } } - return 0, err + return 0, &database.Error{OrigErr: err, Query: []byte(query)} + default: return int(version), nil } } func (p *Postgres) Drop() error { - if _, err := p.db.Exec("DROP SCHEMA public cascade "); err != nil { - return err + // select all tables in current schema + query := `SELECT table_name FROM information_schema.tables WHERE table_schema=(SELECT current_schema())` + tables, err := p.db.Query(query) + if err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} } - if _, err := p.db.Exec("CREATE SCHEMA public"); err != nil { - return err + defer tables.Close() + + // delete one table after another + tableNames := make([]string, 0) + for tables.Next() { + var tableName string + if err := tables.Scan(&tableName); err != nil { + return err + } + if len(tableName) > 0 { + tableNames = append(tableNames, tableName) + } } - if err := p.ensureVersionTable(); err != nil { - return err + + if len(tableNames) > 0 { + // delete one by one ... + for _, t := range tableNames { + query = `DROP TABLE IF EXISTS ` + t + ` CASCADE` + if _, err := p.db.Exec(query); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} + } + } + if err := p.ensureVersionTable(); err != nil { + return err + } } + return nil } func (p *Postgres) ensureVersionTable() error { - r := p.db.QueryRow("SELECT count(*) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema())", tableName) - c := 0 - if err := r.Scan(&c); err != nil { - return err + // check if migration table exists + var count int + query := `SELECT COUNT(1) FROM information_schema.tables WHERE table_name = $1 AND table_schema = (SELECT current_schema()) LIMIT 1` + if err := p.db.QueryRow(query, MigrationsTable).Scan(&count); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} } - if c > 0 { + if count == 1 { return nil } - if _, err := p.db.Exec("CREATE TABLE IF NOT EXISTS " + tableName + " (version bigint not null primary key);"); err != nil { - return err + + // if not, create the empty migration table + query = `CREATE TABLE "` + MigrationsTable + `" (version bigint not null primary key, dirty boolean not null)` + if _, err := p.db.Exec(query); err != nil { + return &database.Error{OrigErr: err, Query: []byte(query)} } return nil }