diff --git a/database/cassandra/cassandra.go b/database/cassandra/cassandra.go index 68d2954..6bb4f3a 100644 --- a/database/cassandra/cassandra.go +++ b/database/cassandra/cassandra.go @@ -24,6 +24,7 @@ var ( ErrNilConfig = fmt.Errorf("no config") ErrNoKeyspace = fmt.Errorf("no keyspace provided") ErrDatabaseDirty = fmt.Errorf("database is dirty") + ErrClosedSession = fmt.Errorf("session is closed") ) type Config struct { @@ -39,7 +40,29 @@ type Cassandra struct { config *Config } -func (p *Cassandra) Open(url string) (database.Driver, error) { +func WithSession(session *gocql.Session, config *Config) (database.Driver, error) { + if config == nil { + return nil, ErrNilConfig + } else if isClosed := session.Closed(); isClosed { + return nil, ErrClosedSession + } else if len(config.KeyspaceName) == 0 { + return nil, ErrNoKeyspace + } + + if len(config.MigrationsTable) == 0 { + config.MigrationsTable = DefaultMigrationsTable + } + c := &Cassandra{ + session: session, + config: config, + } + if err := c.ensureVersionTable(); err != nil { + return nil, err + } + return c, nil +} + +func (c *Cassandra) Open(url string) (database.Driver, error) { u, err := nurl.Parse(url) if err != nil { return nil, err @@ -55,7 +78,7 @@ func (p *Cassandra) Open(url string) (database.Driver, error) { migrationsTable = DefaultMigrationsTable } - p.config = &Config{ + c.config = &Config{ KeyspaceName: u.Path, MigrationsTable: migrationsTable, } @@ -100,25 +123,25 @@ func (p *Cassandra) Open(url string) (database.Driver, error) { cluster.Timeout = timeout } - p.session, err = cluster.CreateSession() + c.session, err = cluster.CreateSession() if err != nil { return nil, err } - if err := p.ensureVersionTable(); err != nil { + if err := c.ensureVersionTable(); err != nil { return nil, err } - return p, nil + return c, nil } -func (p *Cassandra) Close() error { - p.session.Close() +func (c *Cassandra) Close() error { + c.session.Close() return nil } -func (p *Cassandra) Lock() error { +func (c *Cassandra) Lock() error { if dbLocked { return database.ErrLocked } @@ -126,19 +149,19 @@ func (p *Cassandra) Lock() error { return nil } -func (p *Cassandra) Unlock() error { +func (c *Cassandra) Unlock() error { dbLocked = false return nil } -func (p *Cassandra) Run(migration io.Reader) error { +func (c *Cassandra) Run(migration io.Reader) error { migr, err := ioutil.ReadAll(migration) if err != nil { return err } // run migration query := string(migr[:]) - if err := p.session.Query(query).Exec(); err != nil { + if err := c.session.Query(query).Exec(); err != nil { // TODO: cast to Cassandra error and get line number return database.Error{OrigErr: err, Err: "migration failed", Query: migr} } @@ -146,14 +169,14 @@ func (p *Cassandra) Run(migration io.Reader) error { return nil } -func (p *Cassandra) SetVersion(version int, dirty bool) error { - query := `TRUNCATE "` + p.config.MigrationsTable + `"` - if err := p.session.Query(query).Exec(); err != nil { +func (c *Cassandra) SetVersion(version int, dirty bool) error { + query := `TRUNCATE "` + c.config.MigrationsTable + `"` + if err := c.session.Query(query).Exec(); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } if version >= 0 { - query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` - if err := p.session.Query(query, version, dirty).Exec(); err != nil { + query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)` + if err := c.session.Query(query, version, dirty).Exec(); err != nil { return &database.Error{OrigErr: err, Query: []byte(query)} } } @@ -162,9 +185,9 @@ func (p *Cassandra) SetVersion(version int, dirty bool) error { } // Return current keyspace version -func (p *Cassandra) Version() (version int, dirty bool, err error) { - query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1` - err = p.session.Query(query).Scan(&version, &dirty) +func (c *Cassandra) Version() (version int, dirty bool, err error) { + query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1` + err = c.session.Query(query).Scan(&version, &dirty) switch { case err == gocql.ErrNotFound: return database.NilVersion, false, nil @@ -180,31 +203,31 @@ func (p *Cassandra) Version() (version int, dirty bool, err error) { } } -func (p *Cassandra) Drop() error { +func (c *Cassandra) Drop() error { // select all tables in current schema - query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character - iter := p.session.Query(query).Iter() + query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName[1:]) // Skip '/' character + iter := c.session.Query(query).Iter() var tableName string for iter.Scan(&tableName) { - err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() + err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec() if err != nil { return err } } // Re-create the version table - if err := p.ensureVersionTable(); err != nil { + if err := c.ensureVersionTable(); err != nil { return err } return nil } // Ensure version table exists -func (p *Cassandra) ensureVersionTable() error { - err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec() +func (c *Cassandra) ensureVersionTable() error { + err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec() if err != nil { return err } - if _, _, err = p.Version(); err != nil { + if _, _, err = c.Version(); err != nil { return err } return nil