diff --git a/mailserver/db_key.go b/mailserver/db_key.go new file mode 100644 index 000000000..5b3e58b07 --- /dev/null +++ b/mailserver/db_key.go @@ -0,0 +1,65 @@ +package mailserver + +import ( + "encoding/binary" + "errors" + + "github.com/ethereum/go-ethereum/common" +) + +const ( + // DBKeyLength is a size of the envelope key. + DBKeyLength = common.HashLength + timestampLength +) + +var ( + // ErrInvalidByteSize is returned when DBKey can't be created + // from a byte slice because it has invalid length. + ErrInvalidByteSize = errors.New("byte slice has invalid length") +) + +// DBKey key to be stored in a db. +type DBKey struct { + timestamp uint32 + hash common.Hash + raw []byte +} + +// Bytes returns a bytes representation of the DBKey. +func (k *DBKey) Bytes() []byte { + return k.raw +} + +// NewDBKey creates a new DBKey with the given values. +func NewDBKey(timestamp uint32, h common.Hash) *DBKey { + var k DBKey + k.timestamp = timestamp + k.hash = h + k.raw = make([]byte, DBKeyLength) + binary.BigEndian.PutUint32(k.raw, k.timestamp) + copy(k.raw[4:], k.hash[:]) + return &k +} + +// NewDBKeyFromBytes creates a DBKey from a byte slice. +func NewDBKeyFromBytes(b []byte) (*DBKey, error) { + if len(b) != DBKeyLength { + return nil, ErrInvalidByteSize + } + + return &DBKey{ + raw: b, + timestamp: binary.BigEndian.Uint32(b), + hash: common.BytesToHash(b[4:]), + }, nil +} + +// mustNewDBKeyFromBytes panics if creating a key from a byte slice fails. +// Check if a byte slice has DBKeyLength length before using it. +func mustNewDBKeyFromBytes(b []byte) *DBKey { + k, err := NewDBKeyFromBytes(b) + if err != nil { + panic(err) + } + return k +} diff --git a/mailserver/limiter.go b/mailserver/limiter.go index 12616ea7a..64310b610 100644 --- a/mailserver/limiter.go +++ b/mailserver/limiter.go @@ -5,45 +5,83 @@ import ( "time" ) -type limiter struct { - mu sync.RWMutex +type rateLimiter struct { + sync.RWMutex - timeout time.Duration - db map[string]time.Time + lifespan time.Duration // duration of the limit + db map[string]time.Time + + period time.Duration + cancel chan struct{} } -func newLimiter(timeout time.Duration) *limiter { - return &limiter{ - timeout: timeout, - db: make(map[string]time.Time), +func newRateLimiter(duration time.Duration) *rateLimiter { + return &rateLimiter{ + lifespan: duration, + db: make(map[string]time.Time), + period: time.Second, } } -func (l *limiter) add(id string) { - l.mu.Lock() - defer l.mu.Unlock() +func (l *rateLimiter) Start() { + cancel := make(chan struct{}) - l.db[id] = time.Now() + l.Lock() + l.cancel = cancel + l.Unlock() + + go l.cleanUp(l.period, cancel) } -func (l *limiter) isAllowed(id string) bool { - l.mu.RLock() - defer l.mu.RUnlock() +func (l *rateLimiter) Stop() { + l.Lock() + defer l.Unlock() + + if l.cancel == nil { + return + } + close(l.cancel) + l.cancel = nil +} + +func (l *rateLimiter) Add(id string) { + l.Lock() + l.db[id] = time.Now() + l.Unlock() +} + +func (l *rateLimiter) IsAllowed(id string) bool { + l.RLock() + defer l.RUnlock() if lastRequestTime, ok := l.db[id]; ok { - return lastRequestTime.Add(l.timeout).Before(time.Now()) + return lastRequestTime.Add(l.lifespan).Before(time.Now()) } return true } -func (l *limiter) deleteExpired() { - l.mu.Lock() - defer l.mu.Unlock() +func (l *rateLimiter) cleanUp(period time.Duration, cancel <-chan struct{}) { + t := time.NewTicker(period) + defer t.Stop() + + for { + select { + case <-t.C: + l.deleteExpired() + case <-cancel: + return + } + } +} + +func (l *rateLimiter) deleteExpired() { + l.Lock() + defer l.Unlock() now := time.Now() for id, lastRequestTime := range l.db { - if lastRequestTime.Add(l.timeout).Before(now) { + if lastRequestTime.Add(l.lifespan).Before(now) { delete(l.db, id) } } diff --git a/mailserver/limiter_test.go b/mailserver/limiter_test.go index 0e3905785..2da3c7946 100644 --- a/mailserver/limiter_test.go +++ b/mailserver/limiter_test.go @@ -48,16 +48,16 @@ func TestIsAllowed(t *testing.T) { for _, tc := range testCases { t.Run(tc.info, func(*testing.T) { - l := newLimiter(tc.t) + l := newRateLimiter(tc.t) l.db = tc.db() - assert.Equal(t, tc.shouldBeAllowed, l.isAllowed(peerID)) + assert.Equal(t, tc.shouldBeAllowed, l.IsAllowed(peerID)) }) } } func TestRemoveExpiredRateLimits(t *testing.T) { peer := "peer" - l := newLimiter(time.Duration(5) * time.Second) + l := newRateLimiter(time.Duration(5) * time.Second) for i := 0; i < 10; i++ { peerID := fmt.Sprintf("%s%d", peer, i) l.db[peerID] = time.Now().Add(time.Duration(i*(-2)) * time.Second) @@ -78,11 +78,31 @@ func TestRemoveExpiredRateLimits(t *testing.T) { } } +func TestCleaningUpExpiredRateLimits(t *testing.T) { + l := newRateLimiter(5 * time.Second) + l.period = time.Millisecond * 10 + l.Start() + defer l.Stop() + + l.db["peer01"] = time.Now().Add(-1 * time.Second) + l.db["peer02"] = time.Now().Add(-2 * time.Second) + l.db["peer03"] = time.Now().Add(-10 * time.Second) + + time.Sleep(time.Millisecond * 20) + + _, ok := l.db["peer01"] + assert.True(t, ok) + _, ok = l.db["peer02"] + assert.True(t, ok) + _, ok = l.db["peer03"] + assert.False(t, ok) +} + func TestAddingLimts(t *testing.T) { peerID := "peerAdding" - l := newLimiter(time.Duration(5) * time.Second) + l := newRateLimiter(time.Duration(5) * time.Second) pre := time.Now() - l.add(peerID) + l.Add(peerID) post := time.Now() assert.True(t, l.db[peerID].After(pre)) assert.True(t, l.db[peerID].Before(post)) diff --git a/mailserver/mailserver.go b/mailserver/mailserver.go index 2ef00a29e..554c78797 100644 --- a/mailserver/mailserver.go +++ b/mailserver/mailserver.go @@ -27,7 +27,6 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/log" - "github.com/ethereum/go-ethereum/metrics" "github.com/ethereum/go-ethereum/rlp" "github.com/status-im/status-go/db" "github.com/status-im/status-go/params" @@ -46,28 +45,9 @@ const ( var ( errDirectoryNotProvided = errors.New("data directory not provided") errDecryptionMethodNotProvided = errors.New("decryption method is not provided") - // By default go-ethereum/metrics creates dummy metrics that don't register anything. - // Real metrics are collected only if -metrics flag is set - requestProcessTimer = metrics.NewRegisteredTimer("mailserver/requestProcessTime", nil) - requestProcessNetTimer = metrics.NewRegisteredTimer("mailserver/requestProcessNetTime", nil) - requestsMeter = metrics.NewRegisteredMeter("mailserver/requests", nil) - requestsBatchedCounter = metrics.NewRegisteredCounter("mailserver/requestsBatched", nil) - requestErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestErrors", nil) - sentEnvelopesMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopes", nil) - sentEnvelopesSizeMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopesSize", nil) - archivedMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopes", nil) - archivedSizeMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopesSize", nil) - archivedErrorsCounter = metrics.NewRegisteredCounter("mailserver/archiveErrors", nil) - requestValidationErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestValidationErrors", nil) - processRequestErrorsCounter = metrics.NewRegisteredCounter("mailserver/processRequestErrors", nil) - historicResponseErrorsCounter = metrics.NewRegisteredCounter("mailserver/historicResponseErrors", nil) - syncRequestsMeter = metrics.NewRegisteredMeter("mailserver/syncRequests", nil) ) const ( - // DBKeyLength is a size of the envelope key. - DBKeyLength = common.HashLength + timestampLength - timestampLength = 4 requestLimitLength = 4 requestTimeRangeLength = timestampLength * 2 @@ -95,41 +75,8 @@ type WMailServer struct { symFilter *whisper.Filter asymFilter *whisper.Filter - muLimiter sync.RWMutex - limiter *limiter - tick *ticker -} - -// DBKey key to be stored on db. -type DBKey struct { - timestamp uint32 - hash common.Hash - raw []byte -} - -// Bytes returns a bytes representation of the DBKey. -func (k *DBKey) Bytes() []byte { - return k.raw -} - -// NewDBKey creates a new DBKey with the given values. -func NewDBKey(t uint32, h common.Hash) *DBKey { - var k DBKey - k.timestamp = t - k.hash = h - k.raw = make([]byte, DBKeyLength) - binary.BigEndian.PutUint32(k.raw, k.timestamp) - copy(k.raw[4:], k.hash[:]) - return &k -} - -// NewDBKeyFromBytes creates a DBKey from a byte slice. -func NewDBKeyFromBytes(b []byte) *DBKey { - return &DBKey{ - raw: b, - timestamp: binary.BigEndian.Uint32(b), - hash: common.BytesToHash(b[4:]), - } + muRateLimiter sync.RWMutex + rateLimiter *rateLimiter } // Init initializes mailServer. @@ -150,7 +97,7 @@ func (s *WMailServer) Init(shh *whisper.Whisper, config *params.WhisperConfig) e if err := s.setupRequestMessageDecryptor(config); err != nil { return err } - s.setupLimiter(time.Duration(config.MailServerRateLimit) * time.Second) + s.setupRateLimiter(time.Duration(config.MailServerRateLimit) * time.Second) // Open database in the last step in order not to init with error // and leave the database open by accident. @@ -163,12 +110,12 @@ func (s *WMailServer) Init(shh *whisper.Whisper, config *params.WhisperConfig) e return nil } -// setupLimiter in case limit is bigger than 0 it will setup an automated +// setupRateLimiter in case limit is bigger than 0 it will setup an automated // limit db cleanup. -func (s *WMailServer) setupLimiter(limit time.Duration) { +func (s *WMailServer) setupRateLimiter(limit time.Duration) { if limit > 0 { - s.limiter = newLimiter(limit) - s.setupMailServerCleanup(limit) + s.rateLimiter = newRateLimiter(limit) + s.rateLimiter.Start() } } @@ -203,15 +150,6 @@ func (s *WMailServer) setupRequestMessageDecryptor(config *params.WhisperConfig) return nil } -// setupMailServerCleanup periodically runs an expired entries deleteion for -// stored limits. -func (s *WMailServer) setupMailServerCleanup(period time.Duration) { - if s.tick == nil { - s.tick = &ticker{} - } - go s.tick.run(period, s.limiter.deleteExpired) -} - // Close the mailserver and its associated db connection. func (s *WMailServer) Close() { if s.db != nil { @@ -219,8 +157,8 @@ func (s *WMailServer) Close() { log.Error(fmt.Sprintf("s.db.Close failed: %s", err)) } } - if s.tick != nil { - s.tick.stop() + if s.rateLimiter != nil { + s.rateLimiter.Stop() } } @@ -450,18 +388,21 @@ func (s *WMailServer) SyncMail(peer *whisper.Peer, request whisper.SyncMailReque // exceedsPeerRequests in case limit its been setup on the current server and limit // allows the query, it will store/update new query time for the current peer. func (s *WMailServer) exceedsPeerRequests(peer []byte) bool { - s.muLimiter.RLock() - defer s.muLimiter.RUnlock() + s.muRateLimiter.RLock() + defer s.muRateLimiter.RUnlock() - if s.limiter != nil { - peerID := string(peer) - if !s.limiter.isAllowed(peerID) { - log.Info("peerID exceeded the number of requests per second") - return true - } - s.limiter.add(peerID) + if s.rateLimiter == nil { + return false } - return false + + peerID := string(peer) + if s.rateLimiter.IsAllowed(peerID) { + s.rateLimiter.Add(peerID) + return false + } + + log.Info("peerID exceeded the number of requests per second") + return true } func (s *WMailServer) createIterator(lower, upper uint32, cursor []byte) iterator.Iterator { @@ -472,7 +413,7 @@ func (s *WMailServer) createIterator(lower, upper uint32, cursor []byte) iterato kl = NewDBKey(lower, emptyHash) if len(cursor) == DBKeyLength { - ku = NewDBKeyFromBytes(cursor) + ku = mustNewDBKeyFromBytes(cursor) } else { ku = NewDBKey(upper+1, emptyHash) } diff --git a/mailserver/mailserver_test.go b/mailserver/mailserver_test.go index 10c426b19..f90e83a9d 100644 --- a/mailserver/mailserver_test.go +++ b/mailserver/mailserver_test.go @@ -173,7 +173,7 @@ func (s *MailserverSuite) TestInit() { } if tc.config.MailServerRateLimit > 0 { - s.NotNil(mailServer.limiter) + s.NotNil(mailServer.rateLimiter) } }) } @@ -273,15 +273,15 @@ func (s *MailserverSuite) TestArchive() { } func (s *MailserverSuite) TestManageLimits() { - s.server.limiter = newLimiter(time.Duration(5) * time.Millisecond) + s.server.rateLimiter = newRateLimiter(time.Duration(5) * time.Millisecond) s.False(s.server.exceedsPeerRequests([]byte("peerID"))) - s.Equal(1, len(s.server.limiter.db)) - firstSaved := s.server.limiter.db["peerID"] + s.Equal(1, len(s.server.rateLimiter.db)) + firstSaved := s.server.rateLimiter.db["peerID"] // second call when limit is not accomplished does not store a new limit s.True(s.server.exceedsPeerRequests([]byte("peerID"))) - s.Equal(1, len(s.server.limiter.db)) - s.Equal(firstSaved, s.server.limiter.db["peerID"]) + s.Equal(1, len(s.server.rateLimiter.db)) + s.Equal(firstSaved, s.server.rateLimiter.db["peerID"]) } func (s *MailserverSuite) TestDBKey() { diff --git a/mailserver/metrics.go b/mailserver/metrics.go new file mode 100644 index 000000000..49b2b8a0f --- /dev/null +++ b/mailserver/metrics.go @@ -0,0 +1,22 @@ +package mailserver + +import "github.com/ethereum/go-ethereum/metrics" + +var ( + // By default go-ethereum/metrics creates dummy metrics that don't register anything. + // Real metrics are collected only if -metrics flag is set + requestProcessTimer = metrics.NewRegisteredTimer("mailserver/requestProcessTime", nil) + requestProcessNetTimer = metrics.NewRegisteredTimer("mailserver/requestProcessNetTime", nil) + requestsMeter = metrics.NewRegisteredMeter("mailserver/requests", nil) + requestsBatchedCounter = metrics.NewRegisteredCounter("mailserver/requestsBatched", nil) + requestErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestErrors", nil) + sentEnvelopesMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopes", nil) + sentEnvelopesSizeMeter = metrics.NewRegisteredMeter("mailserver/sentEnvelopesSize", nil) + archivedMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopes", nil) + archivedSizeMeter = metrics.NewRegisteredMeter("mailserver/archivedEnvelopesSize", nil) + archivedErrorsCounter = metrics.NewRegisteredCounter("mailserver/archiveErrors", nil) + requestValidationErrorsCounter = metrics.NewRegisteredCounter("mailserver/requestValidationErrors", nil) + processRequestErrorsCounter = metrics.NewRegisteredCounter("mailserver/processRequestErrors", nil) + historicResponseErrorsCounter = metrics.NewRegisteredCounter("mailserver/historicResponseErrors", nil) + syncRequestsMeter = metrics.NewRegisteredMeter("mailserver/syncRequests", nil) +) diff --git a/mailserver/ticker.go b/mailserver/ticker.go deleted file mode 100644 index 4a39d3d53..000000000 --- a/mailserver/ticker.go +++ /dev/null @@ -1,33 +0,0 @@ -package mailserver - -import ( - "sync" - "time" -) - -type ticker struct { - mu sync.RWMutex - timeTicker *time.Ticker -} - -func (t *ticker) run(period time.Duration, fn func()) { - if t.timeTicker != nil { - return - } - - tt := time.NewTicker(period) - t.mu.Lock() - t.timeTicker = tt - t.mu.Unlock() - go func() { - for range tt.C { - fn() - } - }() -} - -func (t *ticker) stop() { - t.mu.RLock() - t.timeTicker.Stop() - t.mu.RUnlock() -} diff --git a/peers/peerpool_test.go b/peers/peerpool_test.go index af47c1a25..6c93c1dd4 100644 --- a/peers/peerpool_test.go +++ b/peers/peerpool_test.go @@ -448,6 +448,9 @@ func (s *PeerPoolSimulationSuite) TestUpdateTopicLimits() { func (s *PeerPoolSimulationSuite) TestMailServerPeersDiscovery() { s.setupEthV5() + // eliminate peer we won't use + s.peers[2].Stop() + // Buffered channels must be used because we expect the events // to be in the same order. Use a buffer length greater than // the expected number of events to avoid deadlock. @@ -515,5 +518,4 @@ func (s *PeerPoolSimulationSuite) TestMailServerPeersDiscovery() { disconnectedPeer := s.getPeerFromEvent(events, p2p.PeerEventTypeDrop) s.Equal(s.peers[0].Self().ID().String(), disconnectedPeer.String()) s.Equal(signal.EventDiscoverySummary, s.getPoolEvent(poolEvents)) - s.Len(<-summaries, 0) }