From 1512e41e415721abad8607ede4ce491a55f59038 Mon Sep 17 00:00:00 2001 From: kenjones Date: Wed, 16 May 2018 17:43:05 -0400 Subject: [PATCH] Minor tweaks to remove duplication Adds missing connection close for cassandra tests Revert to default timeout of 600ms --- database/cassandra/cassandra.go | 50 ++++++++++++---------------- database/cassandra/cassandra_test.go | 12 +++---- 2 files changed, 25 insertions(+), 37 deletions(-) diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 16e479e..87de62c 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -1,11 +1,13 @@ package cassandra import ( + "errors" "fmt" "io" "io/ioutil" nurl "net/url" "strconv" + "strings" "time" "github.com/gocql/gocql" @@ -20,10 +22,10 @@ func init() { var DefaultMigrationsTable = "schema_migrations" var ( - ErrNilConfig = fmt.Errorf("no config") - ErrNoKeyspace = fmt.Errorf("no keyspace provided") - ErrDatabaseDirty = fmt.Errorf("database is dirty") - ErrClosedSession = fmt.Errorf("session is closed") + ErrNilConfig = errors.New("no config") + ErrNoKeyspace = errors.New("no keyspace provided") + ErrDatabaseDirty = errors.New("database is dirty") + ErrClosedSession = errors.New("session is closed") ) type Config struct { @@ -42,22 +44,27 @@ type Cassandra struct { func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) { if config == nil { return nil, ErrNilConfig - } else if isClosed := session.Closed(); isClosed { - return nil, ErrClosedSession } else if len(config.KeyspaceName) == 0 { return nil, ErrNoKeyspace } + if session.Closed() { + return nil, ErrClosedSession + } + if len(config.MigrationsTable) == 0 { config.MigrationsTable = DefaultMigrationsTable } + c := &Cassandra{ session: session, config: config, } + if err := c.ensureVersionTable(); err != nil { return nil, err } + return c, nil } @@ -72,18 +79,8 @@ func (c *Cassandra) Open(url string) (database.Driver, error) { return nil, ErrNoKeyspace } - migrationsTable := u.Query().Get("x-migrations-table") - if len(migrationsTable) == 0 { - migrationsTable = DefaultMigrationsTable - } - - c.config = &Config{ - KeyspaceName: u.Path, - MigrationsTable: migrationsTable, - } - cluster := gocql.NewCluster(u.Host) - cluster.Keyspace = u.Path[1:len(u.Path)] + cluster.Keyspace = strings.TrimPrefix(u.Path, "/") cluster.Consistency = gocql.All cluster.Timeout = 1 * time.Minute @@ -122,17 +119,15 @@ func (c *Cassandra) Open(url string) (database.Driver, error) { cluster.Timeout = timeout } - c.session, err = cluster.CreateSession() - + session, err := cluster.CreateSession() if err != nil { return nil, err } - if err := c.ensureVersionTable(); err != nil { - return nil, err - } - - return c, nil + return WithInstance(session, &Config{ + KeyspaceName: strings.TrimPrefix(u.Path, "/"), + MigrationsTable: u.Query().Get("x-migrations-table"), + }) } func (c *Cassandra) Close() error { @@ -204,7 +199,7 @@ func (c *Cassandra) Version() (version int, dirty bool, err error) { func (c *Cassandra) Drop() error { // select all tables in current schema - query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName[1:]) // Skip '/' character + query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName) iter := c.session.Query(query).Iter() var tableName string for iter.Scan(&tableName) { @@ -214,10 +209,7 @@ func (c *Cassandra) Drop() error { } } // Re-create the version table - if err := c.ensureVersionTable(); err != nil { - return err - } - return nil + return c.ensureVersionTable() } // Ensure version table exists diff --git a/database/cassandra/cassandra_test.go b/database/cassandra/cassandra_test.go index f3e5f58..47556fb 100644 --- a/database/cassandra/cassandra_test.go +++ b/database/cassandra/cassandra_test.go @@ -4,14 +4,9 @@ import ( "fmt" "strconv" "testing" - "time" -) -import ( "github.com/gocql/gocql" -) -import ( dt "github.com/golang-migrate/migrate/database/testing" mt "github.com/golang-migrate/migrate/testing" ) @@ -31,16 +26,16 @@ func isReady(i mt.Instance) bool { cluster := gocql.NewCluster(i.Host()) cluster.Port = port - //cluster.ProtoVersion = 4 cluster.Consistency = gocql.All - cluster.Timeout = 10 * time.Second p, err := cluster.CreateSession() if err != nil { return false } defer p.Close() // Create keyspace for tests - p.Query("CREATE KEYSPACE testks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor':1}").Exec() + if err = p.Query("CREATE KEYSPACE testks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor':1}").Exec(); err != nil { + return false + } return true } @@ -55,6 +50,7 @@ func Test(t *testing.T) { if err != nil { t.Fatalf("%v", err) } + defer d.Close() dt.Test(t, d, []byte("SELECT table_name from system_schema.tables")) }) }