diff --git a/core/types/block.go b/core/types/block.go index 5cdde4462..c04beae5a 100644 --- a/core/types/block.go +++ b/core/types/block.go @@ -148,6 +148,26 @@ func NewBlockWithHeader(header *Header) *Block { return &Block{header: header} } +func (self *Block) Validate() error { + if self.header == nil { + return fmt.Errorf("header is nil") + } + // check *big.Int fields + if self.header.Difficulty == nil { + return fmt.Errorf("Difficulty undefined") + } + if self.header.GasLimit == nil { + return fmt.Errorf("GasLimit undefined") + } + if self.header.GasUsed == nil { + return fmt.Errorf("GasUsed undefined") + } + if self.header.Number == nil { + return fmt.Errorf("Number undefined") + } + return nil +} + func (self *Block) DecodeRLP(s *rlp.Stream) error { var eb extblock if err := s.Decode(&eb); err != nil { diff --git a/eth/protocol.go b/eth/protocol.go index e32ea233b..0a3f67b62 100644 --- a/eth/protocol.go +++ b/eth/protocol.go @@ -105,7 +105,7 @@ type getBlockHashesMsgData struct { type statusMsgData struct { ProtocolVersion uint32 NetworkId uint32 - TD *big.Int + TD big.Int CurrentBlock common.Hash GenesisBlock common.Hash } @@ -276,6 +276,9 @@ func (self *ethProtocol) handle() error { if err := msg.Decode(&request); err != nil { return self.protoError(ErrDecode, "%v: %v", msg, err) } + if err := request.Block.Validate(); err != nil { + return self.protoError(ErrDecode, "block validation %v: %v", msg, err) + } hash := request.Block.Hash() _, chainHead, _ := self.chainManager.Status() @@ -335,7 +338,7 @@ func (self *ethProtocol) handleStatus() error { return self.protoError(ErrProtocolVersionMismatch, "%d (!= %d)", status.ProtocolVersion, self.protocolVersion) } - _, suspended := self.blockPool.AddPeer(status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) + _, suspended := self.blockPool.AddPeer(&status.TD, status.CurrentBlock, self.id, self.requestBlockHashes, self.requestBlocks, self.protoErrorDisconnect) if suspended { return self.protoError(ErrSuspendedPeer, "") } @@ -366,7 +369,7 @@ func (self *ethProtocol) sendStatus() error { return p2p.Send(self.rw, StatusMsg, &statusMsgData{ ProtocolVersion: uint32(self.protocolVersion), NetworkId: uint32(self.networkId), - TD: td, + TD: *td, CurrentBlock: currentBlock, GenesisBlock: genesisBlock, }) diff --git a/eth/protocol_test.go b/eth/protocol_test.go index 8ca6d1be6..7831e9bc6 100644 --- a/eth/protocol_test.go +++ b/eth/protocol_test.go @@ -1,6 +1,7 @@ package eth import ( + "fmt" "log" "math/big" "os" @@ -63,6 +64,10 @@ func (self *testChainManager) GetBlockHashesFromHash(hash common.Hash, amount ui func (self *testChainManager) Status() (td *big.Int, currentBlock common.Hash, genesisBlock common.Hash) { if self.status != nil { td, currentBlock, genesisBlock = self.status() + } else { + td = common.Big1 + currentBlock = common.Hash{1} + genesisBlock = common.Hash{2} } return } @@ -163,14 +168,29 @@ func (self *ethProtocolTester) run() { self.quit <- err } +func (self *ethProtocolTester) handshake(t *testing.T, mock bool) { + td, currentBlock, genesis := self.chainManager.Status() + // first outgoing msg should be StatusMsg. + err := p2p.ExpectMsg(self, StatusMsg, &statusMsgData{ + ProtocolVersion: ProtocolVersion, + NetworkId: NetworkId, + TD: *td, + CurrentBlock: currentBlock, + GenesisBlock: genesis, + }) + if err != nil { + t.Fatalf("incorrect outgoing status: %v", err) + } + if mock { + go p2p.Send(self, StatusMsg, &statusMsgData{ProtocolVersion, NetworkId, *td, currentBlock, genesis}) + } +} + func TestStatusMsgErrors(t *testing.T) { logInit() eth := newEth(t) - td := common.Big1 - currentBlock := common.Hash{1} - genesis := common.Hash{2} - eth.chainManager.status = func() (*big.Int, common.Hash, common.Hash) { return td, currentBlock, genesis } go eth.run() + td, currentBlock, genesis := eth.chainManager.Status() tests := []struct { code uint64 @@ -182,31 +202,20 @@ func TestStatusMsgErrors(t *testing.T) { wantErrorCode: ErrNoStatusMsg, }, { - code: StatusMsg, data: statusMsgData{10, NetworkId, td, currentBlock, genesis}, + code: StatusMsg, data: statusMsgData{10, NetworkId, *td, currentBlock, genesis}, wantErrorCode: ErrProtocolVersionMismatch, }, { - code: StatusMsg, data: statusMsgData{ProtocolVersion, 999, td, currentBlock, genesis}, + code: StatusMsg, data: statusMsgData{ProtocolVersion, 999, *td, currentBlock, genesis}, wantErrorCode: ErrNetworkIdMismatch, }, { - code: StatusMsg, data: statusMsgData{ProtocolVersion, NetworkId, td, currentBlock, common.Hash{3}}, + code: StatusMsg, data: statusMsgData{ProtocolVersion, NetworkId, *td, currentBlock, common.Hash{3}}, wantErrorCode: ErrGenesisBlockMismatch, }, } for _, test := range tests { - // first outgoing msg should be StatusMsg. - err := p2p.ExpectMsg(eth, StatusMsg, &statusMsgData{ - ProtocolVersion: ProtocolVersion, - NetworkId: NetworkId, - TD: td, - CurrentBlock: currentBlock, - GenesisBlock: genesis, - }) - if err != nil { - t.Fatalf("incorrect outgoing status: %v", err) - } - + eth.handshake(t, false) // the send call might hang until reset because // the protocol might not read the payload. go p2p.Send(eth, test.code, test.data) @@ -216,3 +225,73 @@ func TestStatusMsgErrors(t *testing.T) { go eth.run() } } + +func TestNewBlockMsg(t *testing.T) { + logInit() + eth := newEth(t) + eth.blockPool.addBlock = func(block *types.Block, peerId string) (err error) { + fmt.Printf("Add Block: %v\n", block) + return + } + + var disconnected bool + eth.blockPool.removePeer = func(peerId string) { + fmt.Printf("peer <%s> is disconnected\n", peerId) + disconnected = true + } + + go eth.run() + + eth.handshake(t, true) + err := p2p.ExpectMsg(eth, TxMsg, []interface{}{}) + if err != nil { + t.Errorf("transactions expected, got %v", err) + } + + var tds = make(chan *big.Int) + eth.blockPool.addPeer = func(td *big.Int, currentBlock common.Hash, peerId string, requestHashes func(common.Hash) error, requestBlocks func([]common.Hash) error, peerError func(*errs.Error)) (best bool, suspended bool) { + tds <- td + return + } + + var delay = 1 * time.Second + // eth.reset() + block := types.NewBlock(common.Hash{1}, common.Address{1}, common.Hash{1}, common.Big1, 1, "extra") + + go p2p.Send(eth, NewBlockMsg, &newBlockMsgData{Block: block}) + timer := time.After(delay) + + select { + case td := <-tds: + if td.Cmp(common.Big0) != 0 { + t.Errorf("incorrect td %v, expected %v", td, common.Big0) + } + case <-timer: + t.Errorf("no td recorded after %v", delay) + return + case err := <-eth.quit: + t.Errorf("no error expected, got %v", err) + return + } + + go p2p.Send(eth, NewBlockMsg, &newBlockMsgData{block, common.Big2}) + timer = time.After(delay) + + select { + case td := <-tds: + if td.Cmp(common.Big2) != 0 { + t.Errorf("incorrect td %v, expected %v", td, common.Big2) + } + case <-timer: + t.Errorf("no td recorded after %v", delay) + return + case err := <-eth.quit: + t.Errorf("no error expected, got %v", err) + return + } + + go p2p.Send(eth, NewBlockMsg, []interface{}{}) + // Block.DecodeRLP: validation failed: header is nil + eth.checkError(ErrDecode, delay) + +}