diff --git a/sqlcipher.go b/sqlcipher.go new file mode 100644 index 0000000..03c632c --- /dev/null +++ b/sqlcipher.go @@ -0,0 +1,34 @@ +package sqlite3 + +import ( + "bytes" + "errors" + "os" +) + +// sqlite3Header defines the header string used by SQLite 3. +var sqlite3Header = []byte("SQLite format 3\000") + +// IsEncrypted returns true, if the file with filename is encrypted, false +// otherwise. If the file cannot be read properly an error is returned. +func IsEncrypted(filename string) (bool, error) { + // open file + db, err := os.Open(filename) + if err != nil { + return false, err + } + defer db.Close() + // read header + var header [16]byte + n, err := db.Read(header[:]) + if err != nil { + return false, err + } + if n != len(header) { + return false, errors.New("go-sqlcipher: could not read full header") + } + // SQLCipher encrypts also the header, the file is encrypted if the read + // header does not equal the header string used by SQLite 3. + encrypted := !bytes.Equal(header[:], sqlite3Header) + return encrypted, nil +} diff --git a/sqlcipher_test.go b/sqlcipher_test.go index a0f3add..05996c7 100644 --- a/sqlcipher_test.go +++ b/sqlcipher_test.go @@ -1,41 +1,60 @@ package sqlite3_test import ( + "crypto/rand" "database/sql" + "encoding/hex" + "errors" "fmt" + "io" "io/ioutil" + "os" "path/filepath" "testing" - _ "github.com/mutecomm/go-sqlcipher" + "github.com/mutecomm/go-sqlcipher" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) -var db *sql.DB +var ( + db *sql.DB + testDir = "go-sqlcipher_test" + tables = ` +CREATE TABLE KeyValueStore ( + KeyEntry TEXT NOT NULL UNIQUE, + ValueEntry TEXT NOT NULL +);` +) func init() { // create DB key := "passphrase" - tmpdir, err := ioutil.TempDir("", "sqlcipher_test") + tmpdir, err := ioutil.TempDir("", testDir) if err != nil { panic(err) } dbname := filepath.Join(tmpdir, "sqlcipher_test") - dbname += fmt.Sprintf("?_pragma_key=%s&_pragma_cipher_page_size=4096", key) - db, err = sql.Open("sqlite3", dbname) + dbnameWithDSN := dbname + fmt.Sprintf("?_pragma_key=%s&_pragma_cipher_page_size=4096", key) + db, err = sql.Open("sqlite3", dbnameWithDSN) if err != nil { panic(err) } - _, err = db.Exec(` -CREATE TABLE KeyValueStore ( - KeyEntry TEXT NOT NULL UNIQUE, - ValueEntry TEXT NOT NULL -);`) + _, err = db.Exec(tables) if err != nil { panic(err) } db.Close() + // make sure DB is encrypted + encrypted, err := sqlite3.IsEncrypted(dbname) + if err != nil { + panic(err) + } + if !encrypted { + panic(errors.New("go-sqlcipher: DB not encrypted")) + } // open DB for testing - db, err = sql.Open("sqlite3", dbname) + db, err = sql.Open("sqlite3", dbnameWithDSN) if err != nil { panic(err) } @@ -54,14 +73,10 @@ var mapping = map[string]string{ func TestSQLCipherParallelInsert(t *testing.T) { t.Parallel() insertValueQuery, err := db.Prepare("INSERT INTO KeyValueStore (KeyEntry, ValueEntry) VALUES (?, ?);") - if err != nil { - t.Fatal(err) - } + require.NoError(t, err) for key, value := range mapping { _, err := insertValueQuery.Exec(key, value) - if err != nil { - t.Error(err) - } + assert.NoError(t, err) } } @@ -75,11 +90,46 @@ func TestSQLCipherParallelSelect(t *testing.T) { var val string err := getValueQuery.QueryRow(key).Scan(&val) if err != sql.ErrNoRows { - if err != nil { - t.Error(err) - } else if val != value { - t.Errorf("%s != %s", val, value) + if assert.NoError(t, err) { + assert.Equal(t, value, val) } } } } + +func TestSQLCipherIsEncryptedFalse(t *testing.T) { + tmpdir, err := ioutil.TempDir("", testDir) + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + dbname := filepath.Join(tmpdir, "unencrypted.sqlite") + db, err := sql.Open("sqlite3", dbname) + require.NoError(t, err) + defer db.Close() + _, err = db.Exec(tables) + require.NoError(t, err) + encrypted, err := sqlite3.IsEncrypted(dbname) + if assert.NoError(t, err) { + assert.False(t, encrypted) + } +} + +func TestSQLCipherIsEncryptedTrue(t *testing.T) { + tmpdir, err := ioutil.TempDir("", testDir) + require.NoError(t, err) + defer os.RemoveAll(tmpdir) + dbname := filepath.Join(tmpdir, "encrypted.sqlite") + var key [32]byte + _, err = io.ReadFull(rand.Reader, key[:]) + require.NoError(t, err) + dbnameWithDSN := dbname + fmt.Sprintf("?_pragma_key=x'%s'", + hex.EncodeToString(key[:])) + db, err := sql.Open("sqlite3", dbnameWithDSN) + require.NoError(t, err) + defer db.Close() + _, err = db.Exec(tables) + require.NoError(t, err) + encrypted, err := sqlite3.IsEncrypted(dbname) + if assert.NoError(t, err) { + assert.True(t, encrypted) + } +}