status-go/protocol/transport/processed_message_ids_cache.go

82 lines
1.6 KiB
Go

package transport
import (
"context"
"database/sql"
"strings"
)
type ProcessedMessageIDsCache struct {
db *sql.DB
}
func NewProcessedMessageIDsCache(db *sql.DB) *ProcessedMessageIDsCache {
return &ProcessedMessageIDsCache{db: db}
}
func (c *ProcessedMessageIDsCache) Hits(ids []string) (map[string]bool, error) {
hits := make(map[string]bool)
idsArgs := make([]interface{}, 0, len(ids))
for _, id := range ids {
idsArgs = append(idsArgs, id)
}
inVector := strings.Repeat("?, ", len(ids)-1) + "?"
query := "SELECT id FROM transport_message_cache WHERE id IN (" + inVector + ")" // nolint: gosec
rows, err := c.db.Query(query, idsArgs...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var id string
err := rows.Scan(&id)
if err != nil {
return nil, err
}
hits[id] = true
}
return hits, nil
}
func (c *ProcessedMessageIDsCache) Add(ids []string, timestamp uint64) (err error) {
var tx *sql.Tx
tx, err = c.db.BeginTx(context.Background(), &sql.TxOptions{})
if err != nil {
return
}
defer func() {
if err == nil {
err = tx.Commit()
return
}
// don't shadow original error
_ = tx.Rollback()
}()
for _, id := range ids {
var stmt *sql.Stmt
stmt, err = tx.Prepare(`INSERT INTO transport_message_cache(id,timestamp) VALUES (?, ?)`)
if err != nil {
return
}
_, err = stmt.Exec(id, timestamp)
if err != nil {
return
}
}
return
}
func (c *ProcessedMessageIDsCache) Clean(timestamp uint64) error {
_, err := c.db.Exec(`REMOVE FROM transport_message_cache WHERE timestamp < ?`, timestamp)
return err
}