255 lines
6.7 KiB
Go
255 lines
6.7 KiB
Go
// Copyright (c) 2018 David Crawshaw <david@zentus.com>
|
|
// Copyright (c) 2021 Ross Light <rosss@zombiezen.com>
|
|
//
|
|
// Permission to use, copy, modify, and distribute this software for any
|
|
// purpose with or without fee is hereby granted, provided that the above
|
|
// copyright notice and this permission notice appear in all copies.
|
|
//
|
|
// THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
|
|
// WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
|
|
// MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
|
|
// ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
|
|
// WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
|
|
// ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
|
|
// OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
|
|
//
|
|
// SPDX-License-Identifier: ISC
|
|
|
|
package sqlitex
|
|
|
|
import (
|
|
"fmt"
|
|
"runtime"
|
|
"strings"
|
|
|
|
"zombiezen.com/go/sqlite"
|
|
)
|
|
|
|
// Save creates a named SQLite transaction using SAVEPOINT.
|
|
//
|
|
// On success Savepoint returns a releaseFn that will call either
|
|
// RELEASE or ROLLBACK depending on whether the parameter *error
|
|
// points to a nil or non-nil error. This is designed to be deferred.
|
|
//
|
|
// Example:
|
|
//
|
|
// func doWork(conn *sqlite.Conn) (err error) {
|
|
// defer sqlitex.Save(conn)(&err)
|
|
//
|
|
// // ... do work in the transaction
|
|
// }
|
|
//
|
|
// https://www.sqlite.org/lang_savepoint.html
|
|
func Save(conn *sqlite.Conn) (releaseFn func(*error)) {
|
|
name := "sqlitex.Save" // safe as names can be reused
|
|
var pc [3]uintptr
|
|
if n := runtime.Callers(0, pc[:]); n > 0 {
|
|
frames := runtime.CallersFrames(pc[:n])
|
|
if _, more := frames.Next(); more { // runtime.Callers
|
|
if _, more := frames.Next(); more { // savepoint.Save
|
|
frame, _ := frames.Next() // caller we care about
|
|
if frame.Function != "" {
|
|
name = frame.Function
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
releaseFn, err := savepoint(conn, name)
|
|
if err != nil {
|
|
if sqlite.ErrCode(err) == sqlite.ResultInterrupt {
|
|
return func(errp *error) {
|
|
if *errp == nil {
|
|
*errp = err
|
|
}
|
|
}
|
|
}
|
|
panic(err)
|
|
}
|
|
return releaseFn
|
|
}
|
|
|
|
func savepoint(conn *sqlite.Conn, name string) (releaseFn func(*error), err error) {
|
|
if strings.Contains(name, `"`) {
|
|
return nil, fmt.Errorf("sqlitex.Savepoint: invalid name: %q", name)
|
|
}
|
|
if err := Exec(conn, fmt.Sprintf("SAVEPOINT %q;", name), nil); err != nil {
|
|
return nil, err
|
|
}
|
|
// TODO(maybe)
|
|
// tracer := conn.Tracer()
|
|
// if tracer != nil {
|
|
// tracer.Push("TX " + name)
|
|
// }
|
|
releaseFn = func(errp *error) {
|
|
// TODO(maybe)
|
|
// if tracer != nil {
|
|
// tracer.Pop()
|
|
// }
|
|
recoverP := recover()
|
|
|
|
// If a query was interrupted or if a user exec'd COMMIT or
|
|
// ROLLBACK, then everything was already rolled back
|
|
// automatically, thus returning the connection to autocommit
|
|
// mode.
|
|
if conn.AutocommitEnabled() {
|
|
// There is nothing to rollback.
|
|
if recoverP != nil {
|
|
panic(recoverP)
|
|
}
|
|
return
|
|
}
|
|
|
|
if *errp == nil && recoverP == nil {
|
|
// Success path. Release the savepoint successfully.
|
|
*errp = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
|
|
if *errp == nil {
|
|
return
|
|
}
|
|
// Possible interrupt. Fall through to the error path.
|
|
if conn.AutocommitEnabled() {
|
|
// There is nothing to rollback.
|
|
if recoverP != nil {
|
|
panic(recoverP)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
orig := ""
|
|
if *errp != nil {
|
|
orig = (*errp).Error() + "\n\t"
|
|
}
|
|
|
|
// Error path.
|
|
|
|
// Always run ROLLBACK even if the connection has been interrupted.
|
|
oldDoneCh := conn.SetInterrupt(nil)
|
|
defer conn.SetInterrupt(oldDoneCh)
|
|
|
|
err := Exec(conn, fmt.Sprintf("ROLLBACK TO %q;", name), nil)
|
|
if err != nil {
|
|
panic(orig + err.Error())
|
|
}
|
|
err = Exec(conn, fmt.Sprintf("RELEASE %q;", name), nil)
|
|
if err != nil {
|
|
panic(orig + err.Error())
|
|
}
|
|
|
|
if recoverP != nil {
|
|
panic(recoverP)
|
|
}
|
|
}
|
|
return releaseFn, nil
|
|
}
|
|
|
|
// Transaction creates a DEFERRED SQLite transaction.
|
|
//
|
|
// On success Transaction returns an endFn that will call either
|
|
// COMMIT or ROLLBACK depending on whether the parameter *error
|
|
// points to a nil or non-nil error. This is designed to be deferred.
|
|
//
|
|
// https://www.sqlite.org/lang_transaction.html
|
|
func Transaction(conn *sqlite.Conn) (endFn func(*error)) {
|
|
endFn, err := transaction(conn, "DEFERRED")
|
|
if err != nil {
|
|
if sqlite.ErrCode(err) == sqlite.ResultInterrupt {
|
|
return func(errp *error) {
|
|
if *errp == nil {
|
|
*errp = err
|
|
}
|
|
}
|
|
}
|
|
panic(err)
|
|
}
|
|
return endFn
|
|
}
|
|
|
|
// ImmediateTransaction creates an IMMEDIATE SQLite transaction.
|
|
//
|
|
// On success ImmediateTransaction returns an endFn that will call either
|
|
// COMMIT or ROLLBACK depending on whether the parameter *error
|
|
// points to a nil or non-nil error. This is designed to be deferred.
|
|
//
|
|
// https://www.sqlite.org/lang_transaction.html
|
|
func ImmediateTransaction(conn *sqlite.Conn) (endFn func(*error), err error) {
|
|
endFn, err = transaction(conn, "IMMEDIATE")
|
|
if err != nil {
|
|
return func(*error) {}, err
|
|
}
|
|
return endFn, nil
|
|
}
|
|
|
|
// ExclusiveTransaction creates an EXCLUSIVE SQLite transaction.
|
|
//
|
|
// On success ImmediateTransaction returns an endFn that will call either
|
|
// COMMIT or ROLLBACK depending on whether the parameter *error
|
|
// points to a nil or non-nil error. This is designed to be deferred.
|
|
//
|
|
// https://www.sqlite.org/lang_transaction.html
|
|
func ExclusiveTransaction(conn *sqlite.Conn) (endFn func(*error), err error) {
|
|
endFn, err = transaction(conn, "EXCLUSIVE")
|
|
if err != nil {
|
|
return func(*error) {}, err
|
|
}
|
|
return endFn, nil
|
|
}
|
|
|
|
func transaction(conn *sqlite.Conn, mode string) (endFn func(*error), err error) {
|
|
if err := Exec(conn, "BEGIN "+mode+";", nil); err != nil {
|
|
return nil, err
|
|
}
|
|
endFn = func(errp *error) {
|
|
recoverP := recover()
|
|
|
|
// If a query was interrupted or if a user exec'd COMMIT or
|
|
// ROLLBACK, then everything was already rolled back
|
|
// automatically, thus returning the connection to autocommit
|
|
// mode.
|
|
if conn.AutocommitEnabled() {
|
|
// There is nothing to rollback.
|
|
if recoverP != nil {
|
|
panic(recoverP)
|
|
}
|
|
return
|
|
}
|
|
|
|
if *errp == nil && recoverP == nil {
|
|
// Success path. Commit the transaction.
|
|
*errp = Exec(conn, "COMMIT;", nil)
|
|
if *errp == nil {
|
|
return
|
|
}
|
|
// Possible interrupt. Fall through to the error path.
|
|
if conn.AutocommitEnabled() {
|
|
// There is nothing to rollback.
|
|
if recoverP != nil {
|
|
panic(recoverP)
|
|
}
|
|
return
|
|
}
|
|
}
|
|
|
|
orig := ""
|
|
if *errp != nil {
|
|
orig = (*errp).Error() + "\n\t"
|
|
}
|
|
|
|
// Error path.
|
|
|
|
// Always run ROLLBACK even if the connection has been interrupted.
|
|
oldDoneCh := conn.SetInterrupt(nil)
|
|
defer conn.SetInterrupt(oldDoneCh)
|
|
|
|
err := Exec(conn, "ROLLBACK;", nil)
|
|
if err != nil {
|
|
panic(orig + err.Error())
|
|
}
|
|
|
|
if recoverP != nil {
|
|
panic(recoverP)
|
|
}
|
|
}
|
|
return endFn, nil
|
|
}
|