mirror of
https://github.com/logos-messaging/go-libp2p-rendezvous.git
synced 2026-01-02 12:53:13 +00:00
Merge pull request #3 from n0izn0iz/sqlcipher
This commit is contained in:
commit
09965cd647
474
db/sqlcipher/db.go
Normal file
474
db/sqlcipher/db.go
Normal file
@ -0,0 +1,474 @@
|
||||
package sqlcipher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"database/sql"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
dbi "github.com/libp2p/go-libp2p-rendezvous/db"
|
||||
|
||||
_ "github.com/mutecomm/go-sqlcipher/v4"
|
||||
|
||||
logging "github.com/ipfs/go-log/v2"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
)
|
||||
|
||||
var log = logging.Logger("rendezvous/db")
|
||||
|
||||
type DB struct {
|
||||
db *sql.DB
|
||||
|
||||
insertPeerRegistration *sql.Stmt
|
||||
deletePeerRegistrations *sql.Stmt
|
||||
deletePeerRegistrationsNs *sql.Stmt
|
||||
countPeerRegistrations *sql.Stmt
|
||||
selectPeerRegistrations *sql.Stmt
|
||||
selectPeerRegistrationsNS *sql.Stmt
|
||||
selectPeerRegistrationsC *sql.Stmt
|
||||
selectPeerRegistrationsNSC *sql.Stmt
|
||||
deleteExpiredRegistrations *sql.Stmt
|
||||
getCounter *sql.Stmt
|
||||
|
||||
nonce []byte
|
||||
|
||||
cancel func()
|
||||
}
|
||||
|
||||
func OpenDB(ctx context.Context, path string) (*DB, error) {
|
||||
var create bool
|
||||
if path == ":memory:" {
|
||||
create = true
|
||||
} else {
|
||||
_, err := os.Stat(path)
|
||||
switch {
|
||||
case os.IsNotExist(err):
|
||||
create = true
|
||||
case err != nil:
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
db, err := sql.Open("sqlite3", path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if path == ":memory:" {
|
||||
// this is necessary to avoid creating a new database on each connection
|
||||
db.SetMaxOpenConns(1)
|
||||
}
|
||||
|
||||
rdb := &DB{db: db}
|
||||
if create {
|
||||
err = rdb.prepareDB()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
err = rdb.loadNonce()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = rdb.prepareStmts()
|
||||
if err != nil {
|
||||
db.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
bgctx, cancel := context.WithCancel(ctx)
|
||||
rdb.cancel = cancel
|
||||
go rdb.background(bgctx)
|
||||
|
||||
return rdb, nil
|
||||
}
|
||||
|
||||
func (db *DB) Close() error {
|
||||
db.cancel()
|
||||
return db.db.Close()
|
||||
}
|
||||
|
||||
func (db *DB) prepareDB() error {
|
||||
_, err := db.db.Exec("CREATE TABLE Registrations (counter INTEGER PRIMARY KEY AUTOINCREMENT, peer VARCHAR(64), ns VARCHAR, expire INTEGER, addrs VARBINARY)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.db.Exec("CREATE TABLE Nonce (nonce VARBINARY)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
nonce := make([]byte, 32)
|
||||
_, err = rand.Read(nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = db.db.Exec("INSERT INTO Nonce VALUES (?)", nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.nonce = nonce
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) loadNonce() error {
|
||||
var nonce []byte
|
||||
row := db.db.QueryRow("SELECT nonce FROM Nonce")
|
||||
err := row.Scan(&nonce)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
db.nonce = nonce
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) prepareStmts() error {
|
||||
stmt, err := db.db.Prepare("INSERT INTO Registrations VALUES (NULL, ?, ?, ?, ?)")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.insertPeerRegistration = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE peer = ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.deletePeerRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE peer = ? AND ns = ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.deletePeerRegistrationsNs = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT COUNT(*) FROM Registrations WHERE peer = ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.countPeerRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE ns = ? AND expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrationsNS = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE counter > ? AND expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrationsC = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT * FROM Registrations WHERE counter > ? AND ns = ? AND expire > ? LIMIT ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.selectPeerRegistrationsNSC = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("DELETE FROM Registrations WHERE expire < ?")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.deleteExpiredRegistrations = stmt
|
||||
|
||||
stmt, err = db.db.Prepare("SELECT MAX(counter) FROM Registrations")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
db.getCounter = stmt
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *DB) Register(p peer.ID, ns string, addrs [][]byte, ttl int) (uint64, error) {
|
||||
pid := p.Pretty()
|
||||
maddrs := packAddrs(addrs)
|
||||
expire := time.Now().Unix() + int64(ttl)
|
||||
|
||||
tx, err := db.db.Begin()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
delOld := tx.Stmt(db.deletePeerRegistrationsNs)
|
||||
insertNew := tx.Stmt(db.insertPeerRegistration)
|
||||
getCounter := tx.Stmt(db.getCounter)
|
||||
|
||||
_, err = delOld.Exec(pid, ns)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
_, err = insertNew.Exec(pid, ns, expire, maddrs)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var counter uint64
|
||||
row := getCounter.QueryRow()
|
||||
err = row.Scan(&counter)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return 0, err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
return counter, err
|
||||
}
|
||||
|
||||
func (db *DB) CountRegistrations(p peer.ID) (int, error) {
|
||||
pid := p.Pretty()
|
||||
|
||||
row := db.countPeerRegistrations.QueryRow(pid)
|
||||
|
||||
var count int
|
||||
err := row.Scan(&count)
|
||||
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (db *DB) Unregister(p peer.ID, ns string) error {
|
||||
pid := p.Pretty()
|
||||
|
||||
var err error
|
||||
|
||||
if ns == "" {
|
||||
_, err = db.deletePeerRegistrations.Exec(pid)
|
||||
} else {
|
||||
_, err = db.deletePeerRegistrationsNs.Exec(pid, ns)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *DB) Discover(ns string, cookie []byte, limit int) ([]dbi.RegistrationRecord, []byte, error) {
|
||||
now := time.Now().Unix()
|
||||
|
||||
var (
|
||||
counter int64
|
||||
rows *sql.Rows
|
||||
err error
|
||||
)
|
||||
|
||||
if cookie != nil {
|
||||
counter, err = unpackCookie(cookie)
|
||||
if err != nil {
|
||||
log.Errorf("error unpacking cookie: %s", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if counter > 0 {
|
||||
if ns == "" {
|
||||
rows, err = db.selectPeerRegistrationsC.Query(counter, now, limit)
|
||||
} else {
|
||||
rows, err = db.selectPeerRegistrationsNSC.Query(counter, ns, now, limit)
|
||||
}
|
||||
} else {
|
||||
if ns == "" {
|
||||
rows, err = db.selectPeerRegistrations.Query(now, limit)
|
||||
} else {
|
||||
rows, err = db.selectPeerRegistrationsNS.Query(ns, now, limit)
|
||||
}
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
log.Errorf("query error: %s", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
defer rows.Close()
|
||||
|
||||
regs := make([]dbi.RegistrationRecord, 0, limit)
|
||||
for rows.Next() {
|
||||
var (
|
||||
reg dbi.RegistrationRecord
|
||||
rid string
|
||||
rns string
|
||||
expire int64
|
||||
raddrs []byte
|
||||
addrs [][]byte
|
||||
p peer.ID
|
||||
)
|
||||
|
||||
err = rows.Scan(&counter, &rid, &rns, &expire, &raddrs)
|
||||
if err != nil {
|
||||
log.Errorf("row scan error: %s", err.Error())
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
p, err = peer.Decode(rid)
|
||||
if err != nil {
|
||||
log.Errorf("error decoding peer id: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
addrs, err := unpackAddrs(raddrs)
|
||||
if err != nil {
|
||||
log.Errorf("error unpacking address: %s", err.Error())
|
||||
continue
|
||||
}
|
||||
|
||||
reg.Id = p
|
||||
reg.Addrs = addrs
|
||||
reg.Ttl = int(expire - now)
|
||||
|
||||
if ns == "" {
|
||||
reg.Ns = rns
|
||||
}
|
||||
|
||||
regs = append(regs, reg)
|
||||
}
|
||||
|
||||
err = rows.Err()
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if counter > 0 {
|
||||
cookie = packCookie(counter, ns, db.nonce)
|
||||
}
|
||||
|
||||
return regs, cookie, nil
|
||||
}
|
||||
|
||||
func (db *DB) ValidCookie(ns string, cookie []byte) bool {
|
||||
return validCookie(cookie, ns, db.nonce)
|
||||
}
|
||||
|
||||
func (db *DB) background(ctx context.Context) {
|
||||
for {
|
||||
db.cleanupExpired()
|
||||
|
||||
select {
|
||||
case <-time.After(15 * time.Minute):
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (db *DB) cleanupExpired() {
|
||||
now := time.Now().Unix()
|
||||
_, err := db.deleteExpiredRegistrations.Exec(now)
|
||||
if err != nil {
|
||||
log.Errorf("error deleting expired registrations: %s", err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
func packAddrs(addrs [][]byte) []byte {
|
||||
packlen := 0
|
||||
for _, addr := range addrs {
|
||||
packlen = packlen + 2 + len(addr)
|
||||
}
|
||||
|
||||
packed := make([]byte, packlen)
|
||||
buf := packed
|
||||
for _, addr := range addrs {
|
||||
binary.BigEndian.PutUint16(buf, uint16(len(addr)))
|
||||
buf = buf[2:]
|
||||
copy(buf, addr)
|
||||
buf = buf[len(addr):]
|
||||
}
|
||||
|
||||
return packed
|
||||
}
|
||||
|
||||
func unpackAddrs(packed []byte) ([][]byte, error) {
|
||||
var addrs [][]byte
|
||||
|
||||
buf := packed
|
||||
for len(buf) > 1 {
|
||||
l := binary.BigEndian.Uint16(buf)
|
||||
buf = buf[2:]
|
||||
if len(buf) < int(l) {
|
||||
return nil, fmt.Errorf("bad packed address: not enough bytes %v %v", packed, buf)
|
||||
}
|
||||
addr := make([]byte, l)
|
||||
copy(addr, buf[:l])
|
||||
buf = buf[l:]
|
||||
addrs = append(addrs, addr)
|
||||
}
|
||||
|
||||
if len(buf) > 0 {
|
||||
return nil, fmt.Errorf("bad packed address: unprocessed bytes: %v %v", packed, buf)
|
||||
}
|
||||
|
||||
return addrs, nil
|
||||
}
|
||||
|
||||
// cookie: counter:SHA256(nonce + ns + counter)
|
||||
func packCookie(counter int64, ns string, nonce []byte) []byte {
|
||||
cbits := make([]byte, 8)
|
||||
binary.BigEndian.PutUint64(cbits, uint64(counter))
|
||||
|
||||
hash := sha256.New()
|
||||
_, err := hash.Write(nonce)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write([]byte(ns))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write(cbits)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return hash.Sum(cbits)
|
||||
}
|
||||
|
||||
func unpackCookie(cookie []byte) (int64, error) {
|
||||
if len(cookie) < 8 {
|
||||
return 0, fmt.Errorf("bad packed cookie: not enough bytes: %v", cookie)
|
||||
}
|
||||
|
||||
counter := binary.BigEndian.Uint64(cookie[:8])
|
||||
return int64(counter), nil
|
||||
}
|
||||
|
||||
func validCookie(cookie []byte, ns string, nonce []byte) bool {
|
||||
if len(cookie) != 40 {
|
||||
return false
|
||||
}
|
||||
|
||||
cbits := cookie[:8]
|
||||
hash := sha256.New()
|
||||
_, err := hash.Write(nonce)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write([]byte(ns))
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
_, err = hash.Write(cbits)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
hbits := hash.Sum(nil)
|
||||
|
||||
return bytes.Equal(cookie[8:], hbits)
|
||||
}
|
||||
512
db/sqlcipher/db_test.go
Normal file
512
db/sqlcipher/db_test.go
Normal file
@ -0,0 +1,512 @@
|
||||
package sqlcipher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"math/rand"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
func TestPackAddrs(t *testing.T) {
|
||||
addrs := make([][]byte, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
addrs[i] = make([]byte, rand.Intn(256))
|
||||
}
|
||||
|
||||
packed := packAddrs(addrs)
|
||||
unpacked, err := unpackAddrs(packed)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !equalAddrs(addrs, unpacked) {
|
||||
t.Fatal("unpacked addr not equal to original")
|
||||
}
|
||||
}
|
||||
|
||||
func equalAddrs(addrs1, addrs2 [][]byte) bool {
|
||||
if len(addrs1) != len(addrs2) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, addr1 := range addrs1 {
|
||||
addr2 := addrs2[i]
|
||||
if !bytes.Equal(addr1, addr2) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func TestPackCookie(t *testing.T) {
|
||||
nonce := make([]byte, 16)
|
||||
_, err := rand.Read(nonce)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
counter := rand.Int63()
|
||||
ns := "blah"
|
||||
|
||||
cookie := packCookie(counter, ns, nonce)
|
||||
|
||||
if !validCookie(cookie, ns, nonce) {
|
||||
t.Fatal("packed an invalid cookie")
|
||||
}
|
||||
|
||||
xcounter, err := unpackCookie(cookie)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if counter != xcounter {
|
||||
t.Fatal("unpacked cookie counter not equal to original")
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCloseMemDB(t *testing.T) {
|
||||
db, err := OpenDB(context.Background(), ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// let the flush goroutine run its cleanup act
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenCloseFSDB(t *testing.T) {
|
||||
db, err := OpenDB(context.Background(), "/tmp/rendezvous-test.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nonce1 := db.nonce
|
||||
|
||||
// let the flush goroutine run its cleanup act
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
db, err = OpenDB(context.Background(), "/tmp/rendezvous-test.db")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
nonce2 := db.nonce
|
||||
|
||||
// let the flush goroutine run its cleanup act
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
err = db.Close()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(nonce1, nonce2) {
|
||||
t.Fatal("persistent db nonces are not equal")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDBRegistrationAndDiscovery(t *testing.T) {
|
||||
db, err := OpenDB(context.Background(), ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p1, err := peer.Decode("QmVr26fY1tKyspEJBniVhqxQeEjhF78XerGiqWAwraVLQH")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p2, err := peer.Decode("QmUkUQgxXeggyaD5Ckv8ZqfW8wHBX6cYyeiyqvVZYzq5Bi")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr1, err := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/9999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs1 := [][]byte{addr1.Bytes()}
|
||||
|
||||
addr2, err := ma.NewMultiaddr("/ip4/2.2.2.2/tcp/9999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs2 := [][]byte{addr2.Bytes()}
|
||||
|
||||
// register p1 and do discovery
|
||||
_, err = db.Register(p1, "foo1", addrs1, 60)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err := db.CountRegistrations(p1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatal("registrations for p1 should be 1")
|
||||
}
|
||||
|
||||
rrs, cookie, err := db.Discover("foo1", nil, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 1 {
|
||||
t.Fatal("should have got 1 registration")
|
||||
}
|
||||
rr := rrs[0]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
// register p2 and do progressive discovery
|
||||
_, err = db.Register(p2, "foo1", addrs2, 60)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err = db.CountRegistrations(p2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatal("registrations for p2 should be 1")
|
||||
}
|
||||
|
||||
rrs, cookie, err = db.Discover("foo1", cookie, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 1 {
|
||||
t.Fatal("should have got 1 registration")
|
||||
}
|
||||
rr = rrs[0]
|
||||
if rr.Id != p2 {
|
||||
t.Fatal("expected p2 ID in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs2) {
|
||||
t.Fatal("expected p2's addrs in registration")
|
||||
}
|
||||
|
||||
// reregister p1 and do progressive discovery
|
||||
_, err = db.Register(p1, "foo1", addrs1, 60)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err = db.CountRegistrations(p1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatal("registrations for p1 should be 1")
|
||||
}
|
||||
|
||||
rrs, cookie, err = db.Discover("foo1", cookie, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 1 {
|
||||
t.Fatal("should have got 1 registration")
|
||||
}
|
||||
rr = rrs[0]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
// do a full discovery
|
||||
rrs, _, err = db.Discover("foo1", nil, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 2 {
|
||||
t.Fatal("should have got 2 registration")
|
||||
}
|
||||
rr = rrs[0]
|
||||
if rr.Id != p2 {
|
||||
t.Fatal("expected p2 ID in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs2) {
|
||||
t.Fatal("expected p2's addrs in registration")
|
||||
}
|
||||
|
||||
rr = rrs[1]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
// unregister p2 and redo discovery
|
||||
err = db.Unregister(p2, "foo1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err = db.CountRegistrations(p2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("registrations for p2 should be 0")
|
||||
}
|
||||
|
||||
rrs, _, err = db.Discover("foo1", nil, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 1 {
|
||||
t.Fatal("should have got 1 registration")
|
||||
}
|
||||
rr = rrs[0]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func TestDBRegistrationAndDiscoveryMultipleNS(t *testing.T) {
|
||||
db, err := OpenDB(context.Background(), ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p1, err := peer.Decode("QmVr26fY1tKyspEJBniVhqxQeEjhF78XerGiqWAwraVLQH")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p2, err := peer.Decode("QmUkUQgxXeggyaD5Ckv8ZqfW8wHBX6cYyeiyqvVZYzq5Bi")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr1, err := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/9999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs1 := [][]byte{addr1.Bytes()}
|
||||
|
||||
addr2, err := ma.NewMultiaddr("/ip4/2.2.2.2/tcp/9999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs2 := [][]byte{addr2.Bytes()}
|
||||
|
||||
_, err = db.Register(p1, "foo1", addrs1, 60)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Register(p1, "foo2", addrs1, 60)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err := db.CountRegistrations(p1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("registrations for p1 should be 2")
|
||||
}
|
||||
|
||||
rrs, cookie, err := db.Discover("", nil, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 2 {
|
||||
t.Fatal("should have got 2 registrations")
|
||||
}
|
||||
rr := rrs[0]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if rr.Ns != "foo1" {
|
||||
t.Fatal("expected namespace foo1 in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
rr = rrs[1]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if rr.Ns != "foo2" {
|
||||
t.Fatal("expected namespace foo1 in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
_, err = db.Register(p2, "foo1", addrs2, 60)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = db.Register(p2, "foo2", addrs2, 60)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err = db.CountRegistrations(p2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 2 {
|
||||
t.Fatal("registrations for p2 should be 2")
|
||||
}
|
||||
|
||||
rrs, cookie, err = db.Discover("", cookie, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 2 {
|
||||
t.Fatal("should have got 2 registrations")
|
||||
}
|
||||
rr = rrs[0]
|
||||
if rr.Id != p2 {
|
||||
t.Fatal("expected p2 ID in registration")
|
||||
}
|
||||
if rr.Ns != "foo1" {
|
||||
t.Fatal("expected namespace foo1 in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs2) {
|
||||
t.Fatal("expected p2's addrs in registration")
|
||||
}
|
||||
|
||||
rr = rrs[1]
|
||||
if rr.Id != p2 {
|
||||
t.Fatal("expected p2 ID in registration")
|
||||
}
|
||||
if rr.Ns != "foo2" {
|
||||
t.Fatal("expected namespace foo1 in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs2) {
|
||||
t.Fatal("expected p2's addrs in registration")
|
||||
}
|
||||
|
||||
err = db.Unregister(p2, "")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err = db.CountRegistrations(p2)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("registrations for p2 should be 0")
|
||||
}
|
||||
|
||||
rrs, _, err = db.Discover("", nil, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 2 {
|
||||
t.Fatal("should have got 2 registrations")
|
||||
}
|
||||
rr = rrs[0]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if rr.Ns != "foo1" {
|
||||
t.Fatal("expected namespace foo1 in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
rr = rrs[1]
|
||||
if rr.Id != p1 {
|
||||
t.Fatal("expected p1 ID in registration")
|
||||
}
|
||||
if rr.Ns != "foo2" {
|
||||
t.Fatal("expected namespace foo1 in registration")
|
||||
}
|
||||
if !equalAddrs(rr.Addrs, addrs1) {
|
||||
t.Fatal("expected p1's addrs in registration")
|
||||
}
|
||||
|
||||
db.Close()
|
||||
}
|
||||
|
||||
func TestDBCleanup(t *testing.T) {
|
||||
db, err := OpenDB(context.Background(), ":memory:")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
p1, err := peer.Decode("QmVr26fY1tKyspEJBniVhqxQeEjhF78XerGiqWAwraVLQH")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
addr1, err := ma.NewMultiaddr("/ip4/1.1.1.1/tcp/9999")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
addrs1 := [][]byte{addr1.Bytes()}
|
||||
|
||||
_, err = db.Register(p1, "foo1", addrs1, 1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
count, err := db.CountRegistrations(p1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 1 {
|
||||
t.Fatal("registrations for p1 should be 1")
|
||||
}
|
||||
|
||||
time.Sleep(2 * time.Second)
|
||||
|
||||
db.cleanupExpired()
|
||||
|
||||
count, err = db.CountRegistrations(p1)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if count != 0 {
|
||||
t.Fatal("registrations for p1 should be 0")
|
||||
}
|
||||
|
||||
rrs, _, err := db.Discover("foo1", nil, 100)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(rrs) != 0 {
|
||||
t.Fatal("should have got 0 registrations")
|
||||
}
|
||||
|
||||
db.Close()
|
||||
}
|
||||
1
go.mod
1
go.mod
@ -10,4 +10,5 @@ require (
|
||||
github.com/libp2p/go-libp2p-swarm v0.6.0
|
||||
github.com/mattn/go-sqlite3 v1.14.4
|
||||
github.com/multiformats/go-multiaddr v0.4.0
|
||||
github.com/mutecomm/go-sqlcipher/v4 v4.4.2
|
||||
)
|
||||
|
||||
2
go.sum
2
go.sum
@ -444,6 +444,8 @@ github.com/multiformats/go-varint v0.0.2/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXS
|
||||
github.com/multiformats/go-varint v0.0.5/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE=
|
||||
github.com/multiformats/go-varint v0.0.6 h1:gk85QWKxh3TazbLxED/NlDVv8+q+ReFJk7Y2W/KhfNY=
|
||||
github.com/multiformats/go-varint v0.0.6/go.mod h1:3Ls8CIEsrijN6+B7PbrXRPxHRPuXSrVKRY101jdMZYE=
|
||||
github.com/mutecomm/go-sqlcipher/v4 v4.4.2 h1:eM10bFtI4UvibIsKr10/QT7Yfz+NADfjZYh0GKrXUNc=
|
||||
github.com/mutecomm/go-sqlcipher/v4 v4.4.2/go.mod h1:mF2UmIpBnzFeBdu/ypTDb/LdbS0nk0dfSN1WUsWTjMA=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20161129095857-cc309e4a2223/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/mwitkow/go-conntrack v0.0.0-20190716064945-2f068394615f/go.mod h1:qRWi+5nqEBWmkhHvq77mSJWrCKwh8bxhgT7d/eI7P4U=
|
||||
github.com/nats-io/jwt v0.3.0/go.mod h1:fRYCDE99xlTsqUzISS1Bi75UBJ6ljOJQOAAu5VglpSg=
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user