diff --git a/database/postgres/postgres.go b/database/postgres/postgres.go index d4d8ea6..25e06b6 100644 --- a/database/postgres/postgres.go +++ b/database/postgres/postgres.go @@ -175,9 +175,7 @@ func (p *Postgres) Run(migration io.Reader) error { var lineColOK bool if pgErr.Position != "" { if pos, err := strconv.ParseUint(pgErr.Position, 10, 64); err == nil { - if line, col, ok = computeLineFromPos(query, uint(pos)); ok { - lineColOK = true - } + line, col, lineColOK = computeLineFromPos(query, int(pos)) } } message := fmt.Sprintf("migration failed: %s", pgErr.Message) @@ -195,25 +193,39 @@ func (p *Postgres) Run(migration io.Reader) error { return nil } -func computeLineFromPos(s string, pos uint) (uint, uint, bool) { - newLine := "\n" - if i := strings.Index(s, "\r\n"); i >= 0 { - newLine = "\r\n" +func computeLineFromPos(s string, pos int) (line uint, col uint, ok bool) { + // replace crlf with lf + s = strings.Replace(s, "\r\n", "\n", -1) + // pg docs: pos uses index 1 for the first character, and positions are measured in characters not bytes + runes := []rune(s) + if pos > len(runes) { + return 0, 0, false } - lines := strings.Split(s, newLine) - remaining := int(pos) - lineNr := 1 - var curr int - for _, line := range lines { - lineLength := len(line) - curr += lineLength + 1 - if remaining < lineLength { - return uint(lineNr), uint(remaining), true + sel := runes[:pos] + line = uint(runesCount(sel, newLine) + 1) + col = uint(pos - 1 - runesLastIndex(sel, newLine)) + return line, col, true +} + +const newLine = '\n' + +func runesCount(input []rune, target rune) int { + var count int + for _, r := range input { + if r == target { + count++ } - remaining -= lineLength + 1 - lineNr++ } - return 0, 0, false + return count +} + +func runesLastIndex(input []rune, target rune) int { + for i := len(input) - 1; i >= 0; i-- { + if input[i] == target { + return i + } + } + return -1 } func (p *Postgres) SetVersion(version int, dirty bool) error { diff --git a/database/postgres/postgres_test.go b/database/postgres/postgres_test.go index aed3d38..fced138 100644 --- a/database/postgres/postgres_test.go +++ b/database/postgres/postgres_test.go @@ -10,6 +10,7 @@ import ( "fmt" "io" "strconv" + "strings" "testing" dt "github.com/golang-migrate/migrate/database/testing" @@ -219,49 +220,80 @@ func TestPostgres_Lock(t *testing.T) { func Test_computeLineFromPos(t *testing.T) { testcases := []struct { - pos uint + pos int wantLine uint wantCol uint input string wantOk bool }{ { - 9, 2, 1, "foo bar\nother foo bar", true, + 15, 2, 6, "SELECT *\nFROM foo", true, // foo table does not exists }, { - 9, 2, 1, "foo bar\r\nother foo bar", true, + 16, 3, 6, "SELECT *\n\nFROM foo", true, // foo table does not exists, empty line }, { - 0, 1, 0, "foo bar\nother foo bar", true, + 25, 3, 7, "SELECT *\nFROM foo\nWHERE x", true, // x column error }, { - 6, 1, 6, "foo bar\nother foo bar", true, + 27, 5, 7, "SELECT *\n\nFROM foo\n\nWHERE x", true, // x column error, empty lines }, { - 8, 2, 0, "foo bar\nother foo bar", true, + 10, 2, 1, "SELECT *\nFROMM foo", true, // FROMM typo }, { - 15, 3, 6, "foo bar\n\nother foo bar", true, + 11, 3, 1, "SELECT *\n\nFROMM foo", true, // FROMM typo, empty line }, { - 15, 3, 6, "foo bar\r\n\r\nother foo bar", true, + 17, 2, 8, "SELECT *\nFROM foo", true, // last character }, { - 999, 0, 0, "foo bar\nother foo bar", false, + 18, 0, 0, "SELECT *\nFROM foo", false, // invalid position }, } for i, tc := range testcases { t.Run("tc"+strconv.Itoa(i), func(t *testing.T) { - gotLine, gotCol, gotOK := computeLineFromPos(tc.input, tc.pos) - if gotOK != tc.wantOk { - t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK) - } - if gotLine != tc.wantLine { - t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine) - } - if gotCol != tc.wantCol { - t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol) + run := func(crlf bool, nonASCII bool) { + var name string + if crlf { + name = "crlf" + } else { + name = "lf" + } + if nonASCII { + name += "-nonascii" + } else { + name += "-ascii" + } + t.Run(name, func(t *testing.T) { + input := tc.input + if crlf { + input = strings.Replace(input, "\n", "\r\n", -1) + } + if nonASCII { + input = strings.Replace(input, "FROM", "FRÖM", -1) + } + gotLine, gotCol, gotOK := computeLineFromPos(input, tc.pos) + + if tc.wantOk { + t.Logf("pos %d, want %d:%d, %#v", tc.pos, tc.wantLine, tc.wantCol, input) + } + + if gotOK != tc.wantOk { + t.Fatalf("expected ok %v but got %v", tc.wantOk, gotOK) + } + if gotLine != tc.wantLine { + t.Fatalf("expected line %d but got %d", tc.wantLine, gotLine) + } + if gotCol != tc.wantCol { + t.Fatalf("expected col %d but got %d", tc.wantCol, gotCol) + } + }) } + run(false, false) + run(true, false) + run(false, true) + run(true, true) }) }