mirror of
https://github.com/status-im/migrate.git
synced 2025-02-23 16:28:08 +00:00
Merge pull request #30 from Decemberlabs/master
Add WithSession helper for Cassandra driver
This commit is contained in:
commit
f815731412
@ -18,12 +18,12 @@ func init() {
|
||||
}
|
||||
|
||||
var DefaultMigrationsTable = "schema_migrations"
|
||||
var dbLocked = false
|
||||
|
||||
var (
|
||||
ErrNilConfig = fmt.Errorf("no config")
|
||||
ErrNoKeyspace = fmt.Errorf("no keyspace provided")
|
||||
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
||||
ErrClosedSession = fmt.Errorf("session is closed")
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
@ -39,7 +39,29 @@ type Cassandra struct {
|
||||
config *Config
|
||||
}
|
||||
|
||||
func (p *Cassandra) Open(url string) (database.Driver, error) {
|
||||
func WithInstance(session *gocql.Session, config *Config) (database.Driver, error) {
|
||||
if config == nil {
|
||||
return nil, ErrNilConfig
|
||||
} else if isClosed := session.Closed(); isClosed {
|
||||
return nil, ErrClosedSession
|
||||
} else if len(config.KeyspaceName) == 0 {
|
||||
return nil, ErrNoKeyspace
|
||||
}
|
||||
|
||||
if len(config.MigrationsTable) == 0 {
|
||||
config.MigrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
c := &Cassandra{
|
||||
session: session,
|
||||
config: config,
|
||||
}
|
||||
if err := c.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Cassandra) Open(url string) (database.Driver, error) {
|
||||
u, err := nurl.Parse(url)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -55,7 +77,7 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
|
||||
migrationsTable = DefaultMigrationsTable
|
||||
}
|
||||
|
||||
p.config = &Config{
|
||||
c.config = &Config{
|
||||
KeyspaceName: u.Path,
|
||||
MigrationsTable: migrationsTable,
|
||||
}
|
||||
@ -100,45 +122,45 @@ func (p *Cassandra) Open(url string) (database.Driver, error) {
|
||||
cluster.Timeout = timeout
|
||||
}
|
||||
|
||||
p.session, err = cluster.CreateSession()
|
||||
c.session, err = cluster.CreateSession()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := p.ensureVersionTable(); err != nil {
|
||||
if err := c.ensureVersionTable(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return p, nil
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Close() error {
|
||||
p.session.Close()
|
||||
func (c *Cassandra) Close() error {
|
||||
c.session.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Lock() error {
|
||||
if dbLocked {
|
||||
func (c *Cassandra) Lock() error {
|
||||
if c.isLocked {
|
||||
return database.ErrLocked
|
||||
}
|
||||
dbLocked = true
|
||||
c.isLocked = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Unlock() error {
|
||||
dbLocked = false
|
||||
func (c *Cassandra) Unlock() error {
|
||||
c.isLocked = false
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) Run(migration io.Reader) error {
|
||||
func (c *Cassandra) Run(migration io.Reader) error {
|
||||
migr, err := ioutil.ReadAll(migration)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// run migration
|
||||
query := string(migr[:])
|
||||
if err := p.session.Query(query).Exec(); err != nil {
|
||||
if err := c.session.Query(query).Exec(); err != nil {
|
||||
// TODO: cast to Cassandra error and get line number
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
@ -146,14 +168,14 @@ func (p *Cassandra) Run(migration io.Reader) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *Cassandra) SetVersion(version int, dirty bool) error {
|
||||
query := `TRUNCATE "` + p.config.MigrationsTable + `"`
|
||||
if err := p.session.Query(query).Exec(); err != nil {
|
||||
func (c *Cassandra) SetVersion(version int, dirty bool) error {
|
||||
query := `TRUNCATE "` + c.config.MigrationsTable + `"`
|
||||
if err := c.session.Query(query).Exec(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
if version >= 0 {
|
||||
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
|
||||
if err := p.session.Query(query, version, dirty).Exec(); err != nil {
|
||||
query = `INSERT INTO "` + c.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
|
||||
if err := c.session.Query(query, version, dirty).Exec(); err != nil {
|
||||
return &database.Error{OrigErr: err, Query: []byte(query)}
|
||||
}
|
||||
}
|
||||
@ -162,9 +184,9 @@ func (p *Cassandra) SetVersion(version int, dirty bool) error {
|
||||
}
|
||||
|
||||
// Return current keyspace version
|
||||
func (p *Cassandra) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
|
||||
err = p.session.Query(query).Scan(&version, &dirty)
|
||||
func (c *Cassandra) Version() (version int, dirty bool, err error) {
|
||||
query := `SELECT version, dirty FROM "` + c.config.MigrationsTable + `" LIMIT 1`
|
||||
err = c.session.Query(query).Scan(&version, &dirty)
|
||||
switch {
|
||||
case err == gocql.ErrNotFound:
|
||||
return database.NilVersion, false, nil
|
||||
@ -180,31 +202,31 @@ func (p *Cassandra) Version() (version int, dirty bool, err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func (p *Cassandra) Drop() error {
|
||||
func (c *Cassandra) Drop() error {
|
||||
// select all tables in current schema
|
||||
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character
|
||||
iter := p.session.Query(query).Iter()
|
||||
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, c.config.KeyspaceName[1:]) // Skip '/' character
|
||||
iter := c.session.Query(query).Iter()
|
||||
var tableName string
|
||||
for iter.Scan(&tableName) {
|
||||
err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
|
||||
err := c.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Re-create the version table
|
||||
if err := p.ensureVersionTable(); err != nil {
|
||||
if err := c.ensureVersionTable(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure version table exists
|
||||
func (p *Cassandra) ensureVersionTable() error {
|
||||
err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec()
|
||||
func (c *Cassandra) ensureVersionTable() error {
|
||||
err := c.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", c.config.MigrationsTable)).Exec()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, _, err = p.Version(); err != nil {
|
||||
if _, _, err = c.Version(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
Loading…
x
Reference in New Issue
Block a user