diff --git a/whisperv6/whisper.go b/whisperv6/whisper.go index 362e808..e508cb1 100644 --- a/whisperv6/whisper.go +++ b/whisperv6/whisper.go @@ -21,6 +21,8 @@ import ( "crypto/ecdsa" "crypto/sha256" "fmt" + "io" + "io/ioutil" "math" "runtime" "sync" @@ -418,17 +420,23 @@ func (whisper *Whisper) SendHistoricMessageResponse(peer *Peer, payload []byte) } // SendP2PMessage sends a peer-to-peer message to a specific peer. -func (whisper *Whisper) SendP2PMessage(peerID []byte, envelope *Envelope) error { +func (whisper *Whisper) SendP2PMessage(peerID []byte, envelopes ...*Envelope) error { p, err := whisper.getPeer(peerID) if err != nil { return err } - return whisper.SendP2PDirect(p, envelope) + return whisper.SendP2PDirect(p, envelopes...) } // SendP2PDirect sends a peer-to-peer message to a specific peer. -func (whisper *Whisper) SendP2PDirect(peer *Peer, envelope *Envelope) error { - return p2p.Send(peer.ws, p2pMessageCode, envelope) +// If only a single envelope is given, data is sent as a single object +// rather than a slice. This is important to keep this method backward compatible +// as it used to send only single envelopes. +func (whisper *Whisper) SendP2PDirect(peer *Peer, envelopes ...*Envelope) error { + if len(envelopes) == 1 { + return p2p.Send(peer.ws, p2pMessageCode, envelopes[0]) + } + return p2p.Send(peer.ws, p2pMessageCode, envelopes) } // NewKeyPair generates a new cryptographic identity for the client, and injects @@ -843,12 +851,46 @@ func (whisper *Whisper) runMessageLoop(p *Peer, rw p2p.MsgReadWriter) error { // therefore might not satisfy the PoW, expiry and other requirements. // these messages are only accepted from the trusted peer. if p.trusted { - var envelope Envelope - if err := packet.Decode(&envelope); err != nil { - log.Warn("failed to decode direct message, peer will be disconnected", "peer", p.peer.ID(), "err", err) - return errors.New("invalid direct message") + var ( + envelope *Envelope + envelopes []*Envelope + err error + ) + + // Read all data as we will try to decode it possibly twice + // to keep backward compatibility. + data, err := ioutil.ReadAll(packet.Payload) + if err != nil { + return fmt.Errorf("invalid direct messages: %v", err) + } + r := bytes.NewReader(data) + + packet.Payload = r + + if err = packet.Decode(&envelopes); err == nil { + for _, envelope := range envelopes { + whisper.postEvent(envelope, true) + } + continue + } + + // As we failed to decode envelopes, let's set the offset + // to the beginning and try decode data again. + // Decoding to a single Envelope is required + // to be backward compatible. + if _, err := r.Seek(0, io.SeekStart); err != nil { + return fmt.Errorf("invalid direct messages: %v", err) + } + + if err = packet.Decode(&envelope); err == nil { + whisper.postEvent(envelope, true) + continue + } + + if err != nil { + log.Warn("failed to decode direct message, peer will be disconnected", "peer", p.peer.ID(), "err", err) + return fmt.Errorf("invalid direct message: %v", err) } - whisper.postEvent(&envelope, true) } case p2pRequestCode: // Must be processed if mail server is implemented. Otherwise ignore. diff --git a/whisperv6/whisper_test.go b/whisperv6/whisper_test.go index 0df876e..e115dbd 100644 --- a/whisperv6/whisper_test.go +++ b/whisperv6/whisper_test.go @@ -25,6 +25,10 @@ import ( "time" "github.com/ethereum/go-ethereum/common" + "github.com/ethereum/go-ethereum/p2p" + "github.com/ethereum/go-ethereum/p2p/discover" + "github.com/ethereum/go-ethereum/rlp" + "github.com/syndtr/goleveldb/leveldb/errors" "golang.org/x/crypto/pbkdf2" ) @@ -890,3 +894,150 @@ func TestBloom(t *testing.T) { t.Fatalf("retireved wrong bloom filter") } } + +func TestSendP2PDirect(t *testing.T) { + InitSingleTest() + + w := New(&DefaultConfig) + w.SetMinimumPowTest(0.0000001) + defer w.SetMinimumPowTest(DefaultMinimumPoW) + w.Start(nil) + defer w.Stop() + + rwStub := &rwP2PMessagesStub{} + peerW := newPeer(w, p2p.NewPeer(discover.NodeID{}, "test", []p2p.Cap{}), rwStub) + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + params.TTL = 1 + + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params, time.Now()) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + // verify sending a single envelope + err = w.SendP2PDirect(peerW, env) + if err != nil { + t.Fatalf("failed to send envelope with seed %d: %s.", seed, err) + } + if len(rwStub.messages) != 1 { + t.Fatalf("invalid number of messages sent to peer: %d, expected 1", len(rwStub.messages)) + } + var envelope Envelope + if err := rwStub.messages[0].Decode(&envelope); err != nil { + t.Fatalf("failed to decode envelopes: %s", err) + } + if envelope.Hash() != env.Hash() { + t.Fatalf("invalid envelope %d, expected %d", envelope.Hash(), env.Hash()) + } + rwStub.messages = nil + + // send a batch of envelopes + err = w.SendP2PDirect(peerW, env, env, env) + if err != nil { + t.Fatalf("failed to send envelope with seed %d: %s.", seed, err) + } + if len(rwStub.messages) != 1 { + t.Fatalf("invalid number of messages sent to peer: %d, expected 1", len(rwStub.messages)) + } + var envelopes []*Envelope + if err := rwStub.messages[0].Decode(&envelopes); err != nil { + t.Fatalf("failed to decode envelopes: %s", err) + } + if len(envelopes) != 3 { + t.Fatalf("invalid number of envelopes in a message: %d, expected 3", len(envelopes)) + } + rwStub.messages = nil + envelopes = nil +} + +func TestHandleP2PMessageCode(t *testing.T) { + InitSingleTest() + + w := New(&DefaultConfig) + w.SetMinimumPowTest(0.0000001) + defer w.SetMinimumPowTest(DefaultMinimumPoW) + w.Start(nil) + defer w.Stop() + + envelopeEvents := make(chan EnvelopeEvent, 10) + sub := w.SubscribeEnvelopeEvents(envelopeEvents) + defer sub.Unsubscribe() + + params, err := generateMessageParams() + if err != nil { + t.Fatalf("failed generateMessageParams with seed %d: %s.", seed, err) + } + params.TTL = 1 + + msg, err := NewSentMessage(params) + if err != nil { + t.Fatalf("failed to create new message with seed %d: %s.", seed, err) + } + env, err := msg.Wrap(params, time.Now()) + if err != nil { + t.Fatalf("failed Wrap with seed %d: %s.", seed, err) + } + + // read a single envelope + rwStub := &rwP2PMessagesStub{} + rwStub.payload = []interface{}{env} + + peer := newPeer(nil, p2p.NewPeer(discover.NodeID{}, "test", []p2p.Cap{}), nil) + peer.trusted = true + + err = w.runMessageLoop(peer, rwStub) + if err != nil && err != errRWStub { + t.Fatalf("failed run message loop: %s", err) + } + if e := <-envelopeEvents; e.Hash != env.Hash() { + t.Fatalf("received envelope %s while expected %s", e.Hash, env.Hash()) + } + + // read a batch of envelopes + rwStub = &rwP2PMessagesStub{} + rwStub.payload = []interface{}{[]*Envelope{env, env, env}} + + err = w.runMessageLoop(peer, rwStub) + if err != nil && err != errRWStub { + t.Fatalf("failed run message loop: %s", err) + } + for i := 0; i < 3; i++ { + if e := <-envelopeEvents; e.Hash != env.Hash() { + t.Fatalf("received envelope %s while expected %s", e.Hash, env.Hash()) + } + } +} + +var errRWStub = errors.New("no more messages") + +type rwP2PMessagesStub struct { + // payload stores individual messages that will be sent returned + // on ReadMsg() class + payload []interface{} + messages []p2p.Msg +} + +func (stub *rwP2PMessagesStub) ReadMsg() (p2p.Msg, error) { + if len(stub.payload) == 0 { + return p2p.Msg{}, errRWStub + } + size, r, err := rlp.EncodeToReader(stub.payload[0]) + if err != nil { + return p2p.Msg{}, err + } + stub.payload = stub.payload[1:] + return p2p.Msg{Code: p2pMessageCode, Size: uint32(size), Payload: r}, nil +} + +func (stub *rwP2PMessagesStub) WriteMsg(m p2p.Msg) error { + stub.messages = append(stub.messages, m) + return nil +}