101 lines
2.1 KiB
Go
101 lines
2.1 KiB
Go
package transport
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"strings"
|
|
)
|
|
|
|
type ProcessedMessageIDsCache struct {
|
|
db *sql.DB
|
|
}
|
|
|
|
func NewProcessedMessageIDsCache(db *sql.DB) *ProcessedMessageIDsCache {
|
|
return &ProcessedMessageIDsCache{db: db}
|
|
}
|
|
|
|
func (c *ProcessedMessageIDsCache) Clear() error {
|
|
_, err := c.db.Exec("DELETE FROM transport_message_cache")
|
|
return err
|
|
}
|
|
|
|
func (c *ProcessedMessageIDsCache) Hits(ids []string) (map[string]bool, error) {
|
|
hits := make(map[string]bool)
|
|
|
|
// Split the results into batches of 999 items.
|
|
// To prevent excessive memory allocations, the maximum value of a host parameter number
|
|
// is SQLITE_MAX_VARIABLE_NUMBER, which defaults to 999
|
|
batch := 999
|
|
for i := 0; i < len(ids); i += batch {
|
|
j := i + batch
|
|
if j > len(ids) {
|
|
j = len(ids)
|
|
}
|
|
|
|
currentBatch := ids[i:j]
|
|
|
|
idsArgs := make([]interface{}, 0, len(currentBatch))
|
|
for _, id := range currentBatch {
|
|
idsArgs = append(idsArgs, id)
|
|
}
|
|
|
|
inVector := strings.Repeat("?, ", len(currentBatch)-1) + "?"
|
|
query := "SELECT id FROM transport_message_cache WHERE id IN (" + inVector + ")" // nolint: gosec
|
|
|
|
rows, err := c.db.Query(query, idsArgs...)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
defer rows.Close()
|
|
|
|
for rows.Next() {
|
|
var id string
|
|
err := rows.Scan(&id)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
hits[id] = true
|
|
}
|
|
}
|
|
|
|
return hits, nil
|
|
}
|
|
|
|
func (c *ProcessedMessageIDsCache) Add(ids []string, timestamp uint64) (err error) {
|
|
var tx *sql.Tx
|
|
tx, err = c.db.BeginTx(context.Background(), &sql.TxOptions{})
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
defer func() {
|
|
if err == nil {
|
|
err = tx.Commit()
|
|
return
|
|
}
|
|
// don't shadow original error
|
|
_ = tx.Rollback()
|
|
}()
|
|
|
|
for _, id := range ids {
|
|
|
|
var stmt *sql.Stmt
|
|
stmt, err = tx.Prepare(`INSERT INTO transport_message_cache(id,timestamp) VALUES (?, ?)`)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
_, err = stmt.Exec(id, timestamp)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func (c *ProcessedMessageIDsCache) Clean(timestamp uint64) error {
|
|
_, err := c.db.Exec(`DELETE FROM transport_message_cache WHERE timestamp < ?`, timestamp)
|
|
return err
|
|
}
|