mailserver: refactor mailserver's rate limiter (#1341)
This commit is contained in:
parent
a84dee4934
commit
8f2e347e4f
|
@ -0,0 +1,65 @@
|
|||
package mailserver
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"errors"
|
||||
|
||||
"github.com/ethereum/go-ethereum/common"
|
||||
)
|
||||
|
||||
const (
|
||||
// DBKeyLength is a size of the envelope key.
|
||||
DBKeyLength = common.HashLength + timestampLength
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrInvalidByteSize is returned when DBKey can't be created
|
||||
// from a byte slice because it has invalid length.
|
||||
ErrInvalidByteSize = errors.New("byte slice has invalid length")
|
||||
)
|
||||
|
||||
// DBKey key to be stored in a db.
|
||||
type DBKey struct {
|
||||
timestamp uint32
|
||||
hash common.Hash
|
||||
raw []byte
|
||||
}
|
||||
|
||||
// Bytes returns a bytes representation of the DBKey.
|
||||
func (k *DBKey) Bytes() []byte {
|
||||
return k.raw
|
||||
}
|
||||
|
||||
// NewDBKey creates a new DBKey with the given values.
|
||||
func NewDBKey(timestamp uint32, h common.Hash) *DBKey {
|
||||
var k DBKey
|
||||
k.timestamp = timestamp
|
||||
k.hash = h
|
||||
k.raw = make([]byte, DBKeyLength)
|
||||
binary.BigEndian.PutUint32(k.raw, k.timestamp)
|
||||
copy(k.raw[4:], k.hash[:])
|
||||
return &k
|
||||
}
|
||||
|
||||
// NewDBKeyFromBytes creates a DBKey from a byte slice.
|
||||
func NewDBKeyFromBytes(b []byte) (*DBKey, error) {
|
||||
if len(b) != DBKeyLength {
|
||||
return nil, ErrInvalidByteSize
|
||||
}
|
||||
|
||||
return &DBKey{
|
||||
raw: b,
|
||||
timestamp: binary.BigEndian.Uint32(b),
|
||||
hash: common.BytesToHash(b[4:]),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// mustNewDBKeyFromBytes panics if creating a key from a byte slice fails.
|
||||
// Check if a byte slice has DBKeyLength length before using it.
|
||||
func mustNewDBKeyFromBytes(b []byte) *DBKey {
|
||||
k, err := NewDBKeyFromBytes(b)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return k
|
||||
}
|
|
@ -5,45 +5,83 @@ import (
|
|||
"time"
|
||||
)
|
||||
|
||||
type limiter struct {
|
||||
mu sync.RWMutex
|
||||
type rateLimiter struct {
|
||||
sync.RWMutex
|
||||
|
||||
timeout time.Duration
|
||||
lifespan time.Duration // duration of the limit
|
||||
db map[string]time.Time
|
||||
|
||||
period time.Duration
|
||||
cancel chan struct{}
|
||||
}
|
||||
|
||||
func newLimiter(timeout time.Duration) *limiter {
|
||||
return &limiter{
|
||||
timeout: timeout,
|
||||
func newRateLimiter(duration time.Duration) *rateLimiter {
|
||||
return &rateLimiter{
|
||||
lifespan: duration,
|
||||
db: make(map[string]time.Time),
|
||||
period: time.Second,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *limiter) add(id string) {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
func (l *rateLimiter) Start() {
|
||||
cancel := make(chan struct{})
|
||||
|
||||
l.Lock()
|
||||
l.cancel = cancel
|
||||
l.Unlock()
|
||||
|
||||
go l.cleanUp(l.period, cancel)
|
||||
}
|
||||
|
||||
func (l *rateLimiter) Stop() {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
if l.cancel == nil {
|
||||
return
|
||||
}
|
||||
close(l.cancel)
|
||||
l.cancel = nil
|
||||
}
|
||||
|
||||
func (l *rateLimiter) Add(id string) {
|
||||
l.Lock()
|
||||
l.db[id] = time.Now()
|
||||
l.Unlock()
|
||||
}
|
||||
|
||||
func (l *limiter) isAllowed(id string) bool {
|
||||
l.mu.RLock()
|
||||
defer l.mu.RUnlock()
|
||||
func (l *rateLimiter) IsAllowed(id string) bool {
|
||||
l.RLock()
|
||||
defer l.RUnlock()
|
||||
|
||||
if lastRequestTime, ok := l.db[id]; ok {
|
||||
return lastRequestTime.Add(l.timeout).Before(time.Now())
|
||||
return lastRequestTime.Add(l.lifespan).Before(time.Now())
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (l *limiter) deleteExpired() {
|
||||
l.mu.Lock()
|
||||
defer l.mu.Unlock()
|
||||
func (l *rateLimiter) cleanUp(period time.Duration, cancel <-chan struct{}) {
|
||||
t := time.NewTicker(period)
|
||||
defer t.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-t.C:
|
||||
l.deleteExpired()
|
||||
case <-cancel:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *rateLimiter) deleteExpired() {
|
||||
l.Lock()
|
||||
defer l.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
for id, lastRequestTime := range l.db {
|
||||
if lastRequestTime.Add(l.timeout).Before(now) {
|
||||
if lastRequestTime.Add(l.lifespan).Before(now) {
|
||||
delete(l.db, id)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -48,16 +48,16 @@ func TestIsAllowed(t *testing.T) {
|
|||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.info, func(*testing.T) {
|
||||
l := newLimiter(tc.t)
|
||||
l := newRateLimiter(tc.t)
|
||||
l.db = tc.db()
|
||||
assert.Equal(t, tc.shouldBeAllowed, l.isAllowed(peerID))
|
||||
assert.Equal(t, tc.shouldBeAllowed, l.IsAllowed(peerID))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveExpiredRateLimits(t *testing.T) {
|
||||
peer := "peer"
|
||||
l := newLimiter(time.Duration(5) * time.Second)
|
||||
l := newRateLimiter(time.Duration(5) * time.Second)
|
||||
for i := 0; i < 10; i++ {
|
||||
peerID := fmt.Sprintf("%s%d", peer, i)
|
||||
l.db[peerID] = time.Now().Add(time.Duration(i*(-2)) * time.Second)
|
||||
|
@ -78,11 +78,31 @@ func TestRemoveExpiredRateLimits(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCleaningUpExpiredRateLimits(t *testing.T) {
|
||||
l := newRateLimiter(5 * time.Second)
|
||||
l.period = time.Millisecond * 10
|
||||
l.Start()
|
||||
defer l.Stop()
|
||||
|
||||
l.db["peer01"] = time.Now().Add(-1 * time.Second)
|
||||
l.db["peer02"] = time.Now().Add(-2 * time.Second)
|
||||
l.db["peer03"] = time.Now().Add(-10 * time.Second)
|
||||
|
||||
time.Sleep(time.Millisecond * 20)
|
||||
|
||||
_, ok := l.db["peer01"]
|
||||
assert.True(t, ok)
|
||||
_, ok = l.db["peer02"]
|
||||
assert.True(t, ok)
|
||||
_, ok = l.db["peer03"]
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func TestAddingLimts(t *testing.T) {
|
||||
peerID := "peerAdding"
|
||||
l := newLimiter(time.Duration(5) * time.Second)
|
||||
l := newRateLimiter(time.Duration(5) * time.Second)
|
||||
pre := time.Now()
|
||||
l.add(peerID)
|
||||
l.Add(peerID)
|
||||
post := time.Now()
|
||||
assert.True(t, l.db[peerID].After(pre))
|
||||
assert.True(t, l.db[peerID].Before(post))
|
||||
|
|
|
@ -27,7 +27,6 @@ import (
|
|||
"github.com/ethereum/go-ethereum/common"
|
||||
"github.com/ethereum/go-ethereum/crypto"
|
||||
"github.com/ethereum/go-ethereum/log"
|
||||
"github.com/ethereum/go-ethereum/metrics"
|
||||
"github.com/ethereum/go-ethereum/rlp"
|
||||
"github.com/status-im/status-go/db"
|
||||
"github.com/status-im/status-go/params"
|
||||
|
@ -46,28 +45,9 @@ const (
|
|||
var (
|
||||
errDirectoryNotProvided = errors.New("data directory not provided")
|
||||
errDecryptionMethodNotProvided = errors.New("decryption method is not provided")
|
||||
// By default go-ethereum/metrics creates dummy metrics that don't register anything.
|
||||
// Real metrics are collected only if -metrics flag is set
|
||||
requestProcessTimer = metrics.NewRegisteredTimer("mailserver/requestProcessTime", nil)
|
||||
requestProcessNetTimer = metrics.NewRegisteredTimer("mailserver/requestProcessNetTime", nil)
|
||||
requestsMeter = metrics.NewRegisteredMeter("mailserver/requests", nil)
|
||||
requestsBatchedCounter = metrics.NewRegisteredCounter("mailserver/requestsBatched", nil)
|
||||
requestErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestErrors", nil)
|
||||
sentEnvelopesMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopes", nil)
|
||||
sentEnvelopesSizeMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopesSize", nil)
|
||||
archivedMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopes", nil)
|
||||
archivedSizeMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopesSize", nil)
|
||||
archivedErrorsCounter = metrics.NewRegisteredCounter("mailserver/archiveErrors", nil)
|
||||
requestValidationErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestValidationErrors", nil)
|
||||
processRequestErrorsCounter = metrics.NewRegisteredCounter("mailserver/processRequestErrors", nil)
|
||||
historicResponseErrorsCounter = metrics.NewRegisteredCounter("mailserver/historicResponseErrors", nil)
|
||||
syncRequestsMeter = metrics.NewRegisteredMeter("mailserver/syncRequests", nil)
|
||||
)
|
||||
|
||||
const (
|
||||
// DBKeyLength is a size of the envelope key.
|
||||
DBKeyLength = common.HashLength + timestampLength
|
||||
|
||||
timestampLength = 4
|
||||
requestLimitLength = 4
|
||||
requestTimeRangeLength = timestampLength * 2
|
||||
|
@ -95,41 +75,8 @@ type WMailServer struct {
|
|||
symFilter *whisper.Filter
|
||||
asymFilter *whisper.Filter
|
||||
|
||||
muLimiter sync.RWMutex
|
||||
limiter *limiter
|
||||
tick *ticker
|
||||
}
|
||||
|
||||
// DBKey key to be stored on db.
|
||||
type DBKey struct {
|
||||
timestamp uint32
|
||||
hash common.Hash
|
||||
raw []byte
|
||||
}
|
||||
|
||||
// Bytes returns a bytes representation of the DBKey.
|
||||
func (k *DBKey) Bytes() []byte {
|
||||
return k.raw
|
||||
}
|
||||
|
||||
// NewDBKey creates a new DBKey with the given values.
|
||||
func NewDBKey(t uint32, h common.Hash) *DBKey {
|
||||
var k DBKey
|
||||
k.timestamp = t
|
||||
k.hash = h
|
||||
k.raw = make([]byte, DBKeyLength)
|
||||
binary.BigEndian.PutUint32(k.raw, k.timestamp)
|
||||
copy(k.raw[4:], k.hash[:])
|
||||
return &k
|
||||
}
|
||||
|
||||
// NewDBKeyFromBytes creates a DBKey from a byte slice.
|
||||
func NewDBKeyFromBytes(b []byte) *DBKey {
|
||||
return &DBKey{
|
||||
raw: b,
|
||||
timestamp: binary.BigEndian.Uint32(b),
|
||||
hash: common.BytesToHash(b[4:]),
|
||||
}
|
||||
muRateLimiter sync.RWMutex
|
||||
rateLimiter *rateLimiter
|
||||
}
|
||||
|
||||
// Init initializes mailServer.
|
||||
|
@ -150,7 +97,7 @@ func (s *WMailServer) Init(shh *whisper.Whisper, config *params.WhisperConfig) e
|
|||
if err := s.setupRequestMessageDecryptor(config); err != nil {
|
||||
return err
|
||||
}
|
||||
s.setupLimiter(time.Duration(config.MailServerRateLimit) * time.Second)
|
||||
s.setupRateLimiter(time.Duration(config.MailServerRateLimit) * time.Second)
|
||||
|
||||
// Open database in the last step in order not to init with error
|
||||
// and leave the database open by accident.
|
||||
|
@ -163,12 +110,12 @@ func (s *WMailServer) Init(shh *whisper.Whisper, config *params.WhisperConfig) e
|
|||
return nil
|
||||
}
|
||||
|
||||
// setupLimiter in case limit is bigger than 0 it will setup an automated
|
||||
// setupRateLimiter in case limit is bigger than 0 it will setup an automated
|
||||
// limit db cleanup.
|
||||
func (s *WMailServer) setupLimiter(limit time.Duration) {
|
||||
func (s *WMailServer) setupRateLimiter(limit time.Duration) {
|
||||
if limit > 0 {
|
||||
s.limiter = newLimiter(limit)
|
||||
s.setupMailServerCleanup(limit)
|
||||
s.rateLimiter = newRateLimiter(limit)
|
||||
s.rateLimiter.Start()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -203,15 +150,6 @@ func (s *WMailServer) setupRequestMessageDecryptor(config *params.WhisperConfig)
|
|||
return nil
|
||||
}
|
||||
|
||||
// setupMailServerCleanup periodically runs an expired entries deleteion for
|
||||
// stored limits.
|
||||
func (s *WMailServer) setupMailServerCleanup(period time.Duration) {
|
||||
if s.tick == nil {
|
||||
s.tick = &ticker{}
|
||||
}
|
||||
go s.tick.run(period, s.limiter.deleteExpired)
|
||||
}
|
||||
|
||||
// Close the mailserver and its associated db connection.
|
||||
func (s *WMailServer) Close() {
|
||||
if s.db != nil {
|
||||
|
@ -219,8 +157,8 @@ func (s *WMailServer) Close() {
|
|||
log.Error(fmt.Sprintf("s.db.Close failed: %s", err))
|
||||
}
|
||||
}
|
||||
if s.tick != nil {
|
||||
s.tick.stop()
|
||||
if s.rateLimiter != nil {
|
||||
s.rateLimiter.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -450,19 +388,22 @@ func (s *WMailServer) SyncMail(peer *whisper.Peer, request whisper.SyncMailReque
|
|||
// exceedsPeerRequests in case limit its been setup on the current server and limit
|
||||
// allows the query, it will store/update new query time for the current peer.
|
||||
func (s *WMailServer) exceedsPeerRequests(peer []byte) bool {
|
||||
s.muLimiter.RLock()
|
||||
defer s.muLimiter.RUnlock()
|
||||
s.muRateLimiter.RLock()
|
||||
defer s.muRateLimiter.RUnlock()
|
||||
|
||||
if s.rateLimiter == nil {
|
||||
return false
|
||||
}
|
||||
|
||||
if s.limiter != nil {
|
||||
peerID := string(peer)
|
||||
if !s.limiter.isAllowed(peerID) {
|
||||
if s.rateLimiter.IsAllowed(peerID) {
|
||||
s.rateLimiter.Add(peerID)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Info("peerID exceeded the number of requests per second")
|
||||
return true
|
||||
}
|
||||
s.limiter.add(peerID)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (s *WMailServer) createIterator(lower, upper uint32, cursor []byte) iterator.Iterator {
|
||||
var (
|
||||
|
@ -472,7 +413,7 @@ func (s *WMailServer) createIterator(lower, upper uint32, cursor []byte) iterato
|
|||
|
||||
kl = NewDBKey(lower, emptyHash)
|
||||
if len(cursor) == DBKeyLength {
|
||||
ku = NewDBKeyFromBytes(cursor)
|
||||
ku = mustNewDBKeyFromBytes(cursor)
|
||||
} else {
|
||||
ku = NewDBKey(upper+1, emptyHash)
|
||||
}
|
||||
|
|
|
@ -173,7 +173,7 @@ func (s *MailserverSuite) TestInit() {
|
|||
}
|
||||
|
||||
if tc.config.MailServerRateLimit > 0 {
|
||||
s.NotNil(mailServer.limiter)
|
||||
s.NotNil(mailServer.rateLimiter)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
@ -273,15 +273,15 @@ func (s *MailserverSuite) TestArchive() {
|
|||
}
|
||||
|
||||
func (s *MailserverSuite) TestManageLimits() {
|
||||
s.server.limiter = newLimiter(time.Duration(5) * time.Millisecond)
|
||||
s.server.rateLimiter = newRateLimiter(time.Duration(5) * time.Millisecond)
|
||||
s.False(s.server.exceedsPeerRequests([]byte("peerID")))
|
||||
s.Equal(1, len(s.server.limiter.db))
|
||||
firstSaved := s.server.limiter.db["peerID"]
|
||||
s.Equal(1, len(s.server.rateLimiter.db))
|
||||
firstSaved := s.server.rateLimiter.db["peerID"]
|
||||
|
||||
// second call when limit is not accomplished does not store a new limit
|
||||
s.True(s.server.exceedsPeerRequests([]byte("peerID")))
|
||||
s.Equal(1, len(s.server.limiter.db))
|
||||
s.Equal(firstSaved, s.server.limiter.db["peerID"])
|
||||
s.Equal(1, len(s.server.rateLimiter.db))
|
||||
s.Equal(firstSaved, s.server.rateLimiter.db["peerID"])
|
||||
}
|
||||
|
||||
func (s *MailserverSuite) TestDBKey() {
|
||||
|
|
|
@ -0,0 +1,22 @@
|
|||
package mailserver
|
||||
|
||||
import "github.com/ethereum/go-ethereum/metrics"
|
||||
|
||||
var (
|
||||
// By default go-ethereum/metrics creates dummy metrics that don't register anything.
|
||||
// Real metrics are collected only if -metrics flag is set
|
||||
requestProcessTimer = metrics.NewRegisteredTimer("mailserver/requestProcessTime", nil)
|
||||
requestProcessNetTimer = metrics.NewRegisteredTimer("mailserver/requestProcessNetTime", nil)
|
||||
requestsMeter = metrics.NewRegisteredMeter("mailserver/requests", nil)
|
||||
requestsBatchedCounter = metrics.NewRegisteredCounter("mailserver/requestsBatched", nil)
|
||||
requestErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestErrors", nil)
|
||||
sentEnvelopesMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopes", nil)
|
||||
sentEnvelopesSizeMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopesSize", nil)
|
||||
archivedMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopes", nil)
|
||||
archivedSizeMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopesSize", nil)
|
||||
archivedErrorsCounter = metrics.NewRegisteredCounter("mailserver/archiveErrors", nil)
|
||||
requestValidationErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestValidationErrors", nil)
|
||||
processRequestErrorsCounter = metrics.NewRegisteredCounter("mailserver/processRequestErrors", nil)
|
||||
historicResponseErrorsCounter = metrics.NewRegisteredCounter("mailserver/historicResponseErrors", nil)
|
||||
syncRequestsMeter = metrics.NewRegisteredMeter("mailserver/syncRequests", nil)
|
||||
)
|
|
@ -1,33 +0,0 @@
|
|||
package mailserver
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type ticker struct {
|
||||
mu sync.RWMutex
|
||||
timeTicker *time.Ticker
|
||||
}
|
||||
|
||||
func (t *ticker) run(period time.Duration, fn func()) {
|
||||
if t.timeTicker != nil {
|
||||
return
|
||||
}
|
||||
|
||||
tt := time.NewTicker(period)
|
||||
t.mu.Lock()
|
||||
t.timeTicker = tt
|
||||
t.mu.Unlock()
|
||||
go func() {
|
||||
for range tt.C {
|
||||
fn()
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (t *ticker) stop() {
|
||||
t.mu.RLock()
|
||||
t.timeTicker.Stop()
|
||||
t.mu.RUnlock()
|
||||
}
|
|
@ -448,6 +448,9 @@ func (s *PeerPoolSimulationSuite) TestUpdateTopicLimits() {
|
|||
func (s *PeerPoolSimulationSuite) TestMailServerPeersDiscovery() {
|
||||
s.setupEthV5()
|
||||
|
||||
// eliminate peer we won't use
|
||||
s.peers[2].Stop()
|
||||
|
||||
// Buffered channels must be used because we expect the events
|
||||
// to be in the same order. Use a buffer length greater than
|
||||
// the expected number of events to avoid deadlock.
|
||||
|
@ -515,5 +518,4 @@ func (s *PeerPoolSimulationSuite) TestMailServerPeersDiscovery() {
|
|||
disconnectedPeer := s.getPeerFromEvent(events, p2p.PeerEventTypeDrop)
|
||||
s.Equal(s.peers[0].Self().ID().String(), disconnectedPeer.String())
|
||||
s.Equal(signal.EventDiscoverySummary, s.getPoolEvent(poolEvents))
|
||||
s.Len(<-summaries, 0)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue