Implement rate limiter and integrate it with whisper

Implement rate limiter and integrate it with whisper
This commit is contained in:
Dmitry 2018-09-25 11:38:39 +03:00
parent c72a926c11
commit 64f159412d
9 changed files with 737 additions and 2 deletions

91
ratelimiter/peer.go Normal file
View File

@ -0,0 +1,91 @@
package ratelimiter
import (
"net"
"strings"
"time"
"github.com/ethereum/go-ethereum/p2p"
)
const (
// IDMode enables rate limiting based on peers public key identity.
IDMode = 1 + iota
// IPMode enables rate limiting based on peer external ip address.
IPMode
)
func ipModeFunc(peer *p2p.Peer) []byte {
addr := peer.RemoteAddr().Network()
ip := net.ParseIP(strings.Split(addr, ":")[0])
return []byte(ip)
}
func idModeFunc(peer *p2p.Peer) []byte {
return peer.ID().Bytes()
}
// selectFunc returns idModeFunc by default.
func selectFunc(mode int) func(*p2p.Peer) []byte {
if mode == IPMode {
return ipModeFunc
}
return idModeFunc
}
// NewP2PRateLimiter returns an instance of P2PRateLimiter.
func NewP2PRateLimiter(mode int, ratelimiter Interface) P2PRateLimiter {
return P2PRateLimiter{
modeFunc: selectFunc(mode),
ratelimiter: ratelimiter,
}
}
// P2PRateLimiter implements rate limiter that accepts p2p.Peer as identifier.
type P2PRateLimiter struct {
modeFunc func(*p2p.Peer) []byte
ratelimiter Interface
}
func (r P2PRateLimiter) Config() Config {
return r.ratelimiter.Config()
}
func (r P2PRateLimiter) Create(peer *p2p.Peer) error {
return r.ratelimiter.Create(r.modeFunc(peer))
}
func (r P2PRateLimiter) Remove(peer *p2p.Peer, duration time.Duration) error {
return r.ratelimiter.Remove(r.modeFunc(peer), duration)
}
func (r P2PRateLimiter) TakeAvailable(peer *p2p.Peer, count int64) int64 {
return r.ratelimiter.TakeAvailable(r.modeFunc(peer), count)
}
func (r P2PRateLimiter) Available(peer *p2p.Peer) int64 {
return r.ratelimiter.Available(r.modeFunc(peer))
}
func (r P2PRateLimiter) UpdateConfig(peer *p2p.Peer, config Config) error {
return r.ratelimiter.UpdateConfig(r.modeFunc(peer), config)
}
type Whisper struct {
ingress, egress P2PRateLimiter
}
func ForWhisper(mode int, db DBInterface, ingress, egress Config) Whisper {
return Whisper{
ingress: NewP2PRateLimiter(mode, NewPersisted(db, ingress, []byte("i"))),
egress: NewP2PRateLimiter(mode, NewPersisted(db, egress, []byte("e"))),
}
}
func (w Whisper) I() P2PRateLimiter {
return w.ingress
}
func (w Whisper) E() P2PRateLimiter {
return w.egress
}

29
ratelimiter/peer_test.go Normal file
View File

@ -0,0 +1,29 @@
package ratelimiter
import (
"testing"
"time"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/require"
)
func TestIDMode(t *testing.T) {
cfg := Config{}
peer := p2p.NewPeer(discover.NodeID{1}, "test", nil)
ctrl := gomock.NewController(t)
rl := NewMockInterface(ctrl)
rl.EXPECT().Create(peer.ID().Bytes())
rl.EXPECT().TakeAvailable(peer.ID().Bytes(), int64(0))
rl.EXPECT().Available(peer.ID().Bytes())
rl.EXPECT().Remove(peer.ID().Bytes(), time.Duration(0))
rl.EXPECT().UpdateConfig(peer.ID().Bytes(), cfg)
peerrl := NewP2PRateLimiter(IDMode, rl)
require.NoError(t, peerrl.Create(peer))
peerrl.TakeAvailable(peer, 0)
peerrl.Available(peer)
require.NoError(t, peerrl.Remove(peer, 0))
require.NoError(t, peerrl.UpdateConfig(peer, cfg))
}

