mirror of https://github.com/status-im/migrate.git
Add WithSession helper for Cassandra driver
Also changes method receivers names for the Cassandra struct to "c" since "p" makes no sense in this context.
This commit is contained in:
parent
22f249514d
commit
78c47074a3
|
@ -24,6 +24,7 @@ var (
|
|||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoKeyspace = fmt.Errorf("no keyspace provided")
|
||||
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
||||
ErrClosedSession = fmt.Errorf("session is closed")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
|
@ -39,7 +40,29 @@ type Cassandra struct {
|
|||
config *Config
|
||||
}
|
||||
|
||||
func (p *Cassandra) Open(url string) (database.Driver, error) {
|
||||
func WithSession(session *gocql.Session, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
} else if isClosed := session.Closed(); isClosed {
|
||||
return nil, ErrClosedSession
|
||||
} else if len(config.KeyspaceName) == 0 {
|
||||
return nil, ErrNoKeyspace
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
c := &Cassandra{
|
||||
session: session,
|
||||
config: config,
|
||||
}
|
||||
if err := c.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) Open(url string) (database.Driver, error) {
|
||||
u, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
@ -55,7 +78,7 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
|
|||
migrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
p.config = &Config{
|
||||
c.config = &Config{
|
||||
KeyspaceName: u.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
}
|
||||
|
@ -100,25 +123,25 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
|
|||
cluster.Timeout = timeout
|
||||
}
|
||||
|
||||
p.session, err = cluster.CreateSession()
|
||||
c.session, err = cluster.CreateSession()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.ensureVersionTable(); err != nil {
|
||||
if err := c.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Close() error {
|
||||
p.session.Close()
|
||||
func (c *Cassandra) Close() error {
|
||||
c.session.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Lock() error {
|
||||
func (c *Cassandra) Lock() error {
|
||||
if dbLocked {
|
||||
return database.ErrLocked
|
||||
}
|
||||
|
@ -126,19 +149,19 @@ func (p *Cassandra) Lock() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Unlock() error {
|
||||
func (c *Cassandra) Unlock() error {
|
||||
dbLocked = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Run(migration io.Reader) error {
|
||||
func (c *Cassandra) Run(migration io.Reader) error {
|
||||
migr, err := ioutil.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// run migration
|
||||
query := string(migr[:])
|
||||
if err := p.session.Query(query).Exec(); err != nil {
|
||||
if err := c.session.Query(query).Exec(); err != nil {
|
||||
// TODO: cast to Cassandra error and get line number
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
|
@ -146,14 +169,14 @@ func (p *Cassandra) Run(migration io.Reader) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) SetVersion(version int, dirty bool) error {
|
||||
query := `TRUNCATE "` + p.config.MigrationsTable + `"`
|
||||
if err := p.session.Query(query).Exec(); err != nil {
|
||||
func (c *Cassandra) SetVersion(version int, dirty bool) error {
|
||||
query := `TRUNCATE "` + c.config.MigrationsTable + `"`
|
||||
if err := c.session.Query(query).Exec(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
if version >= 0 {
|
||||
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
|
||||
if err := p.session.Query(query, version, dirty).Exec(); err != nil {
|
||||
query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
|
||||
if err := c.session.Query(query, version, dirty).Exec(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
|
@ -162,9 +185,9 @@ func (p *Cassandra) SetVersion(version int, dirty bool) error {
|
|||
}
|
||||
|
||||
// Return current keyspace version
|
||||
func (p *Cassandra) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
|
||||
err = p.session.Query(query).Scan(&version, &dirty)
|
||||
func (c *Cassandra) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
|
||||
err = c.session.Query(query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == gocql.ErrNotFound:
|
||||
return database.NilVersion, false, nil
|
||||
|
@ -180,31 +203,31 @@ func (p *Cassandra) Version() (version int, dirty bool, err error) {
|
|||
}
|
||||
}
|
||||
|
||||
func (p *Cassandra) Drop() error {
|
||||
func (c *Cassandra) Drop() error {
|
||||
// select all tables in current schema
|
||||
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character
|
||||
iter := p.session.Query(query).Iter()
|
||||
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName[1:]) // Skip '/' character
|
||||
iter := c.session.Query(query).Iter()
|
||||
var tableName string
|
||||
for iter.Scan(&tableName) {
|
||||
err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
|
||||
err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Re-create the version table
|
||||
if err := p.ensureVersionTable(); err != nil {
|
||||
if err := c.ensureVersionTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure version table exists
|
||||
func (p *Cassandra) ensureVersionTable() error {
|
||||
err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec()
|
||||
func (c *Cassandra) ensureVersionTable() error {
|
||||
err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, _, err = p.Version(); err != nil {
|
||||
if _, _, err = c.Version(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
|
Loading…
Reference in New Issue