diff --git a/core/blockchain.go b/core/blockchain.go index 22f130ce6..d173b2de2 100644 --- a/core/blockchain.go +++ b/core/blockchain.go @@ -561,6 +561,17 @@ func (bc *BlockChain) HasBlock(hash common.Hash, number uint64) bool { return rawdb.HasBody(bc.db, hash, number) } +// HasFastBlock checks if a fast block is fully present in the database or not. +func (bc *BlockChain) HasFastBlock(hash common.Hash, number uint64) bool { + if !bc.HasBlock(hash, number) { + return false + } + if bc.receiptsCache.Contains(hash) { + return true + } + return rawdb.HasReceipts(bc.db, hash, number) +} + // HasState checks if state trie is fully present in the database or not. func (bc *BlockChain) HasState(hash common.Hash) bool { _, err := bc.stateCache.OpenTrie(hash) @@ -618,12 +629,10 @@ func (bc *BlockChain) GetReceiptsByHash(hash common.Hash) types.Receipts { if receipts, ok := bc.receiptsCache.Get(hash); ok { return receipts.(types.Receipts) } - number := rawdb.ReadHeaderNumber(bc.db, hash) if number == nil { return nil } - receipts := rawdb.ReadReceipts(bc.db, hash, *number) bc.receiptsCache.Add(hash, receipts) return receipts diff --git a/core/rawdb/accessors_chain.go b/core/rawdb/accessors_chain.go index 6660e17de..491a125c6 100644 --- a/core/rawdb/accessors_chain.go +++ b/core/rawdb/accessors_chain.go @@ -271,6 +271,15 @@ func DeleteTd(db DatabaseDeleter, hash common.Hash, number uint64) { } } +// HasReceipts verifies the existence of all the transaction receipts belonging +// to a block. +func HasReceipts(db DatabaseReader, hash common.Hash, number uint64) bool { + if has, err := db.Has(blockReceiptsKey(number, hash)); !has || err != nil { + return false + } + return true +} + // ReadReceipts retrieves all the transaction receipts belonging to a block. func ReadReceipts(db DatabaseReader, hash common.Hash, number uint64) types.Receipts { // Retrieve the flattened receipt slice diff --git a/eth/downloader/downloader.go b/eth/downloader/downloader.go index 56c54c8ed..f81a5cbac 100644 --- a/eth/downloader/downloader.go +++ b/eth/downloader/downloader.go @@ -181,6 +181,9 @@ type BlockChain interface { // HasBlock verifies a block's presence in the local chain. HasBlock(common.Hash, uint64) bool + // HasFastBlock verifies a fast block's presence in the local chain. + HasFastBlock(common.Hash, uint64) bool + // GetBlockByHash retrieves a block from the local chain. GetBlockByHash(common.Hash) *types.Block @@ -430,7 +433,7 @@ func (d *Downloader) syncWithPeer(p *peerConnection, hash common.Hash, td *big.I } height := latest.Number.Uint64() - origin, err := d.findAncestor(p, height) + origin, err := d.findAncestor(p, latest) if err != nil { return err } @@ -587,41 +590,86 @@ func (d *Downloader) fetchHeight(p *peerConnection) (*types.Header, error) { } } +// calculateRequestSpan calculates what headers to request from a peer when trying to determine the +// common ancestor. +// It returns parameters to be used for peer.RequestHeadersByNumber: +// from - starting block number +// count - number of headers to request +// skip - number of headers to skip +// and also returns 'max', the last block which is expected to be returned by the remote peers, +// given the (from,count,skip) +func calculateRequestSpan(remoteHeight, localHeight uint64) (int64, int, int, uint64) { + var ( + from int + count int + MaxCount = MaxHeaderFetch / 16 + ) + // requestHead is the highest block that we will ask for. If requestHead is not offset, + // the highest block that we will get is 16 blocks back from head, which means we + // will fetch 14 or 15 blocks unnecessarily in the case the height difference + // between us and the peer is 1-2 blocks, which is most common + requestHead := int(remoteHeight) - 1 + if requestHead < 0 { + requestHead = 0 + } + // requestBottom is the lowest block we want included in the query + // Ideally, we want to include just below own head + requestBottom := int(localHeight - 1) + if requestBottom < 0 { + requestBottom = 0 + } + totalSpan := requestHead - requestBottom + span := 1 + totalSpan/MaxCount + if span < 2 { + span = 2 + } + if span > 16 { + span = 16 + } + + count = 1 + totalSpan/span + if count > MaxCount { + count = MaxCount + } + if count < 2 { + count = 2 + } + from = requestHead - (count-1)*span + if from < 0 { + from = 0 + } + max := from + (count-1)*span + return int64(from), count, span - 1, uint64(max) +} + // findAncestor tries to locate the common ancestor link of the local chain and // a remote peers blockchain. In the general case when our node was in sync and // on the correct chain, checking the top N links should already get us a match. // In the rare scenario when we ended up on a long reorganisation (i.e. none of // the head links match), we do a binary search to find the common ancestor. -func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, error) { +func (d *Downloader) findAncestor(p *peerConnection, remoteHeader *types.Header) (uint64, error) { // Figure out the valid ancestor range to prevent rewrite attacks - floor, ceil := int64(-1), d.lightchain.CurrentHeader().Number.Uint64() + var ( + floor = int64(-1) + localHeight uint64 + remoteHeight = remoteHeader.Number.Uint64() + ) + switch d.mode { + case FullSync: + localHeight = d.blockchain.CurrentBlock().NumberU64() + case FastSync: + localHeight = d.blockchain.CurrentFastBlock().NumberU64() + default: + localHeight = d.lightchain.CurrentHeader().Number.Uint64() + } + p.log.Debug("Looking for common ancestor", "local", localHeight, "remote", remoteHeight) + if localHeight >= MaxForkAncestry { + floor = int64(localHeight - MaxForkAncestry) + } + from, count, skip, max := calculateRequestSpan(remoteHeight, localHeight) - if d.mode == FullSync { - ceil = d.blockchain.CurrentBlock().NumberU64() - } else if d.mode == FastSync { - ceil = d.blockchain.CurrentFastBlock().NumberU64() - } - if ceil >= MaxForkAncestry { - floor = int64(ceil - MaxForkAncestry) - } - p.log.Debug("Looking for common ancestor", "local", ceil, "remote", height) - - // Request the topmost blocks to short circuit binary ancestor lookup - head := ceil - if head > height { - head = height - } - from := int64(head) - int64(MaxHeaderFetch) - if from < 0 { - from = 0 - } - // Span out with 15 block gaps into the future to catch bad head reports - limit := 2 * MaxHeaderFetch / 16 - count := 1 + int((int64(ceil)-from)/16) - if count > limit { - count = limit - } - go p.peer.RequestHeadersByNumber(uint64(from), count, 15, false) + p.log.Trace("Span searching for common ancestor", "count", count, "from", from, "skip", skip) + go p.peer.RequestHeadersByNumber(uint64(from), count, skip, false) // Wait for the remote response to the head fetch number, hash := uint64(0), common.Hash{} @@ -647,9 +695,10 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err return 0, errEmptyHeaderSet } // Make sure the peer's reply conforms to the request - for i := 0; i < len(headers); i++ { - if number := headers[i].Number.Int64(); number != from+int64(i)*16 { - p.log.Warn("Head headers broke chain ordering", "index", i, "requested", from+int64(i)*16, "received", number) + for i, header := range headers { + expectNumber := from + int64(i)*int64((skip+1)) + if number := header.Number.Int64(); number != expectNumber { + p.log.Warn("Head headers broke chain ordering", "index", i, "requested", expectNumber, "received", number) return 0, errInvalidChain } } @@ -657,20 +706,24 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err finished = true for i := len(headers) - 1; i >= 0; i-- { // Skip any headers that underflow/overflow our requested set - if headers[i].Number.Int64() < from || headers[i].Number.Uint64() > ceil { + if headers[i].Number.Int64() < from || headers[i].Number.Uint64() > max { continue } // Otherwise check if we already know the header or not h := headers[i].Hash() n := headers[i].Number.Uint64() - if (d.mode == FullSync && d.blockchain.HasBlock(h, n)) || (d.mode != FullSync && d.lightchain.HasHeader(h, n)) { - number, hash = n, h - // If every header is known, even future ones, the peer straight out lied about its head - if number > height && i == limit-1 { - p.log.Warn("Lied about chain head", "reported", height, "found", number) - return 0, errStallingPeer - } + var known bool + switch d.mode { + case FullSync: + known = d.blockchain.HasBlock(h, n) + case FastSync: + known = d.blockchain.HasFastBlock(h, n) + default: + known = d.lightchain.HasHeader(h, n) + } + if known { + number, hash = n, h break } } @@ -694,10 +747,12 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err return number, nil } // Ancestor not found, we need to binary search over our chain - start, end := uint64(0), head + start, end := uint64(0), remoteHeight if floor > 0 { start = uint64(floor) } + p.log.Trace("Binary searching for common ancestor", "start", start, "end", end) + for start+1 < end { // Split our chain interval in two, and request the hash to cross check check := (start + end) / 2 @@ -730,7 +785,17 @@ func (d *Downloader) findAncestor(p *peerConnection, height uint64) (uint64, err // Modify the search interval based on the response h := headers[0].Hash() n := headers[0].Number.Uint64() - if (d.mode == FullSync && !d.blockchain.HasBlock(h, n)) || (d.mode != FullSync && !d.lightchain.HasHeader(h, n)) { + + var known bool + switch d.mode { + case FullSync: + known = d.blockchain.HasBlock(h, n) + case FastSync: + known = d.blockchain.HasFastBlock(h, n) + default: + known = d.lightchain.HasHeader(h, n) + } + if !known { end = check break } diff --git a/eth/downloader/downloader_test.go b/eth/downloader/downloader_test.go index 1fe02d884..1a42965d3 100644 --- a/eth/downloader/downloader_test.go +++ b/eth/downloader/downloader_test.go @@ -20,6 +20,7 @@ import ( "errors" "fmt" "math/big" + "strings" "sync" "sync/atomic" "testing" @@ -114,6 +115,15 @@ func (dl *downloadTester) HasBlock(hash common.Hash, number uint64) bool { return dl.GetBlockByHash(hash) != nil } +// HasFastBlock checks if a block is present in the testers canonical chain. +func (dl *downloadTester) HasFastBlock(hash common.Hash, number uint64) bool { + dl.lock.RLock() + defer dl.lock.RUnlock() + + _, ok := dl.ownReceipts[hash] + return ok +} + // GetHeader retrieves a header from the testers canonical chain. func (dl *downloadTester) GetHeaderByHash(hash common.Hash) *types.Header { dl.lock.RLock() @@ -234,6 +244,7 @@ func (dl *downloadTester) InsertChain(blocks types.Blocks) (i int, err error) { dl.ownHeaders[block.Hash()] = block.Header() } dl.ownBlocks[block.Hash()] = block + dl.ownReceipts[block.Hash()] = make(types.Receipts, 0) dl.stateDb.Put(block.Root().Bytes(), []byte{0x00}) dl.ownChainTd[block.Hash()] = new(big.Int).Add(dl.ownChainTd[block.ParentHash()], block.Difficulty()) } @@ -374,28 +385,28 @@ func (dlp *downloadTesterPeer) RequestNodeData(hashes []common.Hash) error { // assertOwnChain checks if the local chain contains the correct number of items // of the various chain components. func assertOwnChain(t *testing.T, tester *downloadTester, length int) { + // Mark this method as a helper to report errors at callsite, not in here + t.Helper() + assertOwnForkedChain(t, tester, 1, []int{length}) } // assertOwnForkedChain checks if the local forked chain contains the correct // number of items of the various chain components. func assertOwnForkedChain(t *testing.T, tester *downloadTester, common int, lengths []int) { - // Initialize the counters for the first fork - headers, blocks, receipts := lengths[0], lengths[0], lengths[0]-fsMinFullBlocks + // Mark this method as a helper to report errors at callsite, not in here + t.Helper() + + // Initialize the counters for the first fork + headers, blocks, receipts := lengths[0], lengths[0], lengths[0] - if receipts < 0 { - receipts = 1 - } // Update the counters for each subsequent fork for _, length := range lengths[1:] { headers += length - common blocks += length - common - receipts += length - common - fsMinFullBlocks + receipts += length - common } - switch tester.downloader.mode { - case FullSync: - receipts = 1 - case LightSync: + if tester.downloader.mode == LightSync { blocks, receipts = 1, 1 } if hs := len(tester.ownHeaders); hs != headers { @@ -1149,7 +1160,9 @@ func testSyncProgress(t *testing.T, protocol int, mode SyncMode) { } func checkProgress(t *testing.T, d *Downloader, stage string, want ethereum.SyncProgress) { + // Mark this method as a helper to report errors at callsite, not in here t.Helper() + p := d.Progress() p.KnownStates, p.PulledStates = 0, 0 want.KnownStates, want.PulledStates = 0, 0 @@ -1479,3 +1492,78 @@ func (ftp *floodingTestPeer) RequestHeadersByNumber(from uint64, count, skip int } return nil } + +func TestRemoteHeaderRequestSpan(t *testing.T) { + testCases := []struct { + remoteHeight uint64 + localHeight uint64 + expected []int + }{ + // Remote is way higher. We should ask for the remote head and go backwards + {1500, 1000, + []int{1323, 1339, 1355, 1371, 1387, 1403, 1419, 1435, 1451, 1467, 1483, 1499}, + }, + {15000, 13006, + []int{14823, 14839, 14855, 14871, 14887, 14903, 14919, 14935, 14951, 14967, 14983, 14999}, + }, + //Remote is pretty close to us. We don't have to fetch as many + {1200, 1150, + []int{1149, 1154, 1159, 1164, 1169, 1174, 1179, 1184, 1189, 1194, 1199}, + }, + // Remote is equal to us (so on a fork with higher td) + // We should get the closest couple of ancestors + {1500, 1500, + []int{1497, 1499}, + }, + // We're higher than the remote! Odd + {1000, 1500, + []int{997, 999}, + }, + // Check some weird edgecases that it behaves somewhat rationally + {0, 1500, + []int{0, 2}, + }, + {6000000, 0, + []int{5999823, 5999839, 5999855, 5999871, 5999887, 5999903, 5999919, 5999935, 5999951, 5999967, 5999983, 5999999}, + }, + {0, 0, + []int{0, 2}, + }, + } + reqs := func(from, count, span int) []int { + var r []int + num := from + for len(r) < count { + r = append(r, num) + num += span + 1 + } + return r + } + for i, tt := range testCases { + from, count, span, max := calculateRequestSpan(tt.remoteHeight, tt.localHeight) + data := reqs(int(from), count, span) + + if max != uint64(data[len(data)-1]) { + t.Errorf("test %d: wrong last value %d != %d", i, data[len(data)-1], max) + } + failed := false + if len(data) != len(tt.expected) { + failed = true + t.Errorf("test %d: length wrong, expected %d got %d", i, len(tt.expected), len(data)) + } else { + for j, n := range data { + if n != tt.expected[j] { + failed = true + break + } + } + } + if failed { + res := strings.Replace(fmt.Sprint(data), " ", ",", -1) + exp := strings.Replace(fmt.Sprint(tt.expected), " ", ",", -1) + fmt.Printf("got: %v\n", res) + fmt.Printf("exp: %v\n", exp) + t.Errorf("test %d: wrong values", i) + } + } +}