201
ratelimiter/ratelimiter.go Normal file
View File

@ -0,0 +1,201 @@
package ratelimiter
import (
"encoding/binary"
"fmt"
"sync"
"time"
"github.com/ethereum/go-ethereum/log"
"github.com/ethereum/go-ethereum/rlp"
"github.com/juju/ratelimit"
"github.com/status-im/status-go/db"
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/opt"
)
// Interface describes common interface methods.
type Interface interface {
Create([]byte) error
Remove([]byte, time.Duration) error
TakeAvailable([]byte, int64) int64
Available([]byte) int64
UpdateConfig([]byte, Config) error
Config() Config
}
// DBInterface defines leveldb methods used by ratelimiter.
type DBInterface interface {
Put(key, value []byte, wo *opt.WriteOptions) error
Get(key []byte, ro *opt.ReadOptions) (value []byte, err error)
Delete(key []byte, wo *opt.WriteOptions) error
}
// Config is a set of options used by rate limiter.
type Config struct {
Interval, Capacity, Quantum uint64
}
// compare config with existing ratelimited bucket.
func compare(c Config, bucket *ratelimit.Bucket) bool {
return int64(c.Capacity) == bucket.Capacity() &&
1e9*float64(c.Quantum)/float64(c.Interval) == bucket.Rate()
}
func newBucket(c Config) *ratelimit.Bucket {
return ratelimit.NewBucketWithQuantum(time.Duration(c.Interval), int64(c.Capacity), int64(c.Quantum))
}
func NewPersisted(db DBInterface, config Config, prefix []byte) *PersistedRateLimiter {
return &PersistedRateLimiter{
db: db,
defaultConfig: config,
initialized: map[string]*ratelimit.Bucket{},
prefix: prefix,
timeFunc: time.Now,
}
}
// PersistedRateLimiter persists latest capacity and updated config per unique ID.
type PersistedRateLimiter struct {
db DBInterface
prefix []byte // TODO move prefix outside of the rate limiter using database interface
defaultConfig Config
mu sync.Mutex
initialized map[string]*ratelimit.Bucket
timeFunc func() time.Time
}
func (r *PersistedRateLimiter) blacklist(id []byte, duration time.Duration) error {
fkey := db.Key(db.RateLimitBlacklist, r.prefix, id)
buf := [8]byte{}
binary.BigEndian.PutUint64(buf[:], uint64(r.timeFunc().Add(duration).Unix()))
if err := r.db.Put(fkey, buf[:], nil); err != nil {
return fmt.Errorf("error blacklisting %x: %v", id, err)
}
return nil
}
func (r *PersistedRateLimiter) Config() Config {
return r.defaultConfig
}
func (r *PersistedRateLimiter) getOrCreate(id []byte, config Config) (bucket *ratelimit.Bucket) {
r.mu.Lock()
defer r.mu.Unlock()
old, exist := r.initialized[string(id)]
if !exist {
bucket = newBucket(config)
r.initialized[string(id)] = bucket
} else {
bucket = old
}
return
}
func (r *PersistedRateLimiter) Create(id []byte) error {
fkey := db.Key(db.RateLimitBlacklist, r.prefix, id)
val, err := r.db.Get(fkey, nil)
if err != leveldb.ErrNotFound {
deadline := binary.BigEndian.Uint64(val)
if deadline >= uint64(r.timeFunc().Unix()) {
return fmt.Errorf("identity %x is blacklisted", id)
}
r.db.Delete(fkey, nil)
}
fkey = db.Key(db.RateLimitConfig, r.prefix, id)
val, err = r.db.Get(fkey, nil)
var cfg Config
if err == leveldb.ErrNotFound {
cfg = r.defaultConfig
} else if err != nil {
log.Error("faield to read config from db. using default", "err", err)
cfg = r.defaultConfig
} else {
if err := rlp.DecodeBytes(val, &cfg); err != nil {
log.Error("failed to decode config. using default", "err", err)
cfg = r.defaultConfig
}
}
bucket := r.getOrCreate(id, cfg)
fkey = db.Key(db.RateLimitCapacity, r.prefix, id)
val, err = r.db.Get(fkey, nil)
if err == leveldb.ErrNotFound {
return nil
} else if len(val) != 16 {
log.Error("stored value is of unexpected length", "expected", 8, "stored", len(val))
return nil
}
bucket.TakeAvailable(int64(binary.BigEndian.Uint64(val[:8])))
// TODO refill rate limiter due to time difference. e.g. if record was stored at T and C seconds passed since T.
// we need to add RATE_PER_SECOND*C to a bucket
return nil
}
// Remove removes key from memory but ensures that the latest information is persisted.
func (r *PersistedRateLimiter) Remove(id []byte, duration time.Duration) error {
if duration != 0 {
if err := r.blacklist(id, duration); err != nil {
return err
}
}
r.mu.Lock()
bucket, exist := r.initialized[string(id)]
delete(r.initialized, string(id))
r.mu.Unlock()
if !exist || bucket == nil {
return nil
}
return r.store(id, bucket)
}
func (r *PersistedRateLimiter) store(id []byte, bucket *ratelimit.Bucket) error {
buf := [16]byte{}
binary.BigEndian.PutUint64(buf[:], uint64(bucket.Capacity()-bucket.Available()))
binary.BigEndian.PutUint64(buf[8:], uint64(r.timeFunc().Unix()))
err := r.db.Put(db.Key(db.RateLimitCapacity, r.prefix, id), buf[:], nil)
if err != nil {
return fmt.Errorf("failed to write current capacicity %d for id %x: %v",
bucket.Capacity(), id, err)
}
return nil
}
func (r *PersistedRateLimiter) TakeAvailable(id []byte, count int64) int64 {
bucket := r.getOrCreate(id, r.defaultConfig)
rst := bucket.TakeAvailable(count)
if err := r.store(id, bucket); err != nil {
log.Error(err.Error())
}
return rst
}
func (r *PersistedRateLimiter) Available(id []byte) int64 {
return r.getOrCreate(id, r.defaultConfig).Available()
}
func (r *PersistedRateLimiter) UpdateConfig(id []byte, config Config) error {
r.mu.Lock()
old, _ := r.initialized[string(id)]
if compare(config, old) {
r.mu.Unlock()
return nil
}
delete(r.initialized, string(id))
r.mu.Unlock()
taken := int64(0)
if old != nil {
taken = old.Capacity() - old.Available()
}
r.getOrCreate(id, config).TakeAvailable(taken)
fkey := db.Key(db.RateLimitConfig, r.prefix, id)
data, err := rlp.EncodeToBytes(config)
if err != nil {
log.Error("failed to update config", "cfg", config, "err", err)
return nil
}
r.db.Put(fkey, data, nil)
return nil
}

