Cleaned up research repo slightly
This commit is contained in:
parent
7e3482c532
commit
ba690c0307
Binary file not shown.
|
@ -1,139 +0,0 @@
|
|||
import time
|
||||
|
||||
def legendre_symbol(a, p):
|
||||
"""
|
||||
Legendre symbol
|
||||
Define if a is a quadratic residue modulo odd prime
|
||||
http://en.wikipedia.org/wiki/Legendre_symbol
|
||||
"""
|
||||
ls = pow(a, (p - 1)/2, p)
|
||||
if ls == p - 1:
|
||||
return -1
|
||||
return ls
|
||||
|
||||
def prime_mod_sqrt(a, p):
|
||||
"""
|
||||
Square root modulo prime number
|
||||
Solve the equation
|
||||
x^2 = a mod p
|
||||
and return list of x solution
|
||||
http://en.wikipedia.org/wiki/Tonelli-Shanks_algorithm
|
||||
"""
|
||||
a %= p
|
||||
|
||||
# Simple case
|
||||
if a == 0:
|
||||
return [0]
|
||||
if p == 2:
|
||||
return [a]
|
||||
|
||||
# Check solution existence on odd prime
|
||||
if legendre_symbol(a, p) != 1:
|
||||
return []
|
||||
|
||||
# Simple case
|
||||
if p % 4 == 3:
|
||||
x = pow(a, (p + 1)/4, p)
|
||||
return [x, p-x]
|
||||
|
||||
# Factor p-1 on the form q * 2^s (with Q odd)
|
||||
q, s = p - 1, 0
|
||||
while q % 2 == 0:
|
||||
s += 1
|
||||
q //= 2
|
||||
|
||||
# Select a z which is a quadratic non resudue modulo p
|
||||
z = 1
|
||||
while legendre_symbol(z, p) != -1:
|
||||
z += 1
|
||||
c = pow(z, q, p)
|
||||
|
||||
# Search for a solution
|
||||
x = pow(a, (q + 1)/2, p)
|
||||
t = pow(a, q, p)
|
||||
m = s
|
||||
while t != 1:
|
||||
# Find the lowest i such that t^(2^i) = 1
|
||||
i, e = 0, 2
|
||||
for i in xrange(1, m):
|
||||
if pow(t, e, p) == 1:
|
||||
break
|
||||
e *= 2
|
||||
|
||||
# Update next value to iterate
|
||||
b = pow(c, 2**(m - i - 1), p)
|
||||
x = (x * b) % p
|
||||
t = (t * b * b) % p
|
||||
c = (b * b) % p
|
||||
m = i
|
||||
|
||||
return [x, p-x]
|
||||
|
||||
def inv(a, n):
|
||||
if a == 0:
|
||||
return 0
|
||||
lm, hm = 1, 0
|
||||
low, high = a % n, n
|
||||
while low > 1:
|
||||
r = high//low
|
||||
nm, new = hm-lm*r, high-low*r
|
||||
lm, low, hm, high = nm, new, lm, low
|
||||
return lm % n
|
||||
|
||||
# Pre-compute (i) a list of primes
|
||||
# (ii) a list of -1 legendre bases for each prime
|
||||
# (iii) the inverse for each base
|
||||
LENPRIMES = 1000
|
||||
primes = []
|
||||
r = 2**31 - 1
|
||||
for i in range(LENPRIMES):
|
||||
r += 2
|
||||
while pow(2, r, r) != 2: r += 2
|
||||
primes.append(r)
|
||||
bases = [None] * LENPRIMES
|
||||
invbases = [None] * LENPRIMES
|
||||
for i in range(LENPRIMES):
|
||||
b = 2
|
||||
while legendre_symbol(b, primes[i]) == 1:
|
||||
b += 1
|
||||
bases[i] = b
|
||||
invbases[i] = inv(b, primes[i])
|
||||
|
||||
# Compute the PoW
|
||||
def forward(val, rounds=10**6):
|
||||
t1 = time.time()
|
||||
for i in range(rounds):
|
||||
# Select a prime
|
||||
p = primes[i % LENPRIMES]
|
||||
# Make sure the value we're working on is a
|
||||
# quadratic residue. If it's not, do a spooky
|
||||
# transform (ie. multiply by a known
|
||||
# non-residue) to make sure that it is
|
||||
if legendre_symbol(val, p) != 1:
|
||||
val = (val * invbases[i % LENPRIMES]) % p
|
||||
mul_by_base = 1
|
||||
else:
|
||||
mul_by_base = 0
|
||||
# Take advantage of the fact that two square
|
||||
# roots exist to hide whether or not the spooky
|
||||
# transform was done in the result so that we
|
||||
# can invert it when verifying
|
||||
val = sorted(prime_mod_sqrt(val, p))[mul_by_base]
|
||||
print time.time() - t1
|
||||
return val
|
||||
|
||||
def backward(val, rounds=10**6):
|
||||
t1 = time.time()
|
||||
for i in range(rounds-1, -1, -1):
|
||||
# Select a prime
|
||||
p = primes[i % LENPRIMES]
|
||||
# Extract the info about whether or not the
|
||||
# spooky transform was done
|
||||
mul_by_base = val * 2 > p
|
||||
# Square the value (ie. invert the square root)
|
||||
val = pow(val, 2, p)
|
||||
# Undo the spooky transform if needed
|
||||
if mul_by_base:
|
||||
val = (val * bases[i % LENPRIMES]) % p
|
||||
print time.time() - t1
|
||||
return val
|
|
@ -1,195 +0,0 @@
|
|||
from ethereum.casper_utils import RandaoManager, get_skips_and_block_making_time, \
|
||||
generate_validation_code, call_casper, sign_block, check_skips, get_timestamp, \
|
||||
get_casper_ct, get_dunkle_candidates, \
|
||||
make_withdrawal_signature
|
||||
from ethereum.utils import sha3, hash32, privtoaddr, ecsign, zpad, encode_int32, \
|
||||
big_endian_to_int
|
||||
from ethereum.transaction_queue import TransactionQueue
|
||||
from ethereum.block_creation import make_head_candidate
|
||||
from ethereum.block import Block
|
||||
from ethereum.transactions import Transaction
|
||||
from ethereum.chain import Chain
|
||||
import networksim
|
||||
import rlp
|
||||
import random
|
||||
|
||||
CHECK_FOR_UNCLES_BACK = 8
|
||||
|
||||
global_block_counter = 0
|
||||
|
||||
casper_ct = get_casper_ct()
|
||||
|
||||
class ChildRequest(rlp.Serializable):
|
||||
fields = [
|
||||
('prevhash', hash32)
|
||||
]
|
||||
|
||||
def __init__(self, prevhash):
|
||||
self.prevhash = prevhash
|
||||
|
||||
@property
|
||||
def hash(self):
|
||||
return sha3(self.prevhash + '::salt:jhfqou213nry138o2r124124')
|
||||
|
||||
ids = []
|
||||
|
||||
class Validator():
|
||||
def __init__(self, genesis, key, network, env, time_offset=5):
|
||||
# Create a chain object
|
||||
self.chain = Chain(genesis, env=env)
|
||||
# Create a transaction queue
|
||||
self.txqueue = TransactionQueue()
|
||||
# Use the validator's time as the chain's time
|
||||
self.chain.time = lambda: self.get_timestamp()
|
||||
# My private key
|
||||
self.key = key
|
||||
# My address
|
||||
self.address = privtoaddr(key)
|
||||
# My randao
|
||||
self.randao = RandaoManager(sha3(self.key))
|
||||
# Pointer to the test p2p network
|
||||
self.network = network
|
||||
# Record of objects already received and processed
|
||||
self.received_objects = {}
|
||||
# The minimum eligible timestamp given a particular number of skips
|
||||
self.next_skip_count = 0
|
||||
self.next_skip_timestamp = 0
|
||||
# Is this validator active?
|
||||
self.active = False
|
||||
# Code that verifies signatures from this validator
|
||||
self.validation_code = generate_validation_code(privtoaddr(key))
|
||||
# Validation code hash
|
||||
self.vchash = sha3(self.validation_code)
|
||||
# Parents that this validator has already built a block on
|
||||
self.used_parents = {}
|
||||
# This validator's clock offset (for testing purposes)
|
||||
self.time_offset = random.randrange(time_offset) - (time_offset // 2)
|
||||
# Determine the epoch length
|
||||
self.epoch_length = self.call_casper('getEpochLength')
|
||||
# My minimum gas price
|
||||
self.mingasprice = 20 * 10**9
|
||||
# Give this validator a unique ID
|
||||
self.id = len(ids)
|
||||
ids.append(self.id)
|
||||
self.update_activity_status()
|
||||
self.cached_head = self.chain.head_hash
|
||||
|
||||
def call_casper(self, fun, args=[]):
|
||||
return call_casper(self.chain.state, fun, args)
|
||||
|
||||
def update_activity_status(self):
|
||||
start_epoch = self.call_casper('getStartEpoch', [self.vchash])
|
||||
now_epoch = self.call_casper('getEpoch')
|
||||
end_epoch = self.call_casper('getEndEpoch', [self.vchash])
|
||||
if start_epoch <= now_epoch < end_epoch:
|
||||
self.active = True
|
||||
self.next_skip_count = 0
|
||||
self.next_skip_timestamp = get_timestamp(self.chain, self.next_skip_count)
|
||||
print 'In current validator set'
|
||||
else:
|
||||
self.active = False
|
||||
|
||||
def get_timestamp(self):
|
||||
return int(self.network.time * 0.01) + self.time_offset
|
||||
|
||||
def on_receive(self, obj):
|
||||
if isinstance(obj, list):
|
||||
for _obj in obj:
|
||||
self.on_receive(_obj)
|
||||
return
|
||||
if obj.hash in self.received_objects:
|
||||
return
|
||||
if isinstance(obj, Block):
|
||||
print 'Receiving block', obj
|
||||
assert obj.hash not in self.chain
|
||||
block_success = self.chain.add_block(obj)
|
||||
self.network.broadcast(self, obj)
|
||||
self.network.broadcast(self, ChildRequest(obj.header.hash))
|
||||
self.update_head()
|
||||
elif isinstance(obj, Transaction):
|
||||
print 'Receiving transaction', obj
|
||||
if obj.gasprice >= self.mingasprice:
|
||||
self.txqueue.add_transaction(obj)
|
||||
print 'Added transaction, txqueue size %d' % len(self.txqueue.txs)
|
||||
self.network.broadcast(self, obj)
|
||||
else:
|
||||
print 'Gasprice too low'
|
||||
self.received_objects[obj.hash] = True
|
||||
for x in self.chain.get_chain():
|
||||
assert x.hash in self.received_objects
|
||||
|
||||
def tick(self):
|
||||
# Try to create a block
|
||||
# Conditions:
|
||||
# (i) you are an active validator,
|
||||
# (ii) you have not yet made a block with this parent
|
||||
if self.active and self.chain.head_hash not in self.used_parents:
|
||||
t = self.get_timestamp()
|
||||
# Is it early enough to create the block?
|
||||
if t >= self.next_skip_timestamp and (not self.chain.head or t > self.chain.head.header.timestamp):
|
||||
# Wrong validator; in this case, just wait for the next skip count
|
||||
if not check_skips(self.chain, self.vchash, self.next_skip_count):
|
||||
self.next_skip_count += 1
|
||||
self.next_skip_timestamp = get_timestamp(self.chain, self.next_skip_count)
|
||||
# print 'Incrementing proposed timestamp for block %d to %d' % \
|
||||
# (self.chain.head.header.number + 1 if self.chain.head else 0, self.next_skip_timestamp)
|
||||
return
|
||||
self.used_parents[self.chain.head_hash] = True
|
||||
# Simulated 15% chance of validator failure to make a block
|
||||
if random.random() > 0.999:
|
||||
print 'Simulating validator failure, block %d not created' % (self.chain.head.header.number + 1 if self.chain.head else 0)
|
||||
return
|
||||
# Make the block
|
||||
s1 = self.chain.state.trie.root_hash
|
||||
pre_dunkle_count = self.call_casper('getTotalDunklesIncluded')
|
||||
dunkle_txs = get_dunkle_candidates(self.chain, self.chain.state)
|
||||
blk = make_head_candidate(self.chain, self.txqueue)
|
||||
randao = self.randao.get_parent(self.call_casper('getRandao', [self.vchash]))
|
||||
blk = sign_block(blk, self.key, randao, self.vchash, self.next_skip_count)
|
||||
# Make sure it's valid
|
||||
global global_block_counter
|
||||
global_block_counter += 1
|
||||
for dtx in dunkle_txs:
|
||||
assert dtx in blk.transactions, (dtx, blk.transactions)
|
||||
print 'made block with timestamp %d and %d dunkles' % (blk.timestamp, len(dunkle_txs))
|
||||
s2 = self.chain.state.trie.root_hash
|
||||
assert s1 == s2
|
||||
assert blk.timestamp >= self.next_skip_timestamp
|
||||
assert self.chain.add_block(blk)
|
||||
self.update_head()
|
||||
post_dunkle_count = self.call_casper('getTotalDunklesIncluded')
|
||||
assert post_dunkle_count - pre_dunkle_count == len(dunkle_txs)
|
||||
self.received_objects[blk.hash] = True
|
||||
print 'Validator %d making block %d (%s)' % (self.id, blk.header.number, blk.header.hash[:8].encode('hex'))
|
||||
self.network.broadcast(self, blk)
|
||||
# Sometimes we received blocks too early or out of order;
|
||||
# run an occasional loop that processes these
|
||||
if random.random() < 0.02:
|
||||
self.chain.process_time_queue()
|
||||
self.chain.process_parent_queue()
|
||||
self.update_head()
|
||||
|
||||
def update_head(self):
|
||||
if self.cached_head == self.chain.head_hash:
|
||||
return
|
||||
self.cached_head = self.chain.head_hash
|
||||
if self.chain.state.block_number % self.epoch_length == 0:
|
||||
self.update_activity_status()
|
||||
if self.active:
|
||||
self.next_skip_count = 0
|
||||
self.next_skip_timestamp = get_timestamp(self.chain, self.next_skip_count)
|
||||
print 'Head changed: %s, will attempt creating a block at %d' % (self.chain.head_hash.encode('hex'), self.next_skip_timestamp)
|
||||
|
||||
def withdraw(self, gasprice=20 * 10**9):
|
||||
sigdata = make_withdrawal_signature(self.key)
|
||||
txdata = casper_ct.encode('startWithdrawal', [self.vchash, sigdata])
|
||||
tx = Transaction(self.chain.state.get_nonce(self.address), gasprice, 650000, self.chain.config['CASPER_ADDR'], 0, txdata).sign(self.key)
|
||||
self.txqueue.add_transaction(tx, force=True)
|
||||
self.network.broadcast(self, tx)
|
||||
print 'Withdrawing!'
|
||||
|
||||
def deposit(self, gasprice=20 * 10**9):
|
||||
assert value * 10**18 >= self.chain.state.get_balance(self.address) + gasprice * 1000000
|
||||
tx = Transaction(self.chain.state.get_nonce(self.address) * 10**18, gasprice, 1000000,
|
||||
casper_config['CASPER_ADDR'], value * 10**18,
|
||||
ct.encode('deposit', [self.validation_code, self.randao.get(9999)]))
|
|
@ -1,37 +0,0 @@
|
|||
import random, sys
|
||||
|
||||
|
||||
def normal_distribution(mean, standev):
|
||||
def f():
|
||||
return int(random.normalvariate(mean, standev))
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def exponential_distribution(mean):
|
||||
def f():
|
||||
total = 0
|
||||
while 1:
|
||||
total += 1
|
||||
if not random.randrange(32):
|
||||
break
|
||||
return int(total * 0.03125 * mean)
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def convolve(*args):
|
||||
def f():
|
||||
total = 0
|
||||
for arg in args:
|
||||
total += arg()
|
||||
return total
|
||||
|
||||
return f
|
||||
|
||||
|
||||
def transform(dist, xformer):
|
||||
def f():
|
||||
return xformer(dist())
|
||||
|
||||
return f
|
|
@ -1,75 +0,0 @@
|
|||
from distributions import transform, normal_distribution
|
||||
import random
|
||||
|
||||
|
||||
class NetworkSimulator():
|
||||
|
||||
def __init__(self, latency=50):
|
||||
self.agents = []
|
||||
self.latency_distribution_sample = transform(normal_distribution(latency, (latency * 2) // 5), lambda x: max(x, 0))
|
||||
self.time = 0
|
||||
self.objqueue = {}
|
||||
self.peers = {}
|
||||
self.reliability = 0.9
|
||||
|
||||
def generate_peers(self, num_peers=5):
|
||||
self.peers = {}
|
||||
for a in self.agents:
|
||||
p = []
|
||||
while len(p) <= num_peers // 2:
|
||||
p.append(random.choice(self.agents))
|
||||
if p[-1] == a:
|
||||
p.pop()
|
||||
self.peers[a.id] = self.peers.get(a.id, []) + p
|
||||
for peer in p:
|
||||
self.peers[peer.id] = self.peers.get(peer.id, []) + [a]
|
||||
|
||||
def tick(self):
|
||||
if self.time in self.objqueue:
|
||||
for recipient, obj in self.objqueue[self.time]:
|
||||
if random.random() < self.reliability:
|
||||
recipient.on_receive(obj)
|
||||
del self.objqueue[self.time]
|
||||
for a in self.agents:
|
||||
a.tick()
|
||||
self.time += 1
|
||||
|
||||
def run(self, steps):
|
||||
for i in range(steps):
|
||||
self.tick()
|
||||
|
||||
def broadcast(self, sender, obj):
|
||||
for p in self.peers[sender.id]:
|
||||
recv_time = self.time + self.latency_distribution_sample()
|
||||
if recv_time not in self.objqueue:
|
||||
self.objqueue[recv_time] = []
|
||||
self.objqueue[recv_time].append((p, obj))
|
||||
|
||||
def direct_send(self, to_id, obj):
|
||||
for a in self.agents:
|
||||
if a.id == to_id:
|
||||
recv_time = self.time + self.latency_distribution_sample()
|
||||
if recv_time not in self.objqueue:
|
||||
self.objqueue[recv_time] = []
|
||||
self.objqueue[recv_time].append((a, obj))
|
||||
|
||||
def knock_offline_random(self, n):
|
||||
ko = {}
|
||||
while len(ko) < n:
|
||||
c = random.choice(self.agents)
|
||||
ko[c.id] = c
|
||||
for c in ko.values():
|
||||
self.peers[c.id] = []
|
||||
for a in self.agents:
|
||||
self.peers[a.id] = [x for x in self.peers[a.id] if x.id not in ko]
|
||||
|
||||
def partition(self):
|
||||
a = {}
|
||||
while len(a) < len(self.agents) / 2:
|
||||
c = random.choice(self.agents)
|
||||
a[c.id] = c
|
||||
for c in self.agents:
|
||||
if c.id in a:
|
||||
self.peers[c.id] = [x for x in self.peers[c.id] if x.id in a]
|
||||
else:
|
||||
self.peers[c.id] = [x for x in self.peers[c.id] if x.id not in a]
|
|
@ -1,180 +0,0 @@
|
|||
# The purpose of this script is to test selfish-mining-like strategies
|
||||
# in the randao-based single-chain Casper.
|
||||
import random
|
||||
|
||||
# Reward for mining a block with nonzero skips
|
||||
NON_PRIMARY_REWARD = 0.5
|
||||
# Penalty for mining a dunkle
|
||||
DUNKLE_PENALTY = 0.75
|
||||
# Penalty to a main-chain block which has a dunkle as a sister
|
||||
DUNKLE_SISTER_PENALTY = 0.375
|
||||
|
||||
# Attacker stake power (out of 100). Try setting this value to any
|
||||
# amount, even values above 50!
|
||||
attacker_share = 60
|
||||
|
||||
# A simulated Casper randao
|
||||
def randao_successor(parent, index):
|
||||
return (((parent ^ 53) + index) ** 3) % (10**20 - 11)
|
||||
|
||||
# We categorize "scenarios" by seeing how far ahead we get a chain
|
||||
# of 0-skips from each "path"; if the path itself isn't clear then
|
||||
# return zero
|
||||
heads_of_interest = ['', '1']
|
||||
# Only scan down this far
|
||||
scandepth = 4
|
||||
# eg. (0, 2) means "going straight from the current randao, you
|
||||
# can descend zero, but if you make a one-skip block, from there
|
||||
# you get two 0-skips in a row"
|
||||
scenarios = [None, (0, 1), (0, 2), (0, 3), (0, 4)]
|
||||
# For each scenario, this is the corresponding "path" to go down
|
||||
paths = ['', '10', '100', '100', '1000']
|
||||
|
||||
# Determine the scenario ID (zero is catch-all) from a chain
|
||||
def extract_scenario(chain):
|
||||
chain = chain.copy()
|
||||
o = []
|
||||
for h in heads_of_interest:
|
||||
# Make sure that we can descend down "the path"
|
||||
succeed = True
|
||||
for step in h:
|
||||
if not chain.can_i_extend(int(step)):
|
||||
succeed = False
|
||||
break
|
||||
chain.extend_me(int(step))
|
||||
if not succeed:
|
||||
o.append(0)
|
||||
else:
|
||||
# See how far down we can go
|
||||
i = 0
|
||||
while chain.can_i_extend(0) and i < scandepth:
|
||||
i += 1
|
||||
chain.extend_me(0)
|
||||
o.append(i)
|
||||
if tuple(o) in scenarios:
|
||||
return scenarios.index(tuple(o))
|
||||
else:
|
||||
return 0
|
||||
|
||||
# Class to represent simulated chains
|
||||
class Chain():
|
||||
def __init__(self, randao=0, time=0, length=0, me=0, them=0):
|
||||
self.randao = randao
|
||||
self.time = time
|
||||
self.length = length
|
||||
self.me = me
|
||||
self.them = them
|
||||
|
||||
def copy(self):
|
||||
return Chain(self.randao, self.time, self.length, self.me, self.them)
|
||||
|
||||
def can_i_extend(self, skips):
|
||||
return randao_successor(self.randao, skips) % 100 < attacker_share
|
||||
|
||||
def can_they_extend(self, skips):
|
||||
return randao_successor(self.randao, skips) % 100 >= attacker_share
|
||||
|
||||
def extend_me(self, skips):
|
||||
new_randao = randao_successor(self.randao, skips)
|
||||
assert new_randao % 100 < attacker_share
|
||||
self.randao = new_randao
|
||||
self.time += skips
|
||||
self.length += 1
|
||||
self.me += NON_PRIMARY_REWARD if skips else 1
|
||||
|
||||
def extend_them(self, skips):
|
||||
new_randao = randao_successor(self.randao, skips)
|
||||
assert new_randao % 100 >= attacker_share
|
||||
self.randao = new_randao
|
||||
self.time += skips
|
||||
self.length += 1
|
||||
self.them += NON_PRIMARY_REWARD if skips else 1
|
||||
|
||||
def add_my_dunkles(self, n):
|
||||
self.me -= n * DUNKLE_PENALTY
|
||||
self.them -= n * DUNKLE_SISTER_PENALTY
|
||||
|
||||
def add_their_dunkles(self, n):
|
||||
self.them -= n * DUNKLE_PENALTY
|
||||
self.me -= n * DUNKLE_SISTER_PENALTY
|
||||
|
||||
my_total_loss = 0
|
||||
their_total_loss = 0
|
||||
|
||||
for strat_id in range(2**len(scenarios)):
|
||||
# Strategy map: scenario to 0 = publish, 1 = selfish-validate
|
||||
strategy = [0] + [((strat_id // 2**i) % 2) for i in range(1, len(scenarios))]
|
||||
# 1 = once we go through the selfish-validating "path", reveal it instantly
|
||||
# 0 = don't reveal until the "main chain" looks like it's close to catching up
|
||||
insta_reveal = strat_id % 2
|
||||
|
||||
print 'Testing strategy: %r, insta_reveal: %d' % (strategy, insta_reveal)
|
||||
|
||||
pubchain = Chain(randao=random.randrange(10**20))
|
||||
|
||||
time = 0
|
||||
while time < 100000:
|
||||
# You honestly get a block
|
||||
if pubchain.can_i_extend(0):
|
||||
pubchain.extend_me(0)
|
||||
time += 1
|
||||
continue
|
||||
e = extract_scenario(pubchain)
|
||||
if strategy[e] == 0:
|
||||
# You honestly let them get a block
|
||||
pubchain.extend_them(0)
|
||||
time += 1
|
||||
continue
|
||||
# Build up the secret chain based on the detected path
|
||||
# print 'Selfish mining along path %r' % paths[e]
|
||||
old_me = pubchain.me
|
||||
old_them = pubchain.them
|
||||
old_time = time
|
||||
secchain = pubchain.copy()
|
||||
sectime = time
|
||||
for skipz in paths[e]:
|
||||
skips = int(skipz)
|
||||
secchain.extend_me(skips)
|
||||
sectime += skips + 1
|
||||
# Public chain builds itself up in the meantime
|
||||
pubwait = 0
|
||||
while time < sectime:
|
||||
if pubchain.can_they_extend(pubwait):
|
||||
pubchain.extend_them(pubwait)
|
||||
pubwait = 0
|
||||
else:
|
||||
pubwait += 1
|
||||
time += 1
|
||||
secwait = 0
|
||||
# If the two chains have equal length, or if the secret chain is more than 1 longer, they duel
|
||||
while (secchain.length > pubchain.length + 1 or secchain.length == pubchain.length) and time < 100000 and not insta_reveal:
|
||||
if pubchain.can_they_extend(pubwait):
|
||||
pubchain.extend_them(pubwait)
|
||||
pubwait = 0
|
||||
else:
|
||||
pubwait += 1
|
||||
if secchain.can_i_extend(secwait):
|
||||
secchain.extend_me(secwait)
|
||||
secwait = 0
|
||||
else:
|
||||
secwait += 1
|
||||
time += 1
|
||||
# Secret chain is longer, takes over public chain, public chain goes in as dunkles
|
||||
if secchain.length > pubchain.length:
|
||||
pubchain_blocks = pubchain.them - old_them
|
||||
assert pubchain.me == old_me
|
||||
pubchain = secchain
|
||||
pubchain.add_their_dunkles(pubchain_blocks)
|
||||
# Public chain is longer, miner deletes secret chain so no dunkling
|
||||
else:
|
||||
pass
|
||||
# print 'Score deltas: me %.2f them %.2f, time delta %d' % (pubchain.me - old_me, pubchain.them - old_them, time - old_time)
|
||||
|
||||
my_loss = 100000 * attacker_share / 100 - pubchain.me
|
||||
their_loss = 100000 * (100 - attacker_share) / 100 - pubchain.them
|
||||
my_total_loss += my_loss
|
||||
their_total_loss += their_loss
|
||||
gf = their_loss / my_loss if my_loss > 0 else 999.99
|
||||
print 'My revenue: %d, their revenue: %d, griefing factor %.2f' % (pubchain.me, pubchain.them, gf)
|
||||
|
||||
print 'Total griefing factor: %.2f' % (their_total_loss / my_total_loss)
|
|
@ -1,71 +0,0 @@
|
|||
import networksim
|
||||
from casper import Validator
|
||||
import casper
|
||||
from ethereum.parse_genesis_declaration import mk_basic_state
|
||||
from ethereum.config import Env
|
||||
from ethereum.casper_utils import RandaoManager, generate_validation_code, call_casper, \
|
||||
get_skips_and_block_making_time, sign_block, get_contract_code, \
|
||||
casper_config, get_casper_ct, get_casper_code, get_rlp_decoder_code, \
|
||||
get_hash_without_ed_code, make_casper_genesis
|
||||
from ethereum.utils import sha3, privtoaddr
|
||||
from ethereum.transactions import Transaction
|
||||
from ethereum.state_transition import apply_transaction
|
||||
|
||||
from ethereum.slogging import LogRecorder, configure_logging, set_level
|
||||
# config_string = ':info,eth.vm.log:trace,eth.vm.op:trace,eth.vm.stack:trace,eth.vm.exit:trace,eth.pb.msg:trace,eth.pb.tx:debug'
|
||||
config_string = ':info,eth.vm.log:trace'
|
||||
configure_logging(config_string=config_string)
|
||||
|
||||
n = networksim.NetworkSimulator(latency=150)
|
||||
n.time = 2
|
||||
print 'Generating keys'
|
||||
keys = [sha3(str(i)) for i in range(20)]
|
||||
print 'Initializing randaos'
|
||||
randaos = [RandaoManager(sha3(k)) for k in keys]
|
||||
deposit_sizes = [128] * 15 + [256] * 5
|
||||
|
||||
print 'Creating genesis state'
|
||||
s = make_casper_genesis(validators=[(generate_validation_code(privtoaddr(k)), ds * 10**18, r.get(9999))
|
||||
for k, ds, r in zip(keys, deposit_sizes, randaos)],
|
||||
alloc={privtoaddr(k): {'balance': 10**18} for k in keys},
|
||||
timestamp=2,
|
||||
epoch_length=50)
|
||||
g = s.to_snapshot()
|
||||
print 'Genesis state created'
|
||||
|
||||
validators = [Validator(g, k, n, Env(config=casper_config), time_offset=4) for k in keys]
|
||||
n.agents = validators
|
||||
n.generate_peers()
|
||||
lowest_shared_height = -1
|
||||
made_101_check = 0
|
||||
|
||||
for i in range(100000):
|
||||
# print 'ticking'
|
||||
n.tick()
|
||||
if i % 100 == 0:
|
||||
print '%d ticks passed' % i
|
||||
print 'Validator heads:', [v.chain.head.header.number if v.chain.head else None for v in validators]
|
||||
print 'Total blocks created:', casper.global_block_counter
|
||||
print 'Dunkle count:', call_casper(validators[0].chain.state, 'getTotalDunklesIncluded', [])
|
||||
lowest_shared_height = min([v.chain.head.header.number if v.chain.head else -1 for v in validators])
|
||||
if lowest_shared_height >= 101 and not made_101_check:
|
||||
made_101_check = True
|
||||
print 'Checking that withdrawn validators are inactive'
|
||||
assert len([v for v in validators if v.active]) == len(validators) - 5, len([v for v in validators if v.active])
|
||||
print 'Check successful'
|
||||
break
|
||||
if i == 1:
|
||||
print 'Checking that all validators are active'
|
||||
assert len([v for v in validators if v.active]) == len(validators)
|
||||
print 'Check successful'
|
||||
if i == 2000:
|
||||
print 'Withdrawing a few validators'
|
||||
for v in validators[:5]:
|
||||
v.withdraw()
|
||||
if i == 4000:
|
||||
print 'Checking that validators have withdrawn'
|
||||
for v in validators[:5]:
|
||||
assert v.call_casper('getEndEpoch', [v.vchash]) <= 2
|
||||
for v in validators[5:]:
|
||||
assert v.call_casper('getEndEpoch', [v.vchash]) > 2
|
||||
print 'Check successful'
|
|
@ -1,56 +0,0 @@
|
|||
from ethereum.utils import safe_ord as ord
|
||||
|
||||
# 0100000101010111010000110100100101001001 -> ASCII
|
||||
def decode_bin(x):
|
||||
o = bytearray(len(x) // 8)
|
||||
for i in range(0, len(x), 8):
|
||||
v = 0
|
||||
for c in x[i:i+8]:
|
||||
v = v * 2 + c
|
||||
o[i//8] = v
|
||||
return bytes(o)
|
||||
|
||||
|
||||
# ASCII -> 0100000101010111010000110100100101001001
|
||||
def encode_bin(x):
|
||||
o = b''
|
||||
for c in x:
|
||||
c = ord(c)
|
||||
p = bytearray(8)
|
||||
for i in range(8):
|
||||
p[7-i] = c % 2
|
||||
c //= 2
|
||||
o += p
|
||||
return o
|
||||
|
||||
two_bits = [bytes([0,0]), bytes([0,1]),
|
||||
bytes([1,0]), bytes([1,1])]
|
||||
prefix00 = bytes([0,0])
|
||||
prefix100000 = bytes([1,0,0,0,0,0])
|
||||
|
||||
|
||||
# Encodes a sequence of 0s and 1s into tightly packed bytes
|
||||
def encode_bin_path(b):
|
||||
b2 = bytes((4 - len(b)) % 4) + b
|
||||
prefix = two_bits[len(b) % 4]
|
||||
if len(b2) % 8 == 4:
|
||||
return decode_bin(prefix00 + prefix + b2)
|
||||
else:
|
||||
return decode_bin(prefix100000 + prefix + b2)
|
||||
|
||||
|
||||
# Decodes bytes into a sequence of 0s and 1s
|
||||
def decode_bin_path(p):
|
||||
p = encode_bin(p)
|
||||
if p[0] == 1:
|
||||
p = p[4:]
|
||||
assert p[0:2] == prefix00
|
||||
L = two_bits.index(p[2:4])
|
||||
return p[4+((4 - L) % 4):]
|
||||
|
||||
def common_prefix_length(a, b):
|
||||
o = 0
|
||||
while o < len(a) and o < len(b) and a[o] == b[o]:
|
||||
o += 1
|
||||
return o
|
||||
|
|
@ -1,46 +0,0 @@
|
|||
# Get a Merkle proof
|
||||
def _get_branch(db, node, keypath):
|
||||
if not keypath:
|
||||
return [db.get(node)]
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
if nodetype == KV_TYPE:
|
||||
path = encode_bin_path(L)
|
||||
if keypath[:len(L)] == L:
|
||||
return [b'\x01'+path] + _get_branch(db, R, keypath[len(L):])
|
||||
else:
|
||||
return [b'\x01'+path, db.get(R)]
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
if keypath[:1] == b0:
|
||||
return [b'\x02'+R] + _get_branch(db, L, keypath[1:])
|
||||
else:
|
||||
return [b'\x03'+L] + _get_branch(db, R, keypath[1:])
|
||||
|
||||
# Verify a Merkle proof
|
||||
def _verify_branch(branch, root, keypath, value):
|
||||
nodes = [branch[-1]]
|
||||
_keypath = b''
|
||||
for data in branch[-2::-1]:
|
||||
marker, node = data[0], data[1:]
|
||||
# it's a keypath
|
||||
if marker == 1:
|
||||
node = decode_bin_path(node)
|
||||
_keypath = node + _keypath
|
||||
nodes.insert(0, encode_kv_node(node, sha3(nodes[0])))
|
||||
# it's a right-side branch
|
||||
elif marker == 2:
|
||||
_keypath = b0 + _keypath
|
||||
nodes.insert(0, encode_branch_node(sha3(nodes[0]), node))
|
||||
# it's a left-side branch
|
||||
elif marker == 3:
|
||||
_keypath = b1 + _keypath
|
||||
nodes.insert(0, encode_branch_node(node, sha3(nodes[0])))
|
||||
else:
|
||||
raise Exception("Foo")
|
||||
L, R, nodetype = parse_node(nodes[0])
|
||||
if value:
|
||||
assert _keypath == keypath
|
||||
assert sha3(nodes[0]) == root
|
||||
db = EphemDB()
|
||||
db.kv = {sha3(node): node for node in nodes}
|
||||
assert _get(db, root, keypath) == value
|
||||
return True
|
|
@ -1,84 +0,0 @@
|
|||
from ethereum.utils import sha3, encode_hex
|
||||
from new_bintrie import parse_node, KV_TYPE, BRANCH_TYPE, LEAF_TYPE, encode_bin_path, encode_kv_node, encode_branch_node, decode_bin_path
|
||||
|
||||
KV_COMPRESS_TYPE = 128
|
||||
BRANCH_LEFT_TYPE = 129
|
||||
BRANCH_RIGHT_TYPE = 130
|
||||
|
||||
def compress(witness):
|
||||
parentmap = {}
|
||||
leaves = []
|
||||
for w in witness:
|
||||
L, R, nodetype = parse_node(w)
|
||||
if nodetype == LEAF_TYPE:
|
||||
leaves.append(w)
|
||||
elif nodetype == KV_TYPE:
|
||||
parentmap[R] = w
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
parentmap[L] = w
|
||||
parentmap[R] = w
|
||||
used = {}
|
||||
proof = []
|
||||
for node in leaves:
|
||||
proof.append(node)
|
||||
used[node] = True
|
||||
h = sha3(node)
|
||||
while h in parentmap:
|
||||
node = parentmap[h]
|
||||
L, R, nodetype = parse_node(node)
|
||||
if nodetype == KV_TYPE:
|
||||
proof.append(bytes([KV_COMPRESS_TYPE]) + encode_bin_path(L))
|
||||
elif nodetype == BRANCH_TYPE and L == h:
|
||||
proof.append(bytes([BRANCH_LEFT_TYPE]) + R)
|
||||
elif nodetype == BRANCH_TYPE and R == h:
|
||||
proof.append(bytes([BRANCH_RIGHT_TYPE]) + L)
|
||||
else:
|
||||
raise Exception("something is wrong")
|
||||
h = sha3(node)
|
||||
if h in used:
|
||||
proof.pop()
|
||||
break
|
||||
used[h] = True
|
||||
assert len(used) == len(proof)
|
||||
return proof
|
||||
|
||||
# Input: a serialized node
|
||||
def parse_proof_node(node):
|
||||
if node[0] == BRANCH_LEFT_TYPE:
|
||||
# Output: right child, node type
|
||||
return node[1:33], BRANCH_LEFT_TYPE
|
||||
elif node[0] == BRANCH_RIGHT_TYPE:
|
||||
# Output: left child, node type
|
||||
return node[1:33], BRANCH_RIGHT_TYPE
|
||||
elif node[0] == KV_COMPRESS_TYPE:
|
||||
# Output: keypath: child, node type
|
||||
return decode_bin_path(node[1:]), KV_COMPRESS_TYPE
|
||||
elif node[0] == LEAF_TYPE:
|
||||
# Output: None, value, node type
|
||||
return node[1:], LEAF_TYPE
|
||||
else:
|
||||
raise Exception("Bad node")
|
||||
|
||||
def expand(proof):
|
||||
witness = []
|
||||
lasthash = None
|
||||
for p in proof:
|
||||
sub, nodetype = parse_proof_node(p)
|
||||
if nodetype == LEAF_TYPE:
|
||||
witness.append(p)
|
||||
lasthash = sha3(p)
|
||||
elif nodetype == KV_COMPRESS_TYPE:
|
||||
fullnode = encode_kv_node(sub, lasthash)
|
||||
witness.append(fullnode)
|
||||
lasthash = sha3(fullnode)
|
||||
elif nodetype == BRANCH_LEFT_TYPE:
|
||||
fullnode = encode_branch_node(lasthash, sub)
|
||||
witness.append(fullnode)
|
||||
lasthash = sha3(fullnode)
|
||||
elif nodetype == BRANCH_RIGHT_TYPE:
|
||||
fullnode = encode_branch_node(sub, lasthash)
|
||||
witness.append(fullnode)
|
||||
lasthash = sha3(fullnode)
|
||||
else:
|
||||
raise Exception("Bad node")
|
||||
return witness
|
|
@ -1,322 +0,0 @@
|
|||
from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, encode_bin, decode_bin
|
||||
from ethereum.utils import sha3, encode_hex
|
||||
|
||||
class EphemDB():
|
||||
def __init__(self, kv=None):
|
||||
self.kv = kv or {}
|
||||
|
||||
def get(self, k):
|
||||
return self.kv.get(k, None)
|
||||
|
||||
def put(self, k, v):
|
||||
self.kv[k] = v
|
||||
|
||||
def delete(self, k):
|
||||
del self.kv[k]
|
||||
|
||||
KV_TYPE = 0
|
||||
BRANCH_TYPE = 1
|
||||
LEAF_TYPE = 2
|
||||
|
||||
b1 = bytes([1])
|
||||
b0 = bytes([0])
|
||||
|
||||
# Input: a serialized node
|
||||
def parse_node(node):
|
||||
if node[0] == BRANCH_TYPE:
|
||||
# Output: left child, right child, node type
|
||||
return node[1:33], node[33:], BRANCH_TYPE
|
||||
elif node[0] == KV_TYPE:
|
||||
# Output: keypath: child, node type
|
||||
return decode_bin_path(node[1:-32]), node[-32:], KV_TYPE
|
||||
elif node[0] == LEAF_TYPE:
|
||||
# Output: None, value, node type
|
||||
return None, node[1:], LEAF_TYPE
|
||||
else:
|
||||
raise Exception("Bad node")
|
||||
|
||||
# Serializes a key/value node
|
||||
def encode_kv_node(keypath, node):
|
||||
assert keypath
|
||||
assert len(node) == 32
|
||||
o = bytes([KV_TYPE]) + encode_bin_path(keypath) + node
|
||||
return o
|
||||
|
||||
# Serializes a branch node (ie. a node with 2 children)
|
||||
def encode_branch_node(left, right):
|
||||
assert len(left) == len(right) == 32
|
||||
return bytes([BRANCH_TYPE]) + left + right
|
||||
|
||||
# Serializes a leaf node
|
||||
def encode_leaf_node(value):
|
||||
return bytes([LEAF_TYPE]) + value
|
||||
|
||||
# Saves a value into the database and returns its hash
|
||||
def hash_and_save(db, node):
|
||||
h = sha3(node)
|
||||
db.put(h, node)
|
||||
return h
|
||||
|
||||
# Fetches the value with a given keypath from the given node
|
||||
def _get(db, node, keypath):
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
# Key-value node descend
|
||||
if nodetype == LEAF_TYPE:
|
||||
return R
|
||||
elif nodetype == KV_TYPE:
|
||||
# Keypath too short
|
||||
if not keypath:
|
||||
return None
|
||||
if keypath[:len(L)] == L:
|
||||
return _get(db, R, keypath[len(L):])
|
||||
else:
|
||||
return None
|
||||
# Branch node descend
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
# Keypath too short
|
||||
if not keypath:
|
||||
return None
|
||||
if keypath[:1] == b0:
|
||||
return _get(db, L, keypath[1:])
|
||||
else:
|
||||
return _get(db, R, keypath[1:])
|
||||
|
||||
# Updates the value at the given keypath from the given node
|
||||
def _update(db, node, keypath, val):
|
||||
# Empty trie
|
||||
if not node:
|
||||
if val:
|
||||
return hash_and_save(db, encode_kv_node(keypath, hash_and_save(db, encode_leaf_node(val))))
|
||||
else:
|
||||
return b''
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
# Node is a leaf node
|
||||
if nodetype == LEAF_TYPE:
|
||||
# Keypath must match, there should be no remaining keypath
|
||||
if keypath:
|
||||
raise Exception("Existing kv pair is being effaced because it's key is the prefix of the new key")
|
||||
return hash_and_save(db, encode_leaf_node(val)) if val else b''
|
||||
# node is a key-value node
|
||||
elif nodetype == KV_TYPE:
|
||||
# Keypath too short
|
||||
if not keypath:
|
||||
return node
|
||||
# Keypath prefixes match
|
||||
if keypath[:len(L)] == L:
|
||||
# Recurse into child
|
||||
o = _update(db, R, keypath[len(L):], val)
|
||||
# If child is empty
|
||||
if not o:
|
||||
return b''
|
||||
#print(db.get(o))
|
||||
subL, subR, subnodetype = parse_node(db.get(o))
|
||||
# If the child is a key-value node, compress together the keypaths
|
||||
# into one node
|
||||
if subnodetype == KV_TYPE:
|
||||
return hash_and_save(db, encode_kv_node(L + subL, subR))
|
||||
else:
|
||||
return hash_and_save(db, encode_kv_node(L, o)) if o else b''
|
||||
# Keypath prefixes don't match. Here we will be converting a key-value node
|
||||
# of the form (k, CHILD) into a structure of one of the following forms:
|
||||
# i. (k[:-1], (NEWCHILD, CHILD))
|
||||
# ii. (k[:-1], ((k2, NEWCHILD), CHILD))
|
||||
# iii. (k1, ((k2, CHILD), NEWCHILD))
|
||||
# iv. (k1, ((k2, CHILD), (k2', NEWCHILD))
|
||||
# v. (CHILD, NEWCHILD)
|
||||
# vi. ((k[1:], CHILD), (k', NEWCHILD))
|
||||
# vii. ((k[1:], CHILD), NEWCHILD)
|
||||
# viii (CHILD, (k[1:], NEWCHILD))
|
||||
else:
|
||||
cf = common_prefix_length(L, keypath[:len(L)])
|
||||
# New key-value pair can not contain empty value
|
||||
if not val:
|
||||
return node
|
||||
# valnode: the child node that has the new value we are adding
|
||||
# Case 1: keypath prefixes almost match, so we are in case (i), (ii), (v), (vi)
|
||||
if len(keypath) == cf + 1:
|
||||
valnode = hash_and_save(db, encode_leaf_node(val))
|
||||
# Case 2: keypath prefixes mismatch in the middle, so we need to break
|
||||
# the keypath in half. We are in case (iii), (iv), (vii), (viii)
|
||||
else:
|
||||
valnode = hash_and_save(db, encode_kv_node(keypath[cf+1:], hash_and_save(db, encode_leaf_node(val))))
|
||||
# oldnode: the child node the has the old child value
|
||||
# Case 1: (i), (iii), (v), (vi)
|
||||
if len(L) == cf + 1:
|
||||
oldnode = R
|
||||
# (ii), (iv), (vi), (viii)
|
||||
else:
|
||||
oldnode = hash_and_save(db, encode_kv_node(L[cf+1:], R))
|
||||
# Create the new branch node (because the key paths diverge, there has to
|
||||
# be some "first bit" at which they diverge, so there must be a branch
|
||||
# node somewhere)
|
||||
if keypath[cf:cf+1] == b1:
|
||||
newsub = hash_and_save(db, encode_branch_node(oldnode, valnode))
|
||||
else:
|
||||
newsub = hash_and_save(db, encode_branch_node(valnode, oldnode))
|
||||
# Case 1: keypath prefixes match in the first bit, so we still need
|
||||
# a kv node at the top
|
||||
# (i) (ii) (iii) (iv)
|
||||
if cf:
|
||||
return hash_and_save(db, encode_kv_node(L[:cf], newsub))
|
||||
# Case 2: keypath prefixes diverge in the first bit, so we replace the
|
||||
# kv node with a branch node
|
||||
# (v) (vi) (vii) (viii)
|
||||
else:
|
||||
return newsub
|
||||
# node is a branch node
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
# Keypath too short
|
||||
if not keypath:
|
||||
return node
|
||||
newL, newR = L, R
|
||||
# Which child node to update? Depends on first bit in keypath
|
||||
if keypath[:1] == b0:
|
||||
newL = _update(db, L, keypath[1:], val)
|
||||
else:
|
||||
newR = _update(db, R, keypath[1:], val)
|
||||
# Compress branch node into kv node
|
||||
if not newL or not newR:
|
||||
subL, subR, subnodetype = parse_node(db.get(newL or newR))
|
||||
first_bit = b1 if newR else b0
|
||||
# Compress (k1, (k2, NODE)) -> (k1 + k2, NODE)
|
||||
if subnodetype == KV_TYPE:
|
||||
return hash_and_save(db, encode_kv_node(first_bit + subL, subR))
|
||||
# kv node pointing to a branch node
|
||||
elif subnodetype == BRANCH_TYPE or subnodetype == LEAF_TYPE:
|
||||
return hash_and_save(db, encode_kv_node(first_bit, newL or newR))
|
||||
else:
|
||||
return hash_and_save(db, encode_branch_node(newL, newR))
|
||||
raise Exception("How did I get here?")
|
||||
|
||||
# Prints a tree, and checks that all invariants check out
|
||||
def print_and_check_invariants(db, node, prefix=b''):
|
||||
if node == b'' and prefix == b'':
|
||||
return {}
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
if nodetype == LEAF_TYPE:
|
||||
# All keys must be 160 bits
|
||||
assert len(prefix) == 160
|
||||
return {prefix: R}
|
||||
elif nodetype == KV_TYPE:
|
||||
# (k1, (k2, node)) two nested key values nodes not allowed
|
||||
assert 0 < len(L) <= 160 - len(prefix)
|
||||
if len(L) + len(prefix) < 160:
|
||||
subL, subR, subnodetype = parse_node(db.get(R))
|
||||
assert subnodetype != KV_TYPE
|
||||
# Childre of a key node cannot be empty
|
||||
assert subR != sha3(b'')
|
||||
return print_and_check_invariants(db, R, prefix + L)
|
||||
else:
|
||||
# Children of a branch node cannot be empty
|
||||
assert L != sha3(b'') and R != sha3(b'')
|
||||
o = {}
|
||||
o.update(print_and_check_invariants(db, L, prefix + b0))
|
||||
o.update(print_and_check_invariants(db, R, prefix + b1))
|
||||
return o
|
||||
|
||||
# Pretty-print all nodes in a tree (for debugging purposes)
|
||||
def print_nodes(db, node, prefix=b''):
|
||||
if node == b'':
|
||||
print('empty node')
|
||||
return
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
if nodetype == LEAF_TYPE:
|
||||
print('value node', encode_hex(node[:4]), R)
|
||||
elif nodetype == KV_TYPE:
|
||||
print(('kv node:', encode_hex(node[:4]), ''.join(['1' if x == 1 else '0' for x in L]), encode_hex(R[:4])))
|
||||
print_nodes(db, R, prefix + L)
|
||||
else:
|
||||
print(('branch node:', encode_hex(node[:4]), encode_hex(L[:4]), encode_hex(R[:4])))
|
||||
print_nodes(db, L, prefix + b0)
|
||||
print_nodes(db, R, prefix + b1)
|
||||
|
||||
# Get a long-format Merkle branch
|
||||
def _get_long_format_branch(db, node, keypath):
|
||||
if not keypath:
|
||||
return [db.get(node)]
|
||||
L, R, nodetype = parse_node(db.get(node))
|
||||
if nodetype == KV_TYPE:
|
||||
path = encode_bin_path(L)
|
||||
if keypath[:len(L)] == L:
|
||||
return [db.get(node)] + _get_long_format_branch(db, R, keypath[len(L):])
|
||||
else:
|
||||
return [db.get(node)]
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
if keypath[:1] == b0:
|
||||
return [db.get(node)] + _get_long_format_branch(db, L, keypath[1:])
|
||||
else:
|
||||
return [db.get(node)] + _get_long_format_branch(db, R, keypath[1:])
|
||||
|
||||
def _verify_long_format_branch(branch, root, keypath, value):
|
||||
db = EphemDB()
|
||||
db.kv = {sha3(node): node for node in branch}
|
||||
assert _get(db, root, keypath) == value
|
||||
return True
|
||||
|
||||
# Get full subtrie
|
||||
def _get_subtrie(db, node):
|
||||
dbnode = db.get(node)
|
||||
L, R, nodetype = parse_node(dbnode)
|
||||
if nodetype == KV_TYPE:
|
||||
return [dbnode] + _get_subtrie(db, R)
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
return [dbnode] + _get_subtrie(db, L) + _get_subtrie(db, R)
|
||||
elif nodetype == LEAF_TYPE:
|
||||
return [dbnode]
|
||||
|
||||
# Get witness for prefix
|
||||
def _get_prefix_witness(db, node, keypath):
|
||||
dbnode = db.get(node)
|
||||
if not keypath:
|
||||
return _get_subtrie(db, node)
|
||||
L, R, nodetype = parse_node(dbnode)
|
||||
if nodetype == KV_TYPE:
|
||||
path = encode_bin_path(L)
|
||||
if len(keypath) < len(L) and L[:len(keypath)] == keypath:
|
||||
return [dbnode] + _get_subtrie(db, R)
|
||||
if keypath[:len(L)] == L:
|
||||
return [dbnode] + _get_prefix_witness(db, R, keypath[len(L):])
|
||||
else:
|
||||
return [dbnode]
|
||||
elif nodetype == BRANCH_TYPE:
|
||||
if keypath[:1] == b0:
|
||||
return [dbnode] + _get_prefix_witness(db, L, keypath[1:])
|
||||
else:
|
||||
return [dbnode] + _get_prefix_witness(db, R, keypath[1:])
|
||||
|
||||
|
||||
# Trie wrapper class
|
||||
class Trie():
|
||||
def __init__(self, db, root):
|
||||
self.db = db
|
||||
self.root = root
|
||||
assert isinstance(self.root, bytes)
|
||||
|
||||
def get(self, key):
|
||||
#assert len(key) == 20
|
||||
return _get(self.db, self.root, encode_bin(key))
|
||||
|
||||
#def get_branch(self, key):
|
||||
# o = _get_branch(self.db, self.root, encode_bin(key))
|
||||
# assert _verify_branch(o, self.root, encode_bin(key), self.get(key))
|
||||
# return o
|
||||
|
||||
def get_long_format_branch(self, key):
|
||||
o = _get_long_format_branch(self.db, self.root, encode_bin(key))
|
||||
assert _verify_long_format_branch(o, self.root, encode_bin(key), self.get(key))
|
||||
return o
|
||||
|
||||
def get_prefix_witness(self, key):
|
||||
return _get_prefix_witness(self.db, self.root, encode_bin(key))
|
||||
|
||||
def update(self, key, value):
|
||||
#assert len(key) == 20
|
||||
self.root = _update(self.db, self.root, encode_bin(key), value)
|
||||
|
||||
def to_dict(self, hexify=False):
|
||||
o = print_and_check_invariants(self.db, self.root)
|
||||
encoder = lambda x: encode_hex(x) if hexify else x
|
||||
return {encoder(decode_bin(k)): v for k, v in o.items()}
|
||||
|
||||
def print_nodes(self):
|
||||
print_nodes(self.db, self.root)
|
|
@ -1,117 +0,0 @@
|
|||
from new_bintrie import b0, b1, KV_TYPE, BRANCH_TYPE, LEAF_TYPE, parse_node, encode_kv_node, encode_branch_node, encode_leaf_node
|
||||
from bin_utils import encode_bin_path, decode_bin_path, common_prefix_length, encode_bin, decode_bin
|
||||
from ethereum.utils import sha3 as _sha3, encode_hex
|
||||
|
||||
sha3_cache = {}
|
||||
|
||||
def sha3(x):
|
||||
if x not in sha3_cache:
|
||||
sha3_cache[x] = _sha3(x)
|
||||
return sha3_cache[x]
|
||||
|
||||
def quick_encode(nodes):
|
||||
o = b''
|
||||
for node in nodes:
|
||||
o += bytes([len(node) // 65536, len(node) // 256, len(node)]) + node
|
||||
return o
|
||||
|
||||
def quick_decode(nodedata):
|
||||
o = []
|
||||
pos = 0
|
||||
while pos < len(nodedata):
|
||||
L = nodedata[pos] * 65536 + nodedata[pos+1] * 256 + nodedata[pos+2]
|
||||
o.append(nodedata[pos+3: pos+3+L])
|
||||
pos += 3+L
|
||||
return o
|
||||
|
||||
|
||||
class WrapperDB():
|
||||
def __init__(self, parent_db):
|
||||
self.parent_db = parent_db
|
||||
self.substores = {}
|
||||
self.node_to_substore = {}
|
||||
self.new_nodes = {}
|
||||
self.parent_db_reads = 0
|
||||
self.parent_db_writes = 0
|
||||
self.printing_mode = False
|
||||
|
||||
# Loads a substore (RLP-encoded list of closeby trie nodes) from the DB
|
||||
def fetch_substore(self, key):
|
||||
substore_values = self.parent_db.get(key)
|
||||
assert substore_values is not None
|
||||
children = quick_decode(substore_values)
|
||||
self.parent_db_reads += 1
|
||||
self.substores[key] = {sha3(n): n for n in children}
|
||||
self.node_to_substore.update({sha3(n): key for n in children})
|
||||
assert key in self.node_to_substore and key in self.substores
|
||||
|
||||
def get(self, k):
|
||||
if k in self.new_nodes:
|
||||
return self.new_nodes[k]
|
||||
if k not in self.node_to_substore:
|
||||
self.fetch_substore(k)
|
||||
o = self.substores[self.node_to_substore[k]][k]
|
||||
assert sha3(o) == k
|
||||
return o
|
||||
|
||||
def put(self, k, v):
|
||||
if k not in self.new_nodes and k not in self.node_to_substore:
|
||||
self.new_nodes[k] = v
|
||||
|
||||
# Given a key, returns a collection of candidate nodes to form
|
||||
# a substore, as well as the children of that substore
|
||||
def get_substore_candidate_and_children(self, key, depth=5):
|
||||
if depth == 0:
|
||||
return [], [key]
|
||||
elif self.parent_db.get(key) is not None:
|
||||
return [], [key]
|
||||
else:
|
||||
node = self.get(key)
|
||||
L, R, nodetype = parse_node(node)
|
||||
if nodetype == BRANCH_TYPE:
|
||||
Ln, Lc = self.get_substore_candidate_and_children(L, depth-1)
|
||||
Rn, Rc = self.get_substore_candidate_and_children(R, depth-1)
|
||||
return [node] + Ln + Rn, Lc + Rc
|
||||
elif nodetype == KV_TYPE:
|
||||
Rn, Rc = self.get_substore_candidate_and_children(R, depth-1)
|
||||
return [node] + Rn, Rc
|
||||
elif nodetype == LEAF_TYPE:
|
||||
return [node], []
|
||||
|
||||
# Commits to the parent DB
|
||||
def commit(self):
|
||||
processed = {}
|
||||
assert_exists = {}
|
||||
for k, v in self.new_nodes.items():
|
||||
if k in processed:
|
||||
continue
|
||||
nodes, children = self.get_substore_candidate_and_children(k)
|
||||
if not nodes:
|
||||
continue
|
||||
assert k == sha3(nodes[0])
|
||||
for c in children:
|
||||
assert_exists[c] = True
|
||||
if c not in self.substores:
|
||||
self.fetch_substore(c)
|
||||
cvalues = list(self.substores[c].values())
|
||||
if len(quick_encode(cvalues + nodes)) < 3072:
|
||||
del self.substores[c]
|
||||
nodes.extend(cvalues)
|
||||
self.parent_db.put(k, quick_encode(nodes))
|
||||
self.parent_db_writes += 1
|
||||
self.substores[k] = {}
|
||||
for n in nodes:
|
||||
h = sha3(n)
|
||||
self.substores[k][h] = n
|
||||
self.node_to_substore[h] = k
|
||||
processed[h] = k
|
||||
for c in assert_exists:
|
||||
assert self.parent_db.get(c) is not None
|
||||
print('reads', self.parent_db_reads, 'writes', self.parent_db_writes)
|
||||
self.parent_db_reads = self.parent_db_writes = 0
|
||||
self.new_nodes = {}
|
||||
|
||||
def clear_cache(self):
|
||||
assert len(self.new_nodes) == 0
|
||||
self.substores = {}
|
||||
self.node_to_substore = {}
|
|
@ -1,52 +0,0 @@
|
|||
from new_bintrie import Trie, EphemDB, encode_bin, encode_bin_path, decode_bin_path
|
||||
from new_bintrie_aggregate import WrapperDB
|
||||
from ethereum.utils import sha3, encode_hex
|
||||
import random
|
||||
import rlp
|
||||
|
||||
def shuffle_in_place(x):
|
||||
y = x[::]
|
||||
random.shuffle(y)
|
||||
return y
|
||||
|
||||
kvpairs = [(sha3(str(i))[12:], str(i).encode('utf-8') * 5) for i in range(2000)]
|
||||
|
||||
|
||||
for path in ([], [1,0,1], [0,0,1,0], [1,0,0,1,0], [1,0,0,1,0,0,1,0], [1,0] * 8):
|
||||
assert decode_bin_path(encode_bin_path(bytes(path))) == bytes(path)
|
||||
|
||||
r1 = None
|
||||
|
||||
t = Trie(WrapperDB(EphemDB()), b'')
|
||||
for i, (k, v) in enumerate(shuffle_in_place(kvpairs)):
|
||||
#print(t.to_dict())
|
||||
t.update(k, v)
|
||||
assert t.get(k) == v
|
||||
if not i % 50:
|
||||
t.db.commit()
|
||||
assert t.db.parent_db.get(t.root) is not None
|
||||
if not i % 250:
|
||||
t.to_dict()
|
||||
print("Length of branch at %d nodes: %d" % (i, len(rlp.encode(t.get_branch(k)))))
|
||||
assert r1 is None or t.root == r1
|
||||
r1 = t.root
|
||||
t.update(kvpairs[0][0], kvpairs[0][1])
|
||||
assert t.root == r1
|
||||
print(t.get_branch(kvpairs[0][0]))
|
||||
print(t.get_branch(kvpairs[0][0][::-1]))
|
||||
print(encode_hex(t.root))
|
||||
t.db.commit()
|
||||
assert t.db.parent_db.get(t.root) is not None
|
||||
t.db.clear_cache()
|
||||
t.db.printing_mode = True
|
||||
for k, v in shuffle_in_place(kvpairs):
|
||||
t.get(k)
|
||||
t.db.clear_cache()
|
||||
print('Average DB reads: %.3f' % (t.db.parent_db_reads / len(kvpairs)))
|
||||
for k, v in shuffle_in_place(kvpairs):
|
||||
t.update(k, b'')
|
||||
if not random.randrange(100):
|
||||
t.to_dict()
|
||||
t.db.commit()
|
||||
#t.print_nodes()
|
||||
assert t.root == b''
|
|
@ -1,70 +0,0 @@
|
|||
from new_bintrie import Trie, EphemDB, encode_bin, encode_bin_path, decode_bin_path
|
||||
from ethereum.utils import sha3, encode_hex
|
||||
from compress_witness import compress, expand
|
||||
import random
|
||||
import rlp
|
||||
|
||||
def shuffle_in_place(x):
|
||||
y = x[::]
|
||||
random.shuffle(y)
|
||||
return y
|
||||
|
||||
kvpairs = [(sha3(str(i))[12:], str(i).encode('utf-8') * 5) for i in range(2000)]
|
||||
|
||||
|
||||
for path in ([], [1,0,1], [0,0,1,0], [1,0,0,1,0], [1,0,0,1,0,0,1,0], [1,0] * 8):
|
||||
assert decode_bin_path(encode_bin_path(bytes(path))) == bytes(path)
|
||||
|
||||
r1 = None
|
||||
|
||||
for _ in range(3):
|
||||
t = Trie(EphemDB(), b'')
|
||||
total_long_length, total_short_length = 0, 0
|
||||
for i, (k, v) in enumerate(shuffle_in_place(kvpairs)):
|
||||
#print(t.to_dict())
|
||||
t.update(k, v)
|
||||
assert t.get(k) == v
|
||||
if not i % 250:
|
||||
t.to_dict()
|
||||
b = t.get_long_format_branch(k)
|
||||
c = compress(b)
|
||||
b2 = expand(c)
|
||||
total_long_length += len(rlp.encode(b))
|
||||
total_short_length += len(rlp.encode(c))
|
||||
assert sorted(b2) == sorted(b), "Witness compression fails"
|
||||
if i % 50 == 49:
|
||||
print("Avg length of long-format branch at %d nodes: %d" % (i-24, total_long_length // 50))
|
||||
print("Avg length of compressed witness: %d" % (total_short_length // 50))
|
||||
total_long_length = 0
|
||||
total_short_length = 0
|
||||
print('Added 1000 values, doing checks')
|
||||
assert r1 is None or t.root == r1
|
||||
r1 = t.root
|
||||
t.update(kvpairs[0][0], kvpairs[0][1])
|
||||
assert t.root == r1
|
||||
print(encode_hex(t.root))
|
||||
print('Checking that single-key witnesses are the same as branches')
|
||||
for k, v in sorted(kvpairs):
|
||||
assert t.get_prefix_witness(k) == t.get_long_format_branch(k)
|
||||
print('Checking byte-wide witnesses')
|
||||
for _ in range(16):
|
||||
byte = random.randrange(256)
|
||||
witness = t.get_prefix_witness(bytearray([byte]))
|
||||
c = compress(witness)
|
||||
w2 = expand(c)
|
||||
assert sorted(w2) == sorted(witness), "Witness compression fails"
|
||||
print('Witness compression for prefix witnesses: %d original %d compressed' %
|
||||
(len(rlp.encode(witness)), len(rlp.encode(c))))
|
||||
subtrie = Trie(EphemDB({sha3(x): x for x in witness}), t.root)
|
||||
print('auditing byte', byte, 'with', len([k for k,v in kvpairs if k[0] == byte]), 'keys')
|
||||
for k, v in sorted(kvpairs):
|
||||
if k[0] == byte:
|
||||
assert subtrie.get(k) == v
|
||||
assert subtrie.get(bytearray([byte] + [0] * 19)) == None
|
||||
assert subtrie.get(bytearray([byte] + [255] * 19)) == None
|
||||
for k, v in shuffle_in_place(kvpairs):
|
||||
t.update(k, b'')
|
||||
if not random.randrange(100):
|
||||
t.to_dict()
|
||||
#t.print_nodes()
|
||||
assert t.root == b''
|
Loading…
Reference in New Issue