mirror of https://github.com/status-im/migrate.git
224 lines
5.0 KiB
Go
224 lines
5.0 KiB
Go
package cassandra
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"io/ioutil"
|
|
nurl "net/url"
|
|
"github.com/gocql/gocql"
|
|
"time"
|
|
"github.com/mattes/migrate/database"
|
|
"strconv"
|
|
)
|
|
|
|
func init() {
|
|
db := new(Cassandra)
|
|
database.Register("cassandra", db)
|
|
}
|
|
|
|
var DefaultMigrationsTable = "schema_migrations"
|
|
var dbLocked = false
|
|
|
|
var (
|
|
ErrNilConfig = fmt.Errorf("no config")
|
|
ErrNoKeyspace = fmt.Errorf("no keyspace provided")
|
|
ErrDatabaseDirty = fmt.Errorf("database is dirty")
|
|
)
|
|
|
|
type Config struct {
|
|
MigrationsTable string
|
|
KeyspaceName string
|
|
}
|
|
|
|
type Cassandra struct {
|
|
session *gocql.Session
|
|
isLocked bool
|
|
|
|
// Open and WithInstance need to guarantee that config is never nil
|
|
config *Config
|
|
}
|
|
|
|
func (p *Cassandra) Open(url string) (database.Driver, error) {
|
|
u, err := nurl.Parse(url)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Check for missing mandatory attributes
|
|
if len(u.Path) == 0 {
|
|
return nil, ErrNoKeyspace
|
|
}
|
|
|
|
migrationsTable := u.Query().Get("x-migrations-table")
|
|
if len(migrationsTable) == 0 {
|
|
migrationsTable = DefaultMigrationsTable
|
|
}
|
|
|
|
p.config = &Config{
|
|
KeyspaceName: u.Path,
|
|
MigrationsTable: migrationsTable,
|
|
}
|
|
|
|
cluster := gocql.NewCluster(u.Host)
|
|
cluster.Keyspace = u.Path[1:len(u.Path)]
|
|
cluster.Consistency = gocql.All
|
|
cluster.Timeout = 1 * time.Minute
|
|
|
|
|
|
// Retrieve query string configuration
|
|
if len(u.Query().Get("consistency")) > 0 {
|
|
var consistency gocql.Consistency
|
|
consistency, err = parseConsistency(u.Query().Get("consistency"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
cluster.Consistency = consistency
|
|
}
|
|
if len(u.Query().Get("protocol")) > 0 {
|
|
var protoversion int
|
|
protoversion, err = strconv.Atoi(u.Query().Get("protocol"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cluster.ProtoVersion = protoversion
|
|
}
|
|
if len(u.Query().Get("timeout")) > 0 {
|
|
var timeout time.Duration
|
|
timeout, err = time.ParseDuration(u.Query().Get("timeout"))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cluster.Timeout = timeout
|
|
}
|
|
|
|
p.session, err = cluster.CreateSession()
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := p.ensureVersionTable(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return p, nil
|
|
}
|
|
|
|
func (p *Cassandra) Close() error {
|
|
p.session.Close()
|
|
return nil
|
|
}
|
|
|
|
func (p *Cassandra) Lock() error {
|
|
if (dbLocked) {
|
|
return database.ErrLocked
|
|
}
|
|
dbLocked = true
|
|
return nil
|
|
}
|
|
|
|
func (p *Cassandra) Unlock() error {
|
|
dbLocked = false
|
|
return nil
|
|
}
|
|
|
|
func (p *Cassandra) Run(migration io.Reader) error {
|
|
migr, err := ioutil.ReadAll(migration)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
// run migration
|
|
query := string(migr[:])
|
|
if err := p.session.Query(query).Exec(); err != nil {
|
|
// TODO: cast to Cassandra error and get line number
|
|
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (p *Cassandra) SetVersion(version int, dirty bool) error {
|
|
query := `TRUNCATE "` + p.config.MigrationsTable + `"`
|
|
if err := p.session.Query(query).Exec(); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
if version >= 0 {
|
|
query = `INSERT INTO "` + p.config.MigrationsTable + `" (version, dirty) VALUES (?, ?)`
|
|
if err := p.session.Query(query, version, dirty).Exec(); err != nil {
|
|
return &database.Error{OrigErr: err, Query: []byte(query)}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
|
|
// Return current keyspace version
|
|
func (p *Cassandra) Version() (version int, dirty bool, err error) {
|
|
query := `SELECT version, dirty FROM "` + p.config.MigrationsTable + `" LIMIT 1`
|
|
err = p.session.Query(query).Scan(&version, &dirty)
|
|
switch {
|
|
case err == gocql.ErrNotFound:
|
|
return database.NilVersion, false, nil
|
|
|
|
case err != nil:
|
|
if _, ok := err.(*gocql.Error); ok {
|
|
return database.NilVersion, false, nil
|
|
}
|
|
return 0, false, &database.Error{OrigErr: err, Query: []byte(query)}
|
|
|
|
default:
|
|
return version, dirty, nil
|
|
}
|
|
}
|
|
|
|
func (p *Cassandra) Drop() error {
|
|
// select all tables in current schema
|
|
query := fmt.Sprintf(`SELECT table_name from system_schema.tables WHERE keyspace_name='%s'`, p.config.KeyspaceName[1:]) // Skip '/' character
|
|
iter := p.session.Query(query).Iter()
|
|
var tableName string
|
|
for iter.Scan(&tableName) {
|
|
err := p.session.Query(fmt.Sprintf(`DROP TABLE %s`, tableName)).Exec()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
}
|
|
// Re-create the version table
|
|
if err := p.ensureVersionTable(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
|
|
// Ensure version table exists
|
|
func (p *Cassandra) ensureVersionTable() error {
|
|
err := p.session.Query(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version bigint, dirty boolean, PRIMARY KEY(version))", p.config.MigrationsTable)).Exec()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
if _, _, err = p.Version(); err != nil {
|
|
return err
|
|
}
|
|
return nil
|
|
}
|
|
|
|
|
|
// ParseConsistency wraps gocql.ParseConsistency
|
|
// to return an error instead of a panicking.
|
|
func parseConsistency(consistencyStr string) (consistency gocql.Consistency, err error) {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
var ok bool
|
|
err, ok = r.(error)
|
|
if !ok {
|
|
err = fmt.Errorf("Failed to parse consistency \"%s\": %v", consistencyStr, r)
|
|
}
|
|
}
|
|
}()
|
|
consistency = gocql.ParseConsistency(consistencyStr)
|
|
|
|
return consistency, nil
|
|
}
|