diff --git a/les/client_handler.go b/les/client_handler.go index 77a0ea5c6..d7ca1c54f 100644 --- a/les/client_handler.go +++ b/les/client_handler.go @@ -17,6 +17,7 @@ package les import ( + "context" "math/big" "sync" "sync/atomic" @@ -200,14 +201,23 @@ func (h *clientHandler) handleMsg(p *serverPeer) error { p.fcServer.ReceivedReply(resp.ReqID, resp.BV) p.answeredRequest(resp.ReqID) - // Filter out any explicitly requested headers, deliver the rest to the downloader - filter := len(headers) == 1 - if filter { - headers = h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) - } - if len(headers) != 0 || !filter { - if err := h.downloader.DeliverHeaders(p.id, headers); err != nil { - log.Debug("Failed to deliver headers", "err", err) + // Filter out the explicitly requested header by the retriever + if h.backend.retriever.requested(resp.ReqID) { + deliverMsg = &Msg{ + MsgType: MsgBlockHeaders, + ReqID: resp.ReqID, + Obj: resp.Headers, + } + } else { + // Filter out any explicitly requested headers, deliver the rest to the downloader + filter := len(headers) == 1 + if filter { + headers = h.fetcher.deliverHeaders(p, resp.ReqID, resp.Headers) + } + if len(headers) != 0 || !filter { + if err := h.downloader.DeliverHeaders(p.id, headers); err != nil { + log.Debug("Failed to deliver headers", "err", err) + } } } case BlockBodiesMsg: @@ -394,6 +404,42 @@ func (pc *peerConnection) RequestHeadersByNumber(origin uint64, amount int, skip return nil } +// RetrieveSingleHeaderByNumber requests a single header by the specified block +// number. This function will wait the response until it's timeout or delivered. +func (pc *peerConnection) RetrieveSingleHeaderByNumber(context context.Context, number uint64) (*types.Header, error) { + reqID := genReqID() + rq := &distReq{ + getCost: func(dp distPeer) uint64 { + peer := dp.(*serverPeer) + return peer.getRequestCost(GetBlockHeadersMsg, 1) + }, + canSend: func(dp distPeer) bool { + return dp.(*serverPeer) == pc.peer + }, + request: func(dp distPeer) func() { + peer := dp.(*serverPeer) + cost := peer.getRequestCost(GetBlockHeadersMsg, 1) + peer.fcServer.QueuedRequest(reqID, cost) + return func() { peer.requestHeadersByNumber(reqID, number, 1, 0, false) } + }, + } + var header *types.Header + if err := pc.handler.backend.retriever.retrieve(context, reqID, rq, func(peer distPeer, msg *Msg) error { + if msg.MsgType != MsgBlockHeaders { + return errInvalidMessageType + } + headers := msg.Obj.([]*types.Header) + if len(headers) != 1 { + return errInvalidEntryCount + } + header = headers[0] + return nil + }, nil); err != nil { + return nil, err + } + return header, nil +} + // downloaderPeerNotify implements peerSetNotify type downloaderPeerNotify clientHandler diff --git a/les/odr.go b/les/odr.go index f8469cc10..2c36f512d 100644 --- a/les/odr.go +++ b/les/odr.go @@ -24,7 +24,6 @@ import ( "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/ethdb" "github.com/ethereum/go-ethereum/light" - "github.com/ethereum/go-ethereum/log" ) // LesOdr implements light.OdrBackend @@ -83,7 +82,8 @@ func (odr *LesOdr) IndexerConfig() *light.IndexerConfig { } const ( - MsgBlockBodies = iota + MsgBlockHeaders = iota + MsgBlockBodies MsgCode MsgReceipts MsgProofsV2 @@ -122,13 +122,17 @@ func (odr *LesOdr) Retrieve(ctx context.Context, req light.OdrRequest) (err erro return func() { lreq.Request(reqID, p) } }, } - sent := mclock.Now() - if err = odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err == nil { - // retrieved from network, store in db - req.StoreResult(odr.db) + + defer func(sent mclock.AbsTime) { + if err != nil { + return + } requestRTT.Update(time.Duration(mclock.Now() - sent)) - } else { - log.Debug("Failed to retrieve data from network", "err", err) + }(mclock.Now()) + + if err := odr.retriever.retrieve(ctx, reqID, rq, func(p distPeer, msg *Msg) error { return lreq.Validate(odr.db, msg) }, odr.stop); err != nil { + return err } - return + req.StoreResult(odr.db) + return nil } diff --git a/les/odr_requests.go b/les/odr_requests.go index 3704436a0..eb1d3602e 100644 --- a/les/odr_requests.go +++ b/les/odr_requests.go @@ -327,9 +327,6 @@ func (r *ChtRequest) CanSend(peer *serverPeer) bool { peer.lock.RLock() defer peer.lock.RUnlock() - if r.Untrusted { - return peer.headInfo.Number >= r.BlockNum && peer.id == r.PeerId - } return peer.headInfo.Number >= r.Config.ChtConfirms && r.ChtNum <= (peer.headInfo.Number-r.Config.ChtConfirms)/r.Config.ChtSize } @@ -369,39 +366,34 @@ func (r *ChtRequest) Validate(db ethdb.Database, msg *Msg) error { if err := rlp.DecodeBytes(headerEnc, header); err != nil { return errHeaderUnavailable } - // Verify the CHT - // Note: For untrusted CHT request, there is no proof response but - // header data. - var node light.ChtNode - if !r.Untrusted { - var encNumber [8]byte - binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) + var ( + node light.ChtNode + encNumber [8]byte + ) + binary.BigEndian.PutUint64(encNumber[:], r.BlockNum) - reads := &readTraceDB{db: nodeSet} - value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) - if err != nil { - return fmt.Errorf("merkle proof verification failed: %v", err) - } - if len(reads.reads) != nodeSet.KeyCount() { - return errUselessNodes - } - - if err := rlp.DecodeBytes(value, &node); err != nil { - return err - } - if node.Hash != header.Hash() { - return errCHTHashMismatch - } - if r.BlockNum != header.Number.Uint64() { - return errCHTNumberMismatch - } + reads := &readTraceDB{db: nodeSet} + value, err := trie.VerifyProof(r.ChtRoot, encNumber[:], reads) + if err != nil { + return fmt.Errorf("merkle proof verification failed: %v", err) + } + if len(reads.reads) != nodeSet.KeyCount() { + return errUselessNodes + } + if err := rlp.DecodeBytes(value, &node); err != nil { + return err + } + if node.Hash != header.Hash() { + return errCHTHashMismatch + } + if r.BlockNum != header.Number.Uint64() { + return errCHTNumberMismatch } // Verifications passed, store and return r.Header = header r.Proof = nodeSet - r.Td = node.Td // For untrusted request, td here is nil, todo improve the les/2 protocol - + r.Td = node.Td return nil } diff --git a/les/retrieve.go b/les/retrieve.go index 4f77004f2..ca4f867ea 100644 --- a/les/retrieve.go +++ b/les/retrieve.go @@ -155,6 +155,15 @@ func (rm *retrieveManager) sendReq(reqID uint64, req *distReq, val validatorFunc return r } +// requested reports whether the request with given reqid is sent by the retriever. +func (rm *retrieveManager) requested(reqId uint64) bool { + rm.lock.RLock() + defer rm.lock.RUnlock() + + _, ok := rm.sentReqs[reqId] + return ok +} + // deliver is called by the LES protocol manager to deliver reply messages to waiting requests func (rm *retrieveManager) deliver(peer distPeer, msg *Msg) error { rm.lock.RLock() diff --git a/les/sync.go b/les/sync.go index d2568d45b..ad3a0e0f3 100644 --- a/les/sync.go +++ b/les/sync.go @@ -56,8 +56,8 @@ func (h *clientHandler) validateCheckpoint(peer *serverPeer) error { defer cancel() // Fetch the block header corresponding to the checkpoint registration. - cp := peer.checkpoint - header, err := light.GetUntrustedHeaderByNumber(ctx, h.backend.odr, peer.checkpointNumber, peer.id) + wrapPeer := &peerConnection{handler: h, peer: peer} + header, err := wrapPeer.RetrieveSingleHeaderByNumber(ctx, peer.checkpointNumber) if err != nil { return err } @@ -66,7 +66,7 @@ func (h *clientHandler) validateCheckpoint(peer *serverPeer) error { if err != nil { return err } - events := h.backend.oracle.Contract().LookupCheckpointEvents(logs, cp.SectionIndex, cp.Hash()) + events := h.backend.oracle.Contract().LookupCheckpointEvents(logs, peer.checkpoint.SectionIndex, peer.checkpoint.Hash()) if len(events) == 0 { return errInvalidCheckpoint } diff --git a/les/sync_test.go b/les/sync_test.go index 6924e7b43..2eb0f88bf 100644 --- a/les/sync_test.go +++ b/les/sync_test.go @@ -53,7 +53,7 @@ func testCheckpointSyncing(t *testing.T, protocol int, syncMode int) { time.Sleep(10 * time.Millisecond) } } - // Generate 512+4 blocks (totally 1 CHT sections) + // Generate 128+1 blocks (totally 1 CHT sections) server, client, tearDown := newClientServerEnv(t, int(config.ChtSize+config.ChtConfirms), protocol, waitIndexers, nil, 0, false, false, true) defer tearDown() diff --git a/light/odr.go b/light/odr.go index 7016ef8ef..bb243f915 100644 --- a/light/odr.go +++ b/light/odr.go @@ -135,8 +135,6 @@ func (req *ReceiptsRequest) StoreResult(db ethdb.Database) { // ChtRequest is the ODR request type for retrieving header by Canonical Hash Trie type ChtRequest struct { - Untrusted bool // Indicator whether the result retrieved is trusted or not - PeerId string // The specified peer id from which to retrieve data. Config *IndexerConfig ChtNum, BlockNum uint64 ChtRoot common.Hash @@ -148,12 +146,9 @@ type ChtRequest struct { // StoreResult stores the retrieved data in local database func (req *ChtRequest) StoreResult(db ethdb.Database) { hash, num := req.Header.Hash(), req.Header.Number.Uint64() - - if !req.Untrusted { - rawdb.WriteHeader(db, req.Header) - rawdb.WriteTd(db, hash, num, req.Td) - rawdb.WriteCanonicalHash(db, hash, num) - } + rawdb.WriteHeader(db, req.Header) + rawdb.WriteTd(db, hash, num, req.Td) + rawdb.WriteCanonicalHash(db, hash, num) } // BloomRequest is the ODR request type for retrieving bloom filters from a CHT structure diff --git a/light/odr_util.go b/light/odr_util.go index aec0c7b69..ec2d1e649 100644 --- a/light/odr_util.go +++ b/light/odr_util.go @@ -19,20 +19,23 @@ package light import ( "bytes" "context" + "errors" "math/big" "github.com/ethereum/go-ethereum/common" "github.com/ethereum/go-ethereum/core" "github.com/ethereum/go-ethereum/core/rawdb" "github.com/ethereum/go-ethereum/core/types" - "github.com/ethereum/go-ethereum/crypto" "github.com/ethereum/go-ethereum/rlp" ) -var sha3Nil = crypto.Keccak256Hash(nil) +// errNonCanonicalHash is returned if the requested chain data doesn't belong +// to the canonical chain. ODR can only retrieve the canonical chain data covered +// by the CHT or Bloom trie for verification. +var errNonCanonicalHash = errors.New("hash is not currently canonical") // GetHeaderByNumber retrieves the canonical block header corresponding to the -// given number. +// given number. The returned header is proven by local CHT. func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*types.Header, error) { // Try to find it in the local database first. db := odr.Database() @@ -63,25 +66,6 @@ func GetHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64) (*typ return r.Header, nil } -// GetUntrustedHeaderByNumber retrieves specified block header without -// correctness checking. Note this function should only be used in light -// client checkpoint syncing. -func GetUntrustedHeaderByNumber(ctx context.Context, odr OdrBackend, number uint64, peerId string) (*types.Header, error) { - // todo(rjl493456442) it's a hack to retrieve headers which is not covered - // by CHT. Fix it in LES4 - r := &ChtRequest{ - BlockNum: number, - ChtNum: number / odr.IndexerConfig().ChtSize, - Untrusted: true, - PeerId: peerId, - Config: odr.IndexerConfig(), - } - if err := odr.Retrieve(ctx, r); err != nil { - return nil, err - } - return r.Header, nil -} - // GetCanonicalHash retrieves the canonical block hash corresponding to the number. func GetCanonicalHash(ctx context.Context, odr OdrBackend, number uint64) (common.Hash, error) { hash := rawdb.ReadCanonicalHash(odr.Database(), number) @@ -102,10 +86,13 @@ func GetTd(ctx context.Context, odr OdrBackend, hash common.Hash, number uint64) if td != nil { return td, nil } - _, err := GetHeaderByNumber(ctx, odr, number) + header, err := GetHeaderByNumber(ctx, odr, number) if err != nil { return nil, err } + if header.Hash() != hash { + return nil, errNonCanonicalHash + } // -> td mapping already be stored in db, get it. return rawdb.ReadTd(odr.Database(), hash, number), nil } @@ -120,6 +107,9 @@ func GetBodyRLP(ctx context.Context, odr OdrBackend, hash common.Hash, number ui if err != nil { return nil, errNoHeader } + if header.Hash() != hash { + return nil, errNonCanonicalHash + } r := &BlockRequest{Hash: hash, Number: number, Header: header} if err := odr.Retrieve(ctx, r); err != nil { return nil, err @@ -167,6 +157,9 @@ func GetBlockReceipts(ctx context.Context, odr OdrBackend, hash common.Hash, num if err != nil { return nil, errNoHeader } + if header.Hash() != hash { + return nil, errNonCanonicalHash + } r := &ReceiptsRequest{Hash: hash, Number: number, Header: header} if err := odr.Retrieve(ctx, r); err != nil { return nil, err diff --git a/light/trie.go b/light/trie.go index 3eb05f4a3..0516b9448 100644 --- a/light/trie.go +++ b/light/trie.go @@ -30,6 +30,10 @@ import ( "github.com/ethereum/go-ethereum/trie" ) +var ( + sha3Nil = crypto.Keccak256Hash(nil) +) + func NewState(ctx context.Context, head *types.Header, odr OdrBackend) *state.StateDB { state, _ := state.New(head.Root, NewStateDatabase(ctx, head, odr), nil) return state