status-go/protocol/transport/processed_message_ids_cache.go
2021-09-22 08:04:28 -04:00

101 lines
2.1 KiB
Go

package transport
import (
"context"
"database/sql"
"strings"
)
type ProcessedMessageIDsCache struct {
db *sql.DB
}
func NewProcessedMessageIDsCache(db *sql.DB) *ProcessedMessageIDsCache {
return &ProcessedMessageIDsCache{db: db}
}
func (c *ProcessedMessageIDsCache) Clear() error {
_, err := c.db.Exec("DELETE FROM transport_message_cache")
return err
}
func (c *ProcessedMessageIDsCache) Hits(ids []string) (map[string]bool, error) {
hits := make(map[string]bool)
// Split the results into batches of 999 items.
// To prevent excessive memory allocations, the maximum value of a host parameter number
// is SQLITE_MAX_VARIABLE_NUMBER, which defaults to 999
batch := 999
for i := 0; i < len(ids); i += batch {
j := i + batch
if j > len(ids) {
j = len(ids)
}
currentBatch := ids[i:j]
idsArgs := make([]interface{}, 0, len(currentBatch))
for _, id := range currentBatch {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(currentBatch)-1) + "?"
query := "SELECT id FROM transport_message_cache WHERE id IN (" + inVector + ")" // nolint: gosec
rows, err := c.db.Query(query, idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id string
err := rows.Scan(&id)
if err != nil {
return nil, err
}
hits[id] = true
}
}
return hits, nil
}
func (c *ProcessedMessageIDsCache) Add(ids []string, timestamp uint64) (err error) {
var tx *sql.Tx
tx, err = c.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
for _, id := range ids {
var stmt *sql.Stmt
stmt, err = tx.Prepare(`INSERT INTO transport_message_cache(id,timestamp) VALUES (?, ?)`)
if err != nil {
return
}
_, err = stmt.Exec(id, timestamp)
if err != nil {
return
}
}
return
}
func (c *ProcessedMessageIDsCache) Clean(timestamp uint64) error {
_, err := c.db.Exec(`DELETE FROM transport_message_cache WHERE timestamp < ?`, timestamp)
return err
}