Provide WithInstance method

Also includes some refactor around package naming, treats own repo as first-class and uses alternative package names for third party imports, Google spanner libraries in this case.
This commit is contained in:
Christian Klotz 2017-06-04 21:53:33 +01:00
parent 748ae8f06a
commit 2742b9c467
1 changed files with 65 additions and 44 deletions

View File

@ -12,10 +12,10 @@ import (
"golang.org/x/net/context"
"cloud.google.com/go/spanner"
database "cloud.google.com/go/spanner/admin/database/apiv1"
sdb "cloud.google.com/go/spanner/admin/database/apiv1"
"github.com/mattes/migrate"
mdb "github.com/mattes/migrate/database"
"github.com/mattes/migrate/database"
"google.golang.org/api/iterator"
adminpb "google.golang.org/genproto/googleapis/spanner/admin/database/v1"
@ -23,7 +23,7 @@ import (
func init() {
db := Spanner{}
mdb.Register("spanner", &db)
database.Register("spanner", &db)
}
// DefaultMigrationsTable is used if no custom table is specified
@ -45,26 +45,56 @@ type Config struct {
// Spanner implements database.Driver for Google Cloud Spanner
type Spanner struct {
adminClient *database.DatabaseAdminClient
dataClient *spanner.Client
db *DB
config *Config
}
type DB struct {
admin *sdb.DatabaseAdminClient
data *spanner.Client
}
// WithInstance implements database.Driver
func WithInstance(instance *DB, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
}
if len(config.DatabaseName) == 0 {
return nil, ErrNoDatabaseName
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
sx := &Spanner{
db: instance,
config: config,
}
if err := sx.ensureVersionTable(); err != nil {
return nil, err
}
return sx, nil
}
// Open implements database.Driver
func (s *Spanner) Open(url string) (mdb.Driver, error) {
func (s *Spanner) Open(url string) (database.Driver, error) {
purl, err := nurl.Parse(url)
if err != nil {
return nil, err
}
dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
ctx := context.Background()
adminClient, err := database.NewDatabaseAdminClient(ctx)
adminClient, err := sdb.NewDatabaseAdminClient(ctx)
if err != nil {
return nil, err
}
dbname := strings.Replace(migrate.FilterCustomQuery(purl).String(), "spanner://", "", 1)
dataClient, err := spanner.NewClient(ctx, dbname)
if err != nil {
log.Fatal(err)
@ -75,26 +105,17 @@ func (s *Spanner) Open(url string) (mdb.Driver, error) {
migrationsTable = DefaultMigrationsTable
}
sx := &Spanner{
adminClient: adminClient,
dataClient: dataClient,
config: &Config{
db := &DB{admin: adminClient, data: dataClient}
return WithInstance(db, &Config{
DatabaseName: dbname,
MigrationsTable: migrationsTable,
},
}
if err := sx.ensureVersionTable(); err != nil {
return nil, err
}
return sx, nil
})
}
// Close implements database.Driver
func (s *Spanner) Close() error {
s.dataClient.Close()
return s.adminClient.Close()
s.db.data.Close()
return s.db.admin.Close()
}
// Lock implements database.Driver but doesn't do anything because Spanner only
@ -119,17 +140,17 @@ func (s *Spanner) Run(migration io.Reader) error {
stmt := string(migr[:])
ctx := context.Background()
op, err := s.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
Database: s.config.DatabaseName,
Statements: []string{stmt},
})
if err != nil {
return &mdb.Error{OrigErr: err, Err: "migration failed", Query: migr}
return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
if err := op.Wait(ctx); err != nil {
return &mdb.Error{OrigErr: err, Err: "migration failed", Query: migr}
return &database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
return nil
@ -139,7 +160,7 @@ func (s *Spanner) Run(migration io.Reader) error {
func (s *Spanner) SetVersion(version int, dirty bool) error {
ctx := context.Background()
_, err := s.dataClient.ReadWriteTransaction(ctx,
_, err := s.db.data.ReadWriteTransaction(ctx,
func(ctx context.Context, txn *spanner.ReadWriteTransaction) error {
m := []*spanner.Mutation{
spanner.Delete(s.config.MigrationsTable, spanner.AllKeys()),
@ -150,7 +171,7 @@ func (s *Spanner) SetVersion(version int, dirty bool) error {
return txn.BufferWrite(m)
})
if err != nil {
return &mdb.Error{OrigErr: err}
return &database.Error{OrigErr: err}
}
return nil
@ -163,21 +184,21 @@ func (s *Spanner) Version() (version int, dirty bool, err error) {
stmt := spanner.Statement{
SQL: `SELECT Version, Dirty FROM ` + s.config.MigrationsTable + ` LIMIT 1`,
}
iter := s.dataClient.Single().Query(ctx, stmt)
iter := s.db.data.Single().Query(ctx, stmt)
defer iter.Stop()
row, err := iter.Next()
switch err {
case iterator.Done:
return mdb.NilVersion, false, nil
return database.NilVersion, false, nil
case nil:
var v int64
if err = row.Columns(&v, &dirty); err != nil {
return 0, false, &mdb.Error{OrigErr: err, Query: []byte(stmt.SQL)}
return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
}
version = int(v)
default:
return 0, false, &mdb.Error{OrigErr: err, Query: []byte(stmt.SQL)}
return 0, false, &database.Error{OrigErr: err, Query: []byte(stmt.SQL)}
}
return version, dirty, nil
@ -191,11 +212,11 @@ func (s *Spanner) Version() (version int, dirty bool, err error) {
// opposite direction. More testing
func (s *Spanner) Drop() error {
ctx := context.Background()
res, err := s.adminClient.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
res, err := s.db.admin.GetDatabaseDdl(ctx, &adminpb.GetDatabaseDdlRequest{
Database: s.config.DatabaseName,
})
if err != nil {
return &mdb.Error{OrigErr: err, Err: "drop failed"}
return &database.Error{OrigErr: err, Err: "drop failed"}
}
if len(res.Statements) == 0 {
return nil
@ -216,18 +237,18 @@ func (s *Spanner) Drop() error {
}
}
op, err := s.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
Database: s.config.DatabaseName,
Statements: stmts,
})
if err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
}
if err := op.Wait(ctx); err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
return &database.Error{OrigErr: err, Query: []byte(strings.Join(stmts, "; "))}
}
if err := p.ensureVersionTable(); err != nil {
if err := s.ensureVersionTable(); err != nil {
return err
}
@ -237,7 +258,7 @@ func (s *Spanner) Drop() error {
func (s *Spanner) ensureVersionTable() error {
ctx := context.Background()
tbl := s.config.MigrationsTable
iter := s.dataClient.Single().Read(ctx, tbl, spanner.AllKeys(), nil)
iter := s.db.data.Single().Read(ctx, tbl, spanner.AllKeys(), nil)
if err := iter.Do(func(r *spanner.Row) error { return nil }); err == nil {
return nil
}
@ -247,16 +268,16 @@ func (s *Spanner) ensureVersionTable() error {
Dirty BOOL NOT NULL
) PRIMARY KEY(Version)`, tbl)
op, err := s.adminClient.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
op, err := s.db.admin.UpdateDatabaseDdl(ctx, &adminpb.UpdateDatabaseDdlRequest{
Database: s.config.DatabaseName,
Statements: []string{stmt},
})
if err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(stmt)}
return &database.Error{OrigErr: err, Query: []byte(stmt)}
}
if err := op.Wait(ctx); err != nil {
return &mdb.Error{OrigErr: err, Query: []byte(stmt)}
return &database.Error{OrigErr: err, Query: []byte(stmt)}
}
return nil