Add WithSession helper for Cassandra driver

Also changes method receivers names for the Cassandra struct to "c"
since "p" makes no sense in this context.
This commit is contained in:
Andrés Rodríguez 2018-03-13 15:49:08 -03:00
parent 22f249514d
commit 78c47074a3
1 changed files with 50 additions and 27 deletions

View File

@ -24,6 +24,7 @@ var (
ErrNilConfig = fmt.Errorf("no config")
ErrNoKeyspace = fmt.Errorf("no keyspace provided")
ErrDatabaseDirty = fmt.Errorf("database is dirty")
ErrClosedSession = fmt.Errorf("session is closed")
)
type Config struct {
@ -39,7 +40,29 @@ type Cassandra struct {
config *Config
}
func (p *Cassandra) Open(url string) (database.Driver, error) {
func WithSession(session *gocql.Session, config *Config) (database.Driver, error) {
if config == nil {
return nil, ErrNilConfig
} else if isClosed := session.Closed(); isClosed {
return nil, ErrClosedSession
} else if len(config.KeyspaceName) == 0 {
return nil, ErrNoKeyspace
}
if len(config.MigrationsTable) == 0 {
config.MigrationsTable = DefaultMigrationsTable
}
c := &Cassandra{
session: session,
config: config,
}
if err := c.ensureVersionTable(); err != nil {
return nil, err
}
return c, nil
}
func (c *Cassandra) Open(url string) (database.Driver, error) {
u, err := nurl.Parse(url)
if err != nil {
return nil, err
@ -55,7 +78,7 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
migrationsTable = DefaultMigrationsTable
}
p.config = &Config{
c.config = &Config{
KeyspaceName: u.Path,
MigrationsTable: migrationsTable,
}
@ -100,25 +123,25 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
cluster.Timeout = timeout
}
p.session, err = cluster.CreateSession()
c.session, err = cluster.CreateSession()
if err != nil {
return nil, err
}
if err := p.ensureVersionTable(); err != nil {
if err := c.ensureVersionTable(); err != nil {
return nil, err
}
return p, nil
return c, nil
}
func (p *Cassandra) Close() error {
p.session.Close()
func (c *Cassandra) Close() error {
c.session.Close()
return nil
}
func (p *Cassandra) Lock() error {
func (c *Cassandra) Lock() error {
if dbLocked {
return database.ErrLocked
}
@ -126,19 +149,19 @@ func (p *Cassandra) Lock() error {
return nil
}
func (p *Cassandra) Unlock() error {
func (c *Cassandra) Unlock() error {
dbLocked = false
return nil
}
func (p *Cassandra) Run(migration io.Reader) error {
func (c *Cassandra) Run(migration io.Reader) error {
migr, err := ioutil.ReadAll(migration)
if err != nil {
return err
}
// run migration
query := string(migr[:])
if err := p.session.Query(query).Exec(); err != nil {
if err := c.session.Query(query).Exec(); err != nil {
// TODO: cast to Cassandra error and get line number
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
}
@ -146,14 +169,14 @@ func (p *Cassandra) Run(migration io.Reader) error {
return nil
}
func (p *Cassandra) SetVersion(version int, dirty bool) error {
query := `TRUNCATE "` + p.config.MigrationsTable + `"`
if err := p.session.Query(query).Exec(); err != nil {
func (c *Cassandra) SetVersion(version int, dirty bool) error {
query := `TRUNCATE "` + c.config.MigrationsTable + `"`
if err := c.session.Query(query).Exec(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
if version >= 0 {
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
if err := p.session.Query(query, version, dirty).Exec(); err != nil {
query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
if err := c.session.Query(query, version, dirty).Exec(); err != nil {
return &database.Error{OrigErr: err, Query: []byte(query)}
}
}
@ -162,9 +185,9 @@ func (p *Cassandra) SetVersion(version int, dirty bool) error {
}
// Return current keyspace version
func (p *Cassandra) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
err = p.session.Query(query).Scan(&version, &dirty)
func (c *Cassandra) Version() (version int, dirty bool, err error) {
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
err = c.session.Query(query).Scan(&version, &dirty)
switch {
case err == gocql.ErrNotFound:
return database.NilVersion, false, nil
@ -180,31 +203,31 @@ func (p *Cassandra) Version() (version int, dirty bool, err error) {
}
}
func (p *Cassandra) Drop() error {
func (c *Cassandra) Drop() error {
// select all tables in current schema
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character
iter := p.session.Query(query).Iter()
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName[1:]) // Skip '/' character
iter := c.session.Query(query).Iter()
var tableName string
for iter.Scan(&tableName) {
err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
if err != nil {
return err
}
}
// Re-create the version table
if err := p.ensureVersionTable(); err != nil {
if err := c.ensureVersionTable(); err != nil {
return err
}
return nil
}
// Ensure version table exists
func (p *Cassandra) ensureVersionTable() error {
err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec()
func (c *Cassandra) ensureVersionTable() error {
err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
if err != nil {
return err
}
if _, _, err = p.Version(); err != nil {
if _, _, err = c.Version(); err != nil {
return err
}
return nil