add IsEncrypted() and tests to check encryption
This commit is contained in:
parent
6bf6f0208d
commit
81d39a96e8
|
@ -0,0 +1,34 @@
|
|||
package sqlite3
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"os"
|
||||
)
|
||||
|
||||
// sqlite3Header defines the header string used by SQLite 3.
|
||||
var sqlite3Header = []byte("SQLite format 3\000")
|
||||
|
||||
// IsEncrypted returns true, if the file with filename is encrypted, false
|
||||
// otherwise. If the file cannot be read properly an error is returned.
|
||||
func IsEncrypted(filename string) (bool, error) {
|
||||
// open file
|
||||
db, err := os.Open(filename)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer db.Close()
|
||||
// read header
|
||||
var header [16]byte
|
||||
n, err := db.Read(header[:])
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if n != len(header) {
|
||||
return false, errors.New("go-sqlcipher: could not read full header")
|
||||
}
|
||||
// SQLCipher encrypts also the header, the file is encrypted if the read
|
||||
// header does not equal the header string used by SQLite 3.
|
||||
encrypted := !bytes.Equal(header[:], sqlite3Header)
|
||||
return encrypted, nil
|
||||
}
|
|
@ -1,41 +1,60 @@
|
|||
package sqlite3_test
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"database/sql"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
_ "github.com/mutecomm/go-sqlcipher"
|
||||
"github.com/mutecomm/go-sqlcipher"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var db *sql.DB
|
||||
var (
|
||||
db *sql.DB
|
||||
testDir = "go-sqlcipher_test"
|
||||
tables = `
|
||||
CREATE TABLE KeyValueStore (
|
||||
KeyEntry TEXT NOT NULL UNIQUE,
|
||||
ValueEntry TEXT NOT NULL
|
||||
);`
|
||||
)
|
||||
|
||||
func init() {
|
||||
// create DB
|
||||
key := "passphrase"
|
||||
tmpdir, err := ioutil.TempDir("", "sqlcipher_test")
|
||||
tmpdir, err := ioutil.TempDir("", testDir)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
dbname := filepath.Join(tmpdir, "sqlcipher_test")
|
||||
dbname += fmt.Sprintf("?_pragma_key=%s&_pragma_cipher_page_size=4096", key)
|
||||
db, err = sql.Open("sqlite3", dbname)
|
||||
dbnameWithDSN := dbname + fmt.Sprintf("?_pragma_key=%s&_pragma_cipher_page_size=4096", key)
|
||||
db, err = sql.Open("sqlite3", dbnameWithDSN)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = db.Exec(`
|
||||
CREATE TABLE KeyValueStore (
|
||||
KeyEntry TEXT NOT NULL UNIQUE,
|
||||
ValueEntry TEXT NOT NULL
|
||||
);`)
|
||||
_, err = db.Exec(tables)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
db.Close()
|
||||
// make sure DB is encrypted
|
||||
encrypted, err := sqlite3.IsEncrypted(dbname)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if !encrypted {
|
||||
panic(errors.New("go-sqlcipher: DB not encrypted"))
|
||||
}
|
||||
// open DB for testing
|
||||
db, err = sql.Open("sqlite3", dbname)
|
||||
db, err = sql.Open("sqlite3", dbnameWithDSN)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
@ -54,14 +73,10 @@ var mapping = map[string]string{
|
|||
func TestSQLCipherParallelInsert(t *testing.T) {
|
||||
t.Parallel()
|
||||
insertValueQuery, err := db.Prepare("INSERT INTO KeyValueStore (KeyEntry, ValueEntry) VALUES (?, ?);")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
require.NoError(t, err)
|
||||
for key, value := range mapping {
|
||||
_, err := insertValueQuery.Exec(key, value)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -75,11 +90,46 @@ func TestSQLCipherParallelSelect(t *testing.T) {
|
|||
var val string
|
||||
err := getValueQuery.QueryRow(key).Scan(&val)
|
||||
if err != sql.ErrNoRows {
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
} else if val != value {
|
||||
t.Errorf("%s != %s", val, value)
|
||||
if assert.NoError(t, err) {
|
||||
assert.Equal(t, value, val)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCipherIsEncryptedFalse(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir("", testDir)
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpdir)
|
||||
dbname := filepath.Join(tmpdir, "unencrypted.sqlite")
|
||||
db, err := sql.Open("sqlite3", dbname)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
_, err = db.Exec(tables)
|
||||
require.NoError(t, err)
|
||||
encrypted, err := sqlite3.IsEncrypted(dbname)
|
||||
if assert.NoError(t, err) {
|
||||
assert.False(t, encrypted)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQLCipherIsEncryptedTrue(t *testing.T) {
|
||||
tmpdir, err := ioutil.TempDir("", testDir)
|
||||
require.NoError(t, err)
|
||||
defer os.RemoveAll(tmpdir)
|
||||
dbname := filepath.Join(tmpdir, "encrypted.sqlite")
|
||||
var key [32]byte
|
||||
_, err = io.ReadFull(rand.Reader, key[:])
|
||||
require.NoError(t, err)
|
||||
dbnameWithDSN := dbname + fmt.Sprintf("?_pragma_key=x'%s'",
|
||||
hex.EncodeToString(key[:]))
|
||||
db, err := sql.Open("sqlite3", dbnameWithDSN)
|
||||
require.NoError(t, err)
|
||||
defer db.Close()
|
||||
_, err = db.Exec(tables)
|
||||
require.NoError(t, err)
|
||||
encrypted, err := sqlite3.IsEncrypted(dbname)
|
||||
if assert.NoError(t, err) {
|
||||
assert.True(t, encrypted)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue