324 lines
10 KiB
C++
324 lines
10 KiB
C++
// (C) 2023 Doug Hoyte. MIT license
|
|
|
|
#pragma once
|
|
|
|
#include <string.h>
|
|
|
|
#include <string>
|
|
#include <string_view>
|
|
#include <vector>
|
|
#include <deque>
|
|
#include <unordered_set>
|
|
#include <limits>
|
|
#include <algorithm>
|
|
#include <stdexcept>
|
|
#include <optional>
|
|
#include <bit>
|
|
|
|
#include "negentropy/encoding.h"
|
|
#include "negentropy/types.h"
|
|
#include "negentropy/storage/base.h"
|
|
|
|
|
|
namespace negentropy {
|
|
|
|
const uint64_t PROTOCOL_VERSION = 0x61; // Version 1
|
|
|
|
const uint64_t MAX_U64 = std::numeric_limits<uint64_t>::max();
|
|
using err = std::runtime_error;
|
|
|
|
|
|
|
|
template<typename StorageImpl>
|
|
struct Negentropy {
|
|
StorageImpl &storage;
|
|
uint64_t frameSizeLimit;
|
|
|
|
bool isInitiator = false;
|
|
|
|
uint64_t lastTimestampIn = 0;
|
|
uint64_t lastTimestampOut = 0;
|
|
|
|
Negentropy(StorageImpl &storage, uint64_t frameSizeLimit = 0) : storage(storage), frameSizeLimit(frameSizeLimit) {
|
|
if (frameSizeLimit != 0 && frameSizeLimit < 4096) throw negentropy::err("frameSizeLimit too small");
|
|
}
|
|
|
|
std::string initiate() {
|
|
if (isInitiator) throw negentropy::err("already initiated");
|
|
isInitiator = true;
|
|
|
|
std::string output;
|
|
output.push_back(PROTOCOL_VERSION);
|
|
|
|
output += splitRange(0, storage.size(), Bound(MAX_U64));
|
|
|
|
return output;
|
|
}
|
|
|
|
void setInitiator() {
|
|
isInitiator = true;
|
|
}
|
|
|
|
std::string reconcile(std::string_view query) {
|
|
if (isInitiator) throw negentropy::err("initiator not asking for have/need IDs");
|
|
|
|
std::vector<std::string> haveIds, needIds;
|
|
return reconcileAux(query, haveIds, needIds);
|
|
}
|
|
|
|
std::optional<std::string> reconcile(std::string_view query, std::vector<std::string> &haveIds, std::vector<std::string> &needIds) {
|
|
if (!isInitiator) throw negentropy::err("non-initiator asking for have/need IDs");
|
|
|
|
auto output = reconcileAux(query, haveIds, needIds);
|
|
if (output.size() == 1) return std::nullopt;
|
|
return output;
|
|
}
|
|
|
|
private:
|
|
std::string reconcileAux(std::string_view query, std::vector<std::string> &haveIds, std::vector<std::string> &needIds) {
|
|
lastTimestampIn = lastTimestampOut = 0; // reset for each message
|
|
|
|
std::string fullOutput;
|
|
fullOutput.push_back(PROTOCOL_VERSION);
|
|
|
|
auto protocolVersion = getByte(query);
|
|
if (protocolVersion < 0x60 || protocolVersion > 0x6F) throw negentropy::err("invalid negentropy protocol version byte");
|
|
if (protocolVersion != PROTOCOL_VERSION) {
|
|
if (isInitiator) throw negentropy::err(std::string("unsupported negentropy protocol version requested") + std::to_string(protocolVersion - 0x60));
|
|
else return fullOutput;
|
|
}
|
|
|
|
uint64_t storageSize = storage.size();
|
|
Bound prevBound;
|
|
size_t prevIndex = 0;
|
|
bool skip = false;
|
|
|
|
while (query.size()) {
|
|
std::string o;
|
|
|
|
auto doSkip = [&]{
|
|
if (skip) {
|
|
skip = false;
|
|
o += encodeBound(prevBound);
|
|
o += encodeVarInt(uint64_t(Mode::Skip));
|
|
}
|
|
};
|
|
|
|
auto currBound = decodeBound(query);
|
|
auto mode = Mode(decodeVarInt(query));
|
|
|
|
auto lower = prevIndex;
|
|
auto upper = storage.findLowerBound(prevIndex, storageSize, currBound);
|
|
|
|
if (mode == Mode::Skip) {
|
|
skip = true;
|
|
} else if (mode == Mode::Fingerprint) {
|
|
auto theirFingerprint = getBytes(query, FINGERPRINT_SIZE);
|
|
auto ourFingerprint = storage.fingerprint(lower, upper);
|
|
|
|
if (theirFingerprint != ourFingerprint.sv()) {
|
|
doSkip();
|
|
o += splitRange(lower, upper, currBound);
|
|
} else {
|
|
skip = true;
|
|
}
|
|
} else if (mode == Mode::IdList) {
|
|
auto numIds = decodeVarInt(query);
|
|
|
|
std::unordered_set<std::string> theirElems;
|
|
for (uint64_t i = 0; i < numIds; i++) {
|
|
auto e = getBytes(query, ID_SIZE);
|
|
theirElems.insert(e);
|
|
}
|
|
|
|
storage.iterate(lower, upper, [&](const Item &item, size_t){
|
|
auto k = std::string(item.getId());
|
|
|
|
if (theirElems.find(k) == theirElems.end()) {
|
|
// ID exists on our side, but not their side
|
|
if (isInitiator) haveIds.emplace_back(k);
|
|
} else {
|
|
// ID exists on both sides
|
|
theirElems.erase(k);
|
|
}
|
|
|
|
return true;
|
|
});
|
|
|
|
if (isInitiator) {
|
|
skip = true;
|
|
|
|
for (const auto &k : theirElems) {
|
|
// ID exists on their side, but not our side
|
|
needIds.emplace_back(k);
|
|
}
|
|
} else {
|
|
doSkip();
|
|
|
|
std::string responseIds;
|
|
uint64_t numResponseIds = 0;
|
|
Bound endBound = currBound;
|
|
|
|
storage.iterate(lower, upper, [&](const Item &item, size_t index){
|
|
if (exceededFrameSizeLimit(fullOutput.size() + responseIds.size())) {
|
|
endBound = Bound(item);
|
|
upper = index; // shrink upper so that remaining range gets correct fingerprint
|
|
return false;
|
|
}
|
|
|
|
responseIds += item.getId();
|
|
numResponseIds++;
|
|
return true;
|
|
});
|
|
|
|
o += encodeBound(endBound);
|
|
o += encodeVarInt(uint64_t(Mode::IdList));
|
|
o += encodeVarInt(numResponseIds);
|
|
o += responseIds;
|
|
|
|
fullOutput += o;
|
|
o.clear();
|
|
}
|
|
} else {
|
|
throw negentropy::err("unexpected mode");
|
|
}
|
|
|
|
if (exceededFrameSizeLimit(fullOutput.size() + o.size())) {
|
|
// frameSizeLimit exceeded: Stop range processing and return a fingerprint for the remaining range
|
|
auto remainingFingerprint = storage.fingerprint(upper, storageSize);
|
|
|
|
fullOutput += encodeBound(Bound(MAX_U64));
|
|
fullOutput += encodeVarInt(uint64_t(Mode::Fingerprint));
|
|
fullOutput += remainingFingerprint.sv();
|
|
break;
|
|
} else {
|
|
fullOutput += o;
|
|
}
|
|
|
|
prevIndex = upper;
|
|
prevBound = currBound;
|
|
}
|
|
|
|
return fullOutput;
|
|
}
|
|
|
|
std::string splitRange(size_t lower, size_t upper, const Bound &upperBound) {
|
|
std::string o;
|
|
|
|
uint64_t numElems = upper - lower;
|
|
const uint64_t buckets = 16;
|
|
|
|
if (numElems < buckets * 2) {
|
|
o += encodeBound(upperBound);
|
|
o += encodeVarInt(uint64_t(Mode::IdList));
|
|
|
|
o += encodeVarInt(numElems);
|
|
storage.iterate(lower, upper, [&](const Item &item, size_t){
|
|
o += item.getId();
|
|
return true;
|
|
});
|
|
} else {
|
|
uint64_t itemsPerBucket = numElems / buckets;
|
|
uint64_t bucketsWithExtra = numElems % buckets;
|
|
auto curr = lower;
|
|
|
|
for (uint64_t i = 0; i < buckets; i++) {
|
|
auto bucketSize = itemsPerBucket + (i < bucketsWithExtra ? 1 : 0);
|
|
auto ourFingerprint = storage.fingerprint(curr, curr + bucketSize);
|
|
curr += bucketSize;
|
|
|
|
Bound nextBound;
|
|
|
|
if (curr == upper) {
|
|
nextBound = upperBound;
|
|
} else {
|
|
Item prevItem, currItem;
|
|
|
|
storage.iterate(curr - 1, curr + 1, [&](const Item &item, size_t index){
|
|
if (index == curr - 1) prevItem = item;
|
|
else currItem = item;
|
|
return true;
|
|
});
|
|
|
|
nextBound = getMinimalBound(prevItem, currItem);
|
|
}
|
|
|
|
o += encodeBound(nextBound);
|
|
o += encodeVarInt(uint64_t(Mode::Fingerprint));
|
|
o += ourFingerprint.sv();
|
|
}
|
|
}
|
|
|
|
return o;
|
|
}
|
|
|
|
bool exceededFrameSizeLimit(size_t n) {
|
|
return frameSizeLimit && n > frameSizeLimit - 200;
|
|
}
|
|
|
|
// Decoding
|
|
|
|
uint64_t decodeTimestampIn(std::string_view &encoded) {
|
|
uint64_t timestamp = decodeVarInt(encoded);
|
|
timestamp = timestamp == 0 ? MAX_U64 : timestamp - 1;
|
|
timestamp += lastTimestampIn;
|
|
if (timestamp < lastTimestampIn) timestamp = MAX_U64; // saturate
|
|
lastTimestampIn = timestamp;
|
|
return timestamp;
|
|
}
|
|
|
|
Bound decodeBound(std::string_view &encoded) {
|
|
auto timestamp = decodeTimestampIn(encoded);
|
|
auto len = decodeVarInt(encoded);
|
|
return Bound(timestamp, getBytes(encoded, len));
|
|
}
|
|
|
|
// Encoding
|
|
|
|
std::string encodeTimestampOut(uint64_t timestamp) {
|
|
if (timestamp == MAX_U64) {
|
|
lastTimestampOut = MAX_U64;
|
|
return encodeVarInt(0);
|
|
}
|
|
|
|
uint64_t temp = timestamp;
|
|
timestamp -= lastTimestampOut;
|
|
lastTimestampOut = temp;
|
|
return encodeVarInt(timestamp + 1);
|
|
};
|
|
|
|
std::string encodeBound(const Bound &bound) {
|
|
std::string output;
|
|
|
|
output += encodeTimestampOut(bound.item.timestamp);
|
|
output += encodeVarInt(bound.idLen);
|
|
output += bound.item.getId().substr(0, bound.idLen);
|
|
|
|
return output;
|
|
};
|
|
|
|
Bound getMinimalBound(const Item &prev, const Item &curr) {
|
|
if (curr.timestamp != prev.timestamp) {
|
|
return Bound(curr.timestamp);
|
|
} else {
|
|
uint64_t sharedPrefixBytes = 0;
|
|
auto currKey = curr.getId();
|
|
auto prevKey = prev.getId();
|
|
|
|
for (uint64_t i = 0; i < ID_SIZE; i++) {
|
|
if (currKey[i] != prevKey[i]) break;
|
|
sharedPrefixBytes++;
|
|
}
|
|
|
|
return Bound(curr.timestamp, currKey.substr(0, sharedPrefixBytes + 1));
|
|
}
|
|
}
|
|
};
|
|
|
|
|
|
}
|
|
|
|
|
|
template<typename T>
|
|
using Negentropy = negentropy::Negentropy<T>;
|