mirror of https://github.com/status-im/migrate.git
Add postgres lib/pq error parsing
This commit is contained in:
parent
157433893c
commit
3eb26a65d3
|
@ -3,13 +3,15 @@
|
|||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
nurl "net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"context"
|
||||
"github.com/golang-migrate/migrate"
|
||||
"github.com/golang-migrate/migrate/database"
|
||||
"github.com/lib/pq"
|
||||
|
@ -167,13 +169,53 @@ func (p *Postgres) Run(migration io.Reader) error {
|
|||
// run migration
|
||||
query := string(migr[:])
|
||||
if _, err := p.conn.ExecContext(context.Background(), query); err != nil {
|
||||
// TODO: cast to postgress error and get line number
|
||||
if pgErr, ok := err.(*pq.Error); ok {
|
||||
var line uint
|
||||
var col uint
|
||||
var lineColOK bool
|
||||
if pgErr.Position != "" {
|
||||
if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil {
|
||||
if line, col, ok = computeLineFromPos(query, uint(pos)); ok {
|
||||
lineColOK = true
|
||||
}
|
||||
}
|
||||
}
|
||||
message := fmt.Sprintf("migration failed: %s", pgErr.Message)
|
||||
if lineColOK {
|
||||
message = fmt.Sprintf("%s (column %d)", message, col)
|
||||
}
|
||||
if pgErr.Detail != "" {
|
||||
message = fmt.Sprintf("%s, %s", message, pgErr.Detail)
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: message, Query: migr, Line: line}
|
||||
}
|
||||
return database.Error{OrigErr: err, Err: "migration failed", Query: migr}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func computeLineFromPos(s string, pos uint) (uint, uint, bool) {
|
||||
newLine := "\n"
|
||||
if i := strings.Index(s, "\r\n"); i >= 0 {
|
||||
newLine = "\r\n"
|
||||
}
|
||||
lines := strings.Split(s, newLine)
|
||||
remaining := int(pos)
|
||||
lineNr := 1
|
||||
var curr int
|
||||
for _, line := range lines {
|
||||
lineLength := len(line)
|
||||
curr += lineLength + 1
|
||||
if remaining < lineLength {
|
||||
return uint(lineNr), uint(remaining), true
|
||||
}
|
||||
remaining -= lineLength + 1
|
||||
lineNr++
|
||||
}
|
||||
return 0, 0, false
|
||||
}
|
||||
|
||||
func (p *Postgres) SetVersion(version int, dirty bool) error {
|
||||
tx, err := p.conn.BeginTx(context.Background(), &sql.TxOptions{})
|
||||
if err != nil {
|
||||
|
|
|
@ -4,16 +4,16 @@ package postgres
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
sqldriver "database/sql/driver"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"context"
|
||||
dt "github.com/golang-migrate/migrate/database/testing"
|
||||
mt "github.com/golang-migrate/migrate/testing"
|
||||
// "github.com/lib/pq"
|
||||
)
|
||||
|
||||
var versions = []mt.Version{
|
||||
|
@ -24,8 +24,12 @@ var versions = []mt.Version{
|
|||
{Image: "postgres:9.2"},
|
||||
}
|
||||
|
||||
func pgConnectionString(host string, port uint) string {
|
||||
return fmt.Sprintf("postgres://postgres@%s:%v/postgres?sslmode=disable", host, port)
|
||||
}
|
||||
|
||||
func isReady(i mt.Instance) bool {
|
||||
db, err := sql.Open("postgres", fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable", i.Host(), i.Port()))
|
||||
db, err := sql.Open("postgres", pgConnectionString(i.Host(), i.Port()))
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
|
@ -47,7 +51,7 @@ func Test(t *testing.T) {
|
|||
mt.ParallelTest(t, versions, isReady,
|
||||
func(t *testing.T, i mt.Instance) {
|
||||
p := &Postgres{}
|
||||
addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable", i.Host(), i.Port())
|
||||
addr := pgConnectionString(i.Host(), i.Port())
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
|
@ -61,7 +65,7 @@ func TestMultiStatement(t *testing.T) {
|
|||
mt.ParallelTest(t, versions, isReady,
|
||||
func(t *testing.T, i mt.Instance) {
|
||||
p := &Postgres{}
|
||||
addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable", i.Host(), i.Port())
|
||||
addr := pgConnectionString(i.Host(), i.Port())
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
|
@ -82,6 +86,27 @@ func TestMultiStatement(t *testing.T) {
|
|||
})
|
||||
}
|
||||
|
||||
func TestErrorParsing(t *testing.T) {
|
||||
mt.ParallelTest(t, versions, isReady,
|
||||
func(t *testing.T, i mt.Instance) {
|
||||
p := &Postgres{}
|
||||
addr := pgConnectionString(i.Host(), i.Port())
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
}
|
||||
defer d.Close()
|
||||
|
||||
wantErr := `migration failed: syntax error at or near "TABLEE" (column 37) in line 1: CREATE TABLE foo ` +
|
||||
`(foo text); CREATE TABLEE bar (bar text); (details: pq: syntax error at or near "TABLEE")`
|
||||
if err := d.Run(bytes.NewReader([]byte("CREATE TABLE foo (foo text); CREATE TABLEE bar (bar text);"))); err == nil {
|
||||
t.Fatal("expected err but got nil")
|
||||
} else if err.Error() != wantErr {
|
||||
t.Fatalf("expected '%s' but got '%s'", wantErr, err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestFilterCustomQuery(t *testing.T) {
|
||||
mt.ParallelTest(t, versions, isReady,
|
||||
func(t *testing.T, i mt.Instance) {
|
||||
|
@ -99,7 +124,7 @@ func TestWithSchema(t *testing.T) {
|
|||
mt.ParallelTest(t, versions, isReady,
|
||||
func(t *testing.T, i mt.Instance) {
|
||||
p := &Postgres{}
|
||||
addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable", i.Host(), i.Port())
|
||||
addr := pgConnectionString(i.Host(), i.Port())
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
|
@ -160,7 +185,7 @@ func TestPostgres_Lock(t *testing.T) {
|
|||
mt.ParallelTest(t, versions, isReady,
|
||||
func(t *testing.T, i mt.Instance) {
|
||||
p := &Postgres{}
|
||||
addr := fmt.Sprintf("postgres://postgres@%v:%v/postgres?sslmode=disable", i.Host(), i.Port())
|
||||
addr := pgConnectionString(i.Host(), i.Port())
|
||||
d, err := p.Open(addr)
|
||||
if err != nil {
|
||||
t.Fatalf("%v", err)
|
||||
|
@ -191,3 +216,53 @@ func TestPostgres_Lock(t *testing.T) {
|
|||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_computeLineFromPos(t *testing.T) {
|
||||
testcases := []struct {
|
||||
pos uint
|
||||
wantLine uint
|
||||
wantCol uint
|
||||
input string
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
9, 2, 1, "foo bar\nother foo bar", true,
|
||||
},
|
||||
{
|
||||
9, 2, 1, "foo bar\r\nother foo bar", true,
|
||||
},
|
||||
{
|
||||
0, 1, 0, "foo bar\nother foo bar", true,
|
||||
},
|
||||
{
|
||||
6, 1, 6, "foo bar\nother foo bar", true,
|
||||
},
|
||||
{
|
||||
8, 2, 0, "foo bar\nother foo bar", true,
|
||||
},
|
||||
{
|
||||
15, 3, 6, "foo bar\n\nother foo bar", true,
|
||||
},
|
||||
{
|
||||
15, 3, 6, "foo bar\r\n\r\nother foo bar", true,
|
||||
},
|
||||
{
|
||||
999, 0, 0, "foo bar\nother foo bar", false,
|
||||
},
|
||||
}
|
||||
for i, tc := range testcases {
|
||||
t.Run("tc"+strconv.Itoa(i), func(t *testing.T) {
|
||||
gotLine, gotCol, gotOK := computeLineFromPos(tc.input, tc.pos)
|
||||
if gotOK != tc.wantOk {
|
||||
t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK)
|
||||
}
|
||||
if gotLine != tc.wantLine {
|
||||
t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine)
|
||||
}
|
||||
if gotCol != tc.wantCol {
|
||||
t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue