c-kzg-4844/bindings/python/multicombs.py
2022-09-19 19:56:43 +01:00

132 lines
6.2 KiB
Python

import random, time, sys, math
# For each subset in `subsets` (provided as a list of indices into `numbers`),
# compute the sum of that subset of `numbers`. More efficient than the naive method.
def multisubset(numbers, subsets, adder=lambda x,y: x+y, zero=0):
numbers = numbers[::]
subsets = {i: {x for x in subset} for i, subset in enumerate(subsets)}
output = [zero for _ in range(len(subsets))]
for roundcount in range(9999999):
# Compute counts of every pair of indices in the subset list
pair_count = {}
for index, subset in subsets.items():
for x in subset:
for y in subset:
if y > x:
pair_count[(x, y)] = pair_count.get((x, y), 0) + 1
# Determine pairs with highest count. The cutoff parameter [:len(numbers)]
# determines a tradeoff between group operation count and other forms of overhead
pairs_by_count = sorted([el for el in pair_count.keys()], key=lambda el: pair_count[el], reverse=True)[:len(numbers)*int(math.log(len(numbers)))]
# Exit condition: all subsets have size 1, no pairs
if len(pairs_by_count) == 0:
for key, subset in subsets.items():
for index in subset:
output[key] = adder(output[key], numbers[index])
return output
# In each of the highest-count pairs, take the sum of the numbers at those indices,
# and add the result as a new value, and modify `subsets` to include the new value
# wherever possible
used = set()
for maxx, maxy in pairs_by_count:
if maxx in used or maxy in used:
continue
used.add(maxx)
used.add(maxy)
numbers.append(adder(numbers[maxx], numbers[maxy]))
for key, subset in list(subsets.items()):
if maxx in subset and maxy in subset:
subset.remove(maxx)
subset.remove(maxy)
if not subset:
output[key] = numbers[-1]
del subsets[key]
else:
subset.add(len(numbers)-1)
# Alternative algorithm. Less optimal than the above, but much lower bit twiddling
# overhead and much simpler.
def multisubset2(numbers, subsets, adder=lambda x,y: x+y, zero=0):
# Split up the numbers into partitions
partition_size = 1 + int(math.log(len(subsets) + 1))
# Align number count to partition size (for simplicity)
numbers = numbers[::]
while len(numbers) % partition_size != 0:
numbers.append(zero)
# Compute power set for each partition (eg. a, b, c -> {0, a, b, a+b, c, a+c, b+c, a+b+c})
power_sets = []
for i in range(0, len(numbers), partition_size):
new_power_set = [zero]
for dimension, value in enumerate(numbers[i:i+partition_size]):
new_power_set += [adder(n, value) for n in new_power_set]
power_sets.append(new_power_set)
# Compute subset sums, using elements from power set for each range of values
# ie. with a single power set lookup you can get the sum of _all_ elements in
# the range partition_size*k...partition_size*(k+1) that are in that subset
subset_sums = []
for subset in subsets:
o = zero
for i in range(len(power_sets)):
index_in_power_set = 0
for j in range(partition_size):
if i * partition_size + j in subset:
index_in_power_set += 2 ** j
o = adder(o, power_sets[i][index_in_power_set])
subset_sums.append(o)
return subset_sums
# Reduces a linear combination `numbers[0] * factors[0] + numbers[1] * factors[1] + ...`
# into a multi-subset problem, and computes the result efficiently
def lincomb(numbers, factors, adder=lambda x,y: x+y, zero=0):
# Maximum bit length of a number; how many subsets we need to make
maxbitlen = max((len(bin(f))-2 for f in factors), default=0)
# Compute the subsets: the ith subset contains the numbers whose corresponding factor
# has a 1 at the ith bit
subsets = [{i for i in range(len(numbers)) if factors[i] & (1 << j)} for j in range(maxbitlen+1)]
subset_sums = multisubset2(numbers, subsets, adder=adder, zero=zero)
# For example, suppose a value V has factor 6 (011 in increasing-order binary). Subset 0
# will not have V, subset 1 will, and subset 2 will. So if we multiply the output of adding
# subset 0 with twice the output of adding subset 1, with four times the output of adding
# subset 2, then V will be represented 0 + 2 + 4 = 6 times. This reasoning applies for every
# value. So `subset_0_sum + 2 * subset_1_sum + 4 * subset_2_sum` gives us the result we want.
# Here, we compute this as `((subset_2_sum * 2) + subset_1_sum) * 2 + subset_0_sum` for
# efficiency: an extra `maxbitlen * 2` group operations.
o = zero
for i in range(len(subsets)-1, -1, -1):
o = adder(adder(o, o), subset_sums[i])
return o
# Tests go here
def make_mock_adder():
counter = [0]
def adder(x, y):
if x and y:
counter[0] += 1
return x+y
return adder, counter
def test_multisubset(numcount, setcount):
numbers = [random.randrange(10**20) for _ in range(numcount)]
subsets = [{i for i in range(numcount) if random.randrange(2)} for i in range(setcount)]
adder, counter = make_mock_adder()
o = multisubset(numbers, subsets, adder=adder)
for output, subset in zip(o, subsets):
assert output == sum([numbers[x] for x in subset])
def test_lincomb(numcount, bitlength=256):
numbers = [random.randrange(10**20) for _ in range(numcount)]
factors = [random.randrange(2**bitlength) for _ in range(numcount)]
adder, counter = make_mock_adder()
o = lincomb(numbers, factors, adder=adder)
assert o == sum([n*f for n,f in zip(numbers, factors)])
total_ones = sum(bin(f).count('1') for f in factors)
print("Naive operation count: %d" % (bitlength * numcount + total_ones))
print("Optimized operation count: %d" % (bitlength * 2 + counter[0]))
print("Optimization factor: %.2f" % ((bitlength * numcount + total_ones) / (bitlength * 2 + counter[0])))
if __name__ == '__main__':
test_lincomb(int(sys.argv[1]) if len(sys.argv) >= 2 else 80)