Minor tweaks to remove duplication

Adds missing connection close for cassandra tests

Revert to default timeout of 600ms
This commit is contained in:
kenjones 2018-05-16 17:43:05 -04:00
parent 55a25c5e0e
commit 1512e41e41
2 changed files with 25 additions and 37 deletions

View File

@ -1,11 +1,13 @@
package cassandra
import (
"errors"
"fmt"
"io"
"io/ioutil"
nurl "net/url"
"strconv"
"strings"
"time"
"github.com/gocql/gocql"
@ -20,10 +22,10 @@ func init() {
var DefaultMigrationsTable = "schema_migrations"
var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoKeyspace = fmt.Errorf("no keyspace provided")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrClosedSession = fmt.Errorf("session is closed")
ErrNilConfig = errors.New("no config")
ErrNoKeyspace = errors.New("no keyspace provided")
ErrDatabaseDirty = errors.New("database is dirty")
ErrClosedSession = errors.New("session is closed")
)
type Config struct {
@ -42,22 +44,27 @@ type Cassandra struct {
func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
} else if isClosed := session.Closed(); isClosed {
return nil, ErrClosedSession
} else if len(config.KeyspaceName) == 0 {
return nil, ErrNoKeyspace
}
if session.Closed() {
return nil, ErrClosedSession
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
c := &Cassandra{
session: session,
config: config,
}
if err := c.ensureVersionTable(); err != nil {
return nil, err
}
return c, nil
}
@ -72,18 +79,8 @@ func (c *Cassandra) Open(url string) (database.Driver, error) {
return nil, ErrNoKeyspace
}
migrationsTable := u.Query().Get("x-migrations-table")
if len(migrationsTable) == 0 {
migrationsTable = DefaultMigrationsTable
}
c.config = &Config{
KeyspaceName: u.Path,
MigrationsTable: migrationsTable,
}
cluster := gocql.NewCluster(u.Host)
cluster.Keyspace = u.Path[1:len(u.Path)]
cluster.Keyspace = strings.TrimPrefix(u.Path, "/")
cluster.Consistency = gocql.All
cluster.Timeout = 1 * time.Minute
@ -122,17 +119,15 @@ func (c *Cassandra) Open(url string) (database.Driver, error) {
cluster.Timeout = timeout
}
c.session, err = cluster.CreateSession()
session, err := cluster.CreateSession()
if err != nil {
return nil, err
}
if err := c.ensureVersionTable(); err != nil {
return nil, err
}
return c, nil
return WithInstance(session, &Config{
KeyspaceName: strings.TrimPrefix(u.Path, "/"),
MigrationsTable: u.Query().Get("x-migrations-table"),
})
}
func (c *Cassandra) Close() error {
@ -204,7 +199,7 @@ func (c *Cassandra) Version() (version int, dirty bool, err error) {
func (c *Cassandra) Drop() error {
// select all tables in current schema
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName[1:]) // Skip '/' character
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName)
iter := c.session.Query(query).Iter()
var tableName string
for iter.Scan(&tableName) {
@ -214,10 +209,7 @@ func (c *Cassandra) Drop() error {
}
}
// Re-create the version table
if err := c.ensureVersionTable(); err != nil {
return err
}
return nil
return c.ensureVersionTable()
}
// Ensure version table exists

View File

@ -4,14 +4,9 @@ import (
"fmt"
"strconv"
"testing"
"time"
)
import (
"github.com/gocql/gocql"
)
import (
dt "github.com/golang-migrate/migrate/database/testing"
mt "github.com/golang-migrate/migrate/testing"
)
@ -31,16 +26,16 @@ func isReady(i mt.Instance) bool {
cluster := gocql.NewCluster(i.Host())
cluster.Port = port
//cluster.ProtoVersion = 4
cluster.Consistency = gocql.All
cluster.Timeout = 10 * time.Second
p, err := cluster.CreateSession()
if err != nil {
return false
}
defer p.Close()
// Create keyspace for tests
p.Query("CREATE KEYSPACE testks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor':1}").Exec()
if err = p.Query("CREATE KEYSPACE testks WITH REPLICATION = {'class': 'SimpleStrategy', 'replication_factor':1}").Exec(); err != nil {
return false
}
return true
}
@ -55,6 +50,7 @@ func Test(t *testing.T) {
if err != nil {
t.Fatalf("%v", err)
}
defer d.Close()
dt.Test(t, d, []byte("SELECT table_name from system_schema.tables"))
})
}