2014-08-27 01:46:41 +00:00
|
|
|
// Package cassandra implements the Driver interface.
|
|
|
|
package cassandra
|
|
|
|
|
|
|
|
import (
|
|
|
|
"net/url"
|
|
|
|
"time"
|
|
|
|
|
|
|
|
"github.com/gocql/gocql"
|
|
|
|
"github.com/mattes/migrate/file"
|
|
|
|
"github.com/mattes/migrate/migrate/direction"
|
|
|
|
)
|
|
|
|
|
|
|
|
type Driver struct {
|
|
|
|
session *gocql.Session
|
|
|
|
}
|
|
|
|
|
2014-11-17 10:42:19 +00:00
|
|
|
const (
|
|
|
|
tableName = "schema_migrations"
|
|
|
|
versionRow = 1
|
|
|
|
)
|
|
|
|
|
|
|
|
type counterStmt bool
|
|
|
|
|
|
|
|
func (c counterStmt) String() string {
|
|
|
|
sign := ""
|
|
|
|
if bool(c) {
|
|
|
|
sign = "+"
|
|
|
|
} else {
|
|
|
|
sign = "-"
|
|
|
|
}
|
|
|
|
return "UPDATE " + tableName + " SET version = version " + sign + " 1 where versionRow = ?"
|
|
|
|
}
|
|
|
|
|
|
|
|
const (
|
|
|
|
up counterStmt = true
|
|
|
|
down counterStmt = false
|
|
|
|
)
|
2014-08-27 01:46:41 +00:00
|
|
|
|
|
|
|
// Cassandra Driver URL format:
|
|
|
|
// cassandra://host:port/keyspace
|
|
|
|
//
|
|
|
|
// Example:
|
|
|
|
// cassandra://localhost/SpaceOfKeys
|
|
|
|
func (driver *Driver) Initialize(rawurl string) error {
|
|
|
|
u, err := url.Parse(rawurl)
|
|
|
|
|
|
|
|
cluster := gocql.NewCluster(u.Host)
|
|
|
|
cluster.Keyspace = u.Path[1:len(u.Path)]
|
|
|
|
cluster.Consistency = gocql.All
|
|
|
|
cluster.Timeout = 1 * time.Minute
|
|
|
|
|
|
|
|
driver.session, err = cluster.CreateSession()
|
|
|
|
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := driver.ensureVersionTableExists(); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (driver *Driver) Close() error {
|
|
|
|
driver.session.Close()
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (driver *Driver) ensureVersionTableExists() error {
|
2014-08-27 02:43:05 +00:00
|
|
|
err := driver.session.Query("CREATE TABLE IF NOT EXISTS " + tableName + " (version counter, versionRow bigint primary key);").Exec()
|
2014-08-27 03:19:13 +00:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2014-08-27 03:24:01 +00:00
|
|
|
_, err = driver.Version()
|
|
|
|
if err != nil {
|
2014-11-17 10:42:19 +00:00
|
|
|
driver.session.Query(up.String(), versionRow).Exec()
|
2014-08-27 03:24:01 +00:00
|
|
|
}
|
2014-08-27 03:19:13 +00:00
|
|
|
|
2014-08-27 01:46:41 +00:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func (driver *Driver) FilenameExtension() string {
|
|
|
|
return "cql"
|
|
|
|
}
|
|
|
|
|
2014-11-17 10:42:19 +00:00
|
|
|
func (driver *Driver) version(d direction.Direction, invert bool) error {
|
|
|
|
var stmt counterStmt
|
|
|
|
switch d {
|
|
|
|
case direction.Up:
|
|
|
|
stmt = up
|
|
|
|
case direction.Down:
|
|
|
|
stmt = down
|
|
|
|
}
|
|
|
|
if invert {
|
|
|
|
stmt = !stmt
|
|
|
|
}
|
|
|
|
return driver.session.Query(stmt.String(), versionRow).Exec()
|
|
|
|
}
|
2014-08-27 01:46:41 +00:00
|
|
|
|
2014-11-17 10:42:19 +00:00
|
|
|
func (driver *Driver) Migrate(f file.File, pipe chan interface{}) {
|
|
|
|
var err error
|
|
|
|
defer func() {
|
2014-08-27 01:46:41 +00:00
|
|
|
if err != nil {
|
2014-11-17 10:42:19 +00:00
|
|
|
// Invert version direction if we couldn't apply the changes for some reason.
|
|
|
|
if err := driver.version(f.Direction, true); err != nil {
|
|
|
|
pipe <- err
|
|
|
|
}
|
2014-08-27 01:46:41 +00:00
|
|
|
pipe <- err
|
|
|
|
}
|
2014-11-17 10:42:19 +00:00
|
|
|
close(pipe)
|
|
|
|
}()
|
2014-08-27 01:46:41 +00:00
|
|
|
|
2014-11-17 10:42:19 +00:00
|
|
|
pipe <- f
|
|
|
|
if err = driver.version(f.Direction, false); err != nil {
|
2014-08-27 01:46:41 +00:00
|
|
|
return
|
|
|
|
}
|
|
|
|
|
2014-11-17 10:42:19 +00:00
|
|
|
if err = f.ReadContent(); err != nil {
|
2014-08-27 01:46:41 +00:00
|
|
|
return
|
|
|
|
}
|
2014-11-17 10:42:19 +00:00
|
|
|
|
|
|
|
err = driver.session.Query(string(f.Content)).Exec()
|
2014-08-27 01:46:41 +00:00
|
|
|
}
|
|
|
|
|
|
|
|
func (driver *Driver) Version() (uint64, error) {
|
|
|
|
var version int64
|
2014-08-27 02:43:05 +00:00
|
|
|
err := driver.session.Query("SELECT version FROM "+tableName+" WHERE versionRow = ?", versionRow).Scan(&version)
|
2014-08-27 03:19:13 +00:00
|
|
|
return uint64(version) - 1, err
|
2014-08-27 01:46:41 +00:00
|
|
|
}
|