Provide WithInstance method

Also includes some refactor around package naming, treats own repo as first-class and uses alternative package names for third party imports, Google spanner libraries in this case.
This commit is contained in:
Christian Klotz 2017-06-04 21:53:33 +01:00
parent 748ae8f06a
commit 2742b9c467
1 changed files with 65 additions and 44 deletions

View File

@ -12,10 +12,10 @@ import (
"golang.org/x/net/context" "golang.org/x/net/context"
"cloud.google.com/go/spanner" "cloud.google.com/go/spanner"
database "cloud.google.com/go/spanner/admin/database/apiv1" sdb "cloud.google.com/go/spanner/admin/database/apiv1"
"github.com/mattes/migrate" "github.com/mattes/migrate"
mdb "github.com/mattes/migrate/database" "github.com/mattes/migrate/database"
"google.golang.org/api/iterator" "google.golang.org/api/iterator"
adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1" adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
@ -23,7 +23,7 @@ import (
func init() { func init() {
db := Spanner{} db := Spanner{}
mdb.Register("spanner", &db) database.Register("spanner", &db)
} }
// DefaultMigrationsTable is used if no custom table is specified // DefaultMigrationsTable is used if no custom table is specified
@ -45,26 +45,56 @@ type Config struct {
// Spanner implements database.Driver for Google Cloud Spanner // Spanner implements database.Driver for Google Cloud Spanner
type Spanner struct { type Spanner struct {
adminClient *database.DatabaseAdminClient db *DB
dataClient *spanner.Client
config *Config config *Config
} }
type DB struct {
admin *sdb.DatabaseAdminClient
data *spanner.Client
}
// WithInstance implements database.Driver
func WithInstance(instance *DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if len(config.DatabaseName) == 0 {
return nil, ErrNoDatabaseName
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
sx := &Spanner{
db: instance,
config: config,
}
if err := sx.ensureVersionTable(); err != nil {
return nil, err
}
return sx, nil
}
// Open implements database.Driver // Open implements database.Driver
func (s *Spanner) Open(url string) (mdb.Driver, error) { func (s *Spanner) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url) purl, err := nurl.Parse(url)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
ctx := context.Background() ctx := context.Background()
adminClient, err := database.NewDatabaseAdminClient(ctx)
adminClient, err := sdb.NewDatabaseAdminClient(ctx)
if err != nil { if err != nil {
return nil, err return nil, err
} }
dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
dataClient, err := spanner.NewClient(ctx, dbname) dataClient, err := spanner.NewClient(ctx, dbname)
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
@ -75,26 +105,17 @@ func (s *Spanner) Open(url string) (mdb.Driver, error) {
migrationsTable = DefaultMigrationsTable migrationsTable = DefaultMigrationsTable
} }
sx := &Spanner{ db := &DB{admin: adminClient, data: dataClient}
adminClient: adminClient, return WithInstance(db, &Config{
dataClient: dataClient, DatabaseName: dbname,
config: &Config{ MigrationsTable: migrationsTable,
DatabaseName: dbname, })
MigrationsTable: migrationsTable,
},
}
if err := sx.ensureVersionTable(); err != nil {
return nil, err
}
return sx, nil
} }
// Close implements database.Driver // Close implements database.Driver
func (s *Spanner) Close() error { func (s *Spanner) Close() error {
s.dataClient.Close() s.db.data.Close()
return s.adminClient.Close() return s.db.admin.Close()
} }
// Lock implements database.Driver but doesn't do anything because Spanner only // Lock implements database.Driver but doesn't do anything because Spanner only
@ -119,17 +140,17 @@ func (s *Spanner) Run(migration io.Reader) error {
stmt := string(migr[:]) stmt := string(migr[:])
ctx := context.Background() ctx := context.Background()
op, err := s.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
Database: s.config.DatabaseName, Database: s.config.DatabaseName,
Statements: []string{stmt}, Statements: []string{stmt},
}) })
if err != nil { if err != nil {
return &mdb.Error{OrigErr: err, Err: "migration failed", Query: migr} return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
} }
if err := op.Wait(ctx); err != nil { if err := op.Wait(ctx); err != nil {
return &mdb.Error{OrigErr: err, Err: "migration failed", Query: migr} return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
} }
return nil return nil
@ -139,7 +160,7 @@ func (s *Spanner) Run(migration io.Reader) error {
func (s *Spanner) SetVersion(version int, dirty bool) error { func (s *Spanner) SetVersion(version int, dirty bool) error {
ctx := context.Background() ctx := context.Background()
_, err := s.dataClient.ReadWriteTransaction(ctx, _, err := s.db.data.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error { func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
m := []*spanner.Mutation{ m := []*spanner.Mutation{
spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()), spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
@ -150,7 +171,7 @@ func (s *Spanner) SetVersion(version int, dirty bool) error {
return txn.BufferWrite(m) return txn.BufferWrite(m)
}) })
if err != nil { if err != nil {
return &mdb.Error{OrigErr: err} return &database.Error{OrigErr: err}
} }
return nil return nil
@ -163,21 +184,21 @@ func (s *Spanner) Version() (version int, dirty bool, err error) {
stmt := spanner.Statement{ stmt := spanner.Statement{
SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`, SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
} }
iter := s.dataClient.Single().Query(ctx, stmt) iter := s.db.data.Single().Query(ctx, stmt)
defer iter.Stop() defer iter.Stop()
row, err := iter.Next() row, err := iter.Next()
switch err { switch err {
case iterator.Done: case iterator.Done:
return mdb.NilVersion, false, nil return database.NilVersion, false, nil
case nil: case nil:
var v int64 var v int64
if err = row.Columns(&v, &dirty); err != nil { if err = row.Columns(&v, &dirty); err != nil {
return 0, false, &mdb.Error{OrigErr: err, Query: []byte(stmt.SQL)} return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
} }
version = int(v) version = int(v)
default: default:
return 0, false, &mdb.Error{OrigErr: err, Query: []byte(stmt.SQL)} return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
} }
return version, dirty, nil return version, dirty, nil
@ -191,11 +212,11 @@ func (s *Spanner) Version() (version int, dirty bool, err error) {
// opposite direction. More testing // opposite direction. More testing
func (s *Spanner) Drop() error { func (s *Spanner) Drop() error {
ctx := context.Background() ctx := context.Background()
res, err := s.adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{ res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
Database: s.config.DatabaseName, Database: s.config.DatabaseName,
}) })
if err != nil { if err != nil {
return &mdb.Error{OrigErr: err, Err: "drop failed"} return &database.Error{OrigErr: err, Err: "drop failed"}
} }
if len(res.Statements) == 0 { if len(res.Statements) == 0 {
return nil return nil
@ -216,18 +237,18 @@ func (s *Spanner) Drop() error {
} }
} }
op, err := s.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
Database: s.config.DatabaseName, Database: s.config.DatabaseName,
Statements: stmts, Statements: stmts,
}) })
if err != nil { if err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
} }
if err := op.Wait(ctx); err != nil { if err := op.Wait(ctx); err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))} return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
} }
if err := p.ensureVersionTable(); err != nil { if err := s.ensureVersionTable(); err != nil {
return err return err
} }
@ -237,7 +258,7 @@ func (s *Spanner) Drop() error {
func (s *Spanner) ensureVersionTable() error { func (s *Spanner) ensureVersionTable() error {
ctx := context.Background() ctx := context.Background()
tbl := s.config.MigrationsTable tbl := s.config.MigrationsTable
iter := s.dataClient.Single().Read(ctx, tbl, spanner.AllKeys(), nil) iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), nil)
if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil { if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
return nil return nil
} }
@ -247,16 +268,16 @@ func (s *Spanner) ensureVersionTable() error {
Dirty BOOL NOT NULL Dirty BOOL NOT NULL
) PRIMARY KEY(Version)`, tbl) ) PRIMARY KEY(Version)`, tbl)
op, err := s.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{ op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
Database: s.config.DatabaseName, Database: s.config.DatabaseName,
Statements: []string{stmt}, Statements: []string{stmt},
}) })
if err != nil { if err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(stmt)} return &database.Error{OrigErr: err, Query: []byte(stmt)}
} }
if err := op.Wait(ctx); err != nil { if err := op.Wait(ctx); err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(stmt)} return &database.Error{OrigErr: err, Query: []byte(stmt)}
} }
return nil return nil