View File

@ -0,0 +1,107 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: ratelimiter/interface.go
// Package ratelimiter is a generated GoMock package.
package ratelimiter
import (
reflect "reflect"
time "time"
gomock "github.com/golang/mock/gomock"
)
// MockInterface is a mock of Interface interface
type MockInterface struct {
ctrl *gomock.Controller
recorder *MockInterfaceMockRecorder
}
// MockInterfaceMockRecorder is the mock recorder for MockInterface
type MockInterfaceMockRecorder struct {
mock *MockInterface
}
// NewMockInterface creates a new mock instance
func NewMockInterface(ctrl *gomock.Controller) *MockInterface {
mock := &MockInterface{ctrl: ctrl}
mock.recorder = &MockInterfaceMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockInterface) EXPECT() *MockInterfaceMockRecorder {
return m.recorder
}
// Create mocks base method
func (m *MockInterface) Create(arg0 []byte) error {
ret := m.ctrl.Call(m, "Create", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Create indicates an expected call of Create
func (mr *MockInterfaceMockRecorder) Create(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Create", reflect.TypeOf((*MockInterface)(nil).Create), arg0)
}
// Remove mocks base method
func (m *MockInterface) Remove(arg0 []byte, arg1 time.Duration) error {
ret := m.ctrl.Call(m, "Remove", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// Remove indicates an expected call of Remove
func (mr *MockInterfaceMockRecorder) Remove(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Remove", reflect.TypeOf((*MockInterface)(nil).Remove), arg0, arg1)
}
// TakeAvailable mocks base method
func (m *MockInterface) TakeAvailable(arg0 []byte, arg1 int64) int64 {
ret := m.ctrl.Call(m, "TakeAvailable", arg0, arg1)
ret0, _ := ret[0].(int64)
return ret0
}
// TakeAvailable indicates an expected call of TakeAvailable
func (mr *MockInterfaceMockRecorder) TakeAvailable(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TakeAvailable", reflect.TypeOf((*MockInterface)(nil).TakeAvailable), arg0, arg1)
}
// Available mocks base method
func (m *MockInterface) Available(arg0 []byte) int64 {
ret := m.ctrl.Call(m, "Available", arg0)
ret0, _ := ret[0].(int64)
return ret0
}
// Available indicates an expected call of Available
func (mr *MockInterfaceMockRecorder) Available(arg0 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Available", reflect.TypeOf((*MockInterface)(nil).Available), arg0)
}
// UpdateConfig mocks base method
func (m *MockInterface) UpdateConfig(arg0 []byte, arg1 Config) error {
ret := m.ctrl.Call(m, "UpdateConfig", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UpdateConfig indicates an expected call of UpdateConfig
func (mr *MockInterfaceMockRecorder) UpdateConfig(arg0, arg1 interface{}) *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateConfig", reflect.TypeOf((*MockInterface)(nil).UpdateConfig), arg0, arg1)
}
// Config mocks base method
func (m *MockInterface) Config() Config {
ret := m.ctrl.Call(m, "Config")
ret0, _ := ret[0].(Config)
return ret0
}
// Config indicates an expected call of Config
func (mr *MockInterfaceMockRecorder) Config() *gomock.Call {
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Config", reflect.TypeOf((*MockInterface)(nil).Config))
}

View File

@ -0,0 +1,64 @@
package ratelimiter
import (
"fmt"
"testing"
"time"
"github.com/stretchr/testify/require"
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/storage"
)
func TestLimitIsPersisted(t *testing.T) {
db, err := leveldb.Open(storage.NewMemStorage(), nil)
require.NoError(t, err)
var (
total int64 = 10000
rl = NewPersisted(db, Config{1 << 62, uint64(10000), 1}, nil)
tid = []byte("test")
)
require.NoError(t, rl.Create(tid))
taken := rl.TakeAvailable(tid, total/2)
require.Equal(t, total/2, taken)
require.NoError(t, rl.Remove(tid, 0))
require.NoError(t, rl.Create(tid))
require.Equal(t, total/2, rl.Available(tid))
}
func TestConfigIsPersistedAndFixedOnUpdate(t *testing.T) {
db, err := leveldb.Open(storage.NewMemStorage(), nil)
require.NoError(t, err)
var (
total int64 = 10000
cfg = Config{1 << 62, uint64(10000), 1}
rl = NewPersisted(db, cfg, nil)
tid = []byte("test")
)
require.NoError(t, rl.Create(tid))
taken := rl.TakeAvailable(tid, total/2)
require.Equal(t, total/2, taken)
cfg.Capacity = 6000
require.NoError(t, rl.UpdateConfig(tid, cfg))
require.Equal(t, int64(cfg.Capacity)-total/2, rl.Available(tid))
require.NoError(t, rl.Remove(tid, 0))
require.NoError(t, rl.Create(tid))
require.Equal(t, int64(cfg.Capacity)-total/2, rl.Available(tid))
}
func TestBlacklistedEntityReturnsError(t *testing.T) {
db, err := leveldb.Open(storage.NewMemStorage(), nil)
require.NoError(t, err)
var (
cfg = Config{1 << 62, uint64(10000), 1}
rl = NewPersisted(db, cfg, nil)
tid = []byte("test")
)
require.NoError(t, rl.Create(tid))
require.NoError(t, rl.Remove(tid, 10*time.Minute))
require.EqualError(t, fmt.Errorf("identity %x is blacklisted", tid), rl.Create(tid).Error())
rl.timeFunc = func() time.Time {
return time.Now().Add(11 * time.Minute)
}
require.NoError(t, rl.Create(tid))
}

View File

@ -47,6 +47,7 @@ const (
messagesCode = 1 // normal whisper message
powRequirementCode = 2 // PoW requirement
bloomFilterExCode = 3 // bloom filter exchange
peerRateLimitCode = 8 // update of the peer rate limit
p2pRequestCompleteCode = 125 // peer-to-peer message, used by Dapp protocol
p2pRequestCode = 126 // peer-to-peer message, used by Dapp protocol
p2pMessageCode = 127 // peer-to-peer message (to be consumed by the peer, but not forwarded any further)

View File

@ -19,6 +19,7 @@ package whisperv6
import (
"fmt"
"math"
"math/rand"
"sync"
"time"
@ -192,6 +193,23 @@ func (peer *Peer) expire() {
}
}
func (peer *Peer) reduceBundle(bundle []*Envelope) []*Envelope {
if peer.host.ratelimiter == nil {
return bundle
}
rand.Shuffle(len(bundle), func(i, j int) {
bundle[i], bundle[j] = bundle[j], bundle[i]
})
for i := range bundle {
size := int64(bundle[i].size())
if peer.host.ratelimiter.E().Available(peer.peer) < size {
return bundle[:i]
}
peer.host.ratelimiter.E().TakeAvailable(peer.peer, size)
}
return bundle
}
// broadcast iterates over the collection of envelopes and transmits yet unknown
// ones over the network.
func (peer *Peer) broadcast() error {
@ -199,13 +217,14 @@ func (peer *Peer) broadcast() error {
log.Trace("Waiting for a peer to restore communication", "ID", peer.peer.ID())
return nil
}
envelopes := peer.host.Envelopes()
envelopes := peer.host.Envelopes() // envelopes are read from hash map, so access is already randomized
bundle := make([]*Envelope, 0, len(envelopes))
for _, envelope := range envelopes {
if !peer.marked(envelope) && envelope.PoW() >= peer.powRequirement && peer.bloomMatch(envelope) {
bundle = append(bundle, envelope)
}
}
bundle = peer.reduceBundle(bundle)
if len(bundle) > 0 {
// transmit the batch of envelopes

177
whisperv6/ratelimit_test.go Normal file
View File

@ -0,0 +1,177 @@
package whisperv6
import (
"math"
"math/rand"
"sync"
"testing"
"time"
"github.com/ethereum/go-ethereum/common"
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/p2p/discover"
"github.com/status-im/status-go/ratelimiter"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/syndtr/goleveldb/leveldb"
"github.com/syndtr/goleveldb/leveldb/storage"
)
const (
testCode = 42 // any non-defined code will work
)
func setupOneConnection(t *testing.T, rlconf ratelimiter.Config) (*Whisper, *p2p.MsgPipeRW, chan error) {
db, err := leveldb.Open(storage.NewMemStorage(), nil)
require.NoError(t, err)
rl := ratelimiter.ForWhisper(ratelimiter.IDMode, db, rlconf, rlconf)
conf := &Config{
MinimumAcceptedPOW: 0,
MaxMessageSize: 100 << 10,
}
w := New(conf)
w.UseRateLimiter(rl)
idx, _ := discover.BytesID([]byte{0x01})
p := p2p.NewPeer(idx, "1", []p2p.Cap{{"shh", 6}})
rw1, rw2 := p2p.MsgPipe()
errorc := make(chan error, 1)
go func() {
errorc <- w.HandlePeer(p, rw2)
}()
msg, err := rw1.ReadMsg()
require.NoError(t, err)
require.Equal(t, uint64(0), msg.Code)
require.NoError(t, msg.Discard())
require.NoError(t, p2p.SendItems(rw1, statusCode, ProtocolVersion, math.Float64bits(w.MinPow()), w.BloomFilter(), true))
require.NoError(t, p2p.ExpectMsg(rw1, peerRateLimitCode, nil), "peer must send ingress rate limit after handshake")
return w, rw1, errorc
}
func TestRatePeerDropsConnection(t *testing.T) {
cfg := ratelimiter.Config{Interval: uint64(time.Hour), Capacity: 10 << 10, Quantum: 1 << 10}
_, rw1, errorc := setupOneConnection(t, cfg)
require.NoError(t, p2p.Send(rw1, testCode, make([]byte, 11<<10))) // limit is 1024
select {
case err := <-errorc:
require.Error(t, err)
case <-time.After(time.Second):
require.FailNow(t, "failed waiting for HandlePeer to exit")
}
}
func TestRateLimitedDelivery(t *testing.T) {
cfg := ratelimiter.Config{Interval: uint64(time.Hour), Capacity: 10 << 10, Quantum: 1 << 10}
w, rw1, _ := setupOneConnection(t, cfg)
small1 := Envelope{
Expiry: uint32(time.Now().Add(10 * time.Second).Unix()),
TTL: 10,
Topic: TopicType{1},
Data: make([]byte, 1<<10),
Nonce: 1,
}
small2 := small1
small2.Nonce = 2
big := small1
big.Nonce = 3
big.Data = make([]byte, 11<<10)
require.NoError(t, w.Send(&small1))
require.NoError(t, w.Send(&big))
require.NoError(t, w.Send(&small2))
received := map[common.Hash]struct{}{}
// we can not guarantee that all expected envelopes will be delivered in a one batch
// so allow whisper to write multiple times and read every message
go func() {
time.Sleep(time.Second)
rw1.Close()
}()
for {
msg, err := rw1.ReadMsg()
if err == p2p.ErrPipeClosed {
require.Contains(t, received, small1.Hash())
require.Contains(t, received, small2.Hash())
require.NotContains(t, received, big.Hash())
break
}
require.NoError(t, err)
require.Equal(t, uint64(1), msg.Code)
var rst []*Envelope
require.NoError(t, msg.Decode(&rst))
for _, e := range rst {
received[e.Hash()] = struct{}{}
}
}
}
func TestRateRandomizedDelivery(t *testing.T) {
cfg := ratelimiter.Config{Interval: uint64(time.Hour), Capacity: 10 << 10, Quantum: 1 << 10}
w1, rw1, _ := setupOneConnection(t, cfg)
w2, rw2, _ := setupOneConnection(t, cfg)
w3, rw3, _ := setupOneConnection(t, cfg)
var (
mu sync.Mutex
wg sync.WaitGroup
sent = map[common.Hash]int{}
received = map[int]int64{}
)
for i := uint64(1); i < 15; i++ {
env := &Envelope{
Expiry: uint32(time.Now().Add(10 * time.Second).Unix()),
TTL: 10,
Topic: TopicType{1},
Data: make([]byte, 1<<10-EnvelopeHeaderLength), // so that 10 envelopes are exactly 10kb
Nonce: i,
}
sent[env.Hash()] = 0
for _, w := range []*Whisper{w1, w2, w3} {
go func(w *Whisper, e *Envelope) {
time.Sleep(time.Duration(rand.Int63n(10)) * time.Millisecond)
assert.NoError(t, w.Send(e))
}(w, env)
}
}
for i, rw := range []*p2p.MsgPipeRW{rw1, rw2, rw3} {
received[i] = 0
wg.Add(1)
go func(rw *p2p.MsgPipeRW) {
time.Sleep(time.Second)
rw.Close()
wg.Done()
}(rw)
wg.Add(1)
go func(i int, rw *p2p.MsgPipeRW) {
defer wg.Done()
for {
msg, err := rw.ReadMsg()
if err != nil {
return
}
if !assert.Equal(t, uint64(1), msg.Code) {
return
}
var rst []*Envelope
if !assert.NoError(t, msg.Decode(&rst)) {
return
}
mu.Lock()
for _, e := range rst {
received[i] += int64(len(e.Data))
received[i] += EnvelopeHeaderLength
sent[e.Hash()]++
}
mu.Unlock()
}
}(i, rw)
}
wg.Wait()
for i := range received {
require.Equal(t, int64(10)<<10, received[i], "peer %d didnt' receive 10 kb of data: %d", i, received[i])
}
total := 0
for h := range sent {
total += sent[h]
}
require.Equal(t, 30, total)
}

View File

@ -34,6 +34,7 @@ import (
"github.com/ethereum/go-ethereum/p2p"
"github.com/ethereum/go-ethereum/rlp"
"github.com/ethereum/go-ethereum/rpc"
"github.com/status-im/status-go/ratelimiter"
"github.com/syndtr/goleveldb/leveldb/errors"
"golang.org/x/crypto/pbkdf2"
"golang.org/x/sync/syncmap"
@ -98,6 +99,8 @@ type Whisper struct {
envelopeFeed event.Feed
ratelimiter *ratelimiter.Whisper
timeSource func() time.Time // source of time for whisper
}
@ -150,6 +153,11 @@ func (whisper *Whisper) SetTimeSource(timesource func() time.Time) {
whisper.timeSource = timesource
}
// UseRateLimiter makes whisper to use a specific implementation of the rate limiter
func (whisper *Whisper) UseRateLimiter(ratelimiter ratelimiter.Whisper) {
whisper.ratelimiter = &ratelimiter
}
// SubscribeEnvelopeEvents subscribes to envelopes feed.
// In order to prevent blocking whisper producers events must be amply buffered.
func (whisper *Whisper) SubscribeEnvelopeEvents(events chan<- EnvelopeEvent) event.Subscription {
@ -767,12 +775,32 @@ func (whisper *Whisper) HandlePeer(peer *p2p.Peer, rw p2p.MsgReadWriter) error {
}
whisperPeer.start()
defer whisperPeer.stop()
if whisper.ratelimiter != nil {
if err := whisper.ratelimiter.I().Create(whisperPeer.peer); err != nil {
return err
}
defer whisper.ratelimiter.I().Remove(whisperPeer.peer, 0)
whisper.ratelimiter.E().Create(whisperPeer.peer)
defer whisper.ratelimiter.E().Remove(whisperPeer.peer, 0)
}
return whisper.runMessageLoop(whisperPeer, rw)
}
func (whisper *Whisper) advertiseEgressLimit(p *Peer, rw p2p.MsgReadWriter) error {
if whisper.ratelimiter == nil {
return nil
}
if err := p2p.Send(rw, peerRateLimitCode, whisper.ratelimiter.I().Config()); err != nil {
return fmt.Errorf("failed to send ingress rate limit to a peer %v: %v", p.peer.ID(), err)
}
return nil
}
// runMessageLoop reads and processes inbound messages directly to merge into client-global state.
func (whisper *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error {
if err := whisper.advertiseEgressLimit(p, rw); err != nil {
return err
}
for {
// fetch the next packet
packet, err := rw.ReadMsg()
@ -861,6 +889,17 @@ func (whisper *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error {
whisper.mailServer.DeliverMail(p, &request)
}
case peerRateLimitCode:
if whisper.ratelimiter == nil {
continue
}
var conf ratelimiter.Config
if err := packet.Decode(&conf); err != nil {
return fmt.Errorf("peer %v sent wrong payload for a rate limiter config", p.peer.ID())
}
if err := whisper.ratelimiter.E().UpdateConfig(p.peer, conf); err != nil {
log.Error("error updaing rate limiter config", "peer", p.peer)
}
case p2pRequestCompleteCode:
if p.trusted {
var payload []byte
@ -914,6 +953,13 @@ func (whisper *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error {
}
packet.Discard()
if packet.Code != p2pMessageCode && whisper.ratelimiter != nil {
if whisper.ratelimiter.I().TakeAvailable(p.peer, int64(packet.Size)) < int64(packet.Size) {
whisper.ratelimiter.I().Remove(p.peer, 10*time.Minute)
return fmt.Errorf("peer %v reached traffic limit capacity", p.peer.ID())
}
}
}
}