Add `Trie.getLast` (#646)
When using Tries, we often want the last matching entry for the given path, and to throw an error if one is not available. By adding this method to the API, we avoid a lot of unnecessary repetition in the code base. Test plan: Unit tests pass. As this touches the untested WeightConfig, I've also manually tested the weight config behavior.
This commit is contained in:
parent
00e2a67477
commit
f1c5d3756d
|
@ -53,43 +53,19 @@ export class StaticAdapterSet {
|
||||||
}
|
}
|
||||||
|
|
||||||
adapterMatchingNode(x: NodeAddressT): StaticPluginAdapter {
|
adapterMatchingNode(x: NodeAddressT): StaticPluginAdapter {
|
||||||
const adapters = this._adapterNodeTrie.get(x);
|
return this._adapterNodeTrie.getLast(x);
|
||||||
if (adapters.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
"Invariant violation: Fallback adapter matches all nodes"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return adapters[adapters.length - 1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
adapterMatchingEdge(x: EdgeAddressT): StaticPluginAdapter {
|
adapterMatchingEdge(x: EdgeAddressT): StaticPluginAdapter {
|
||||||
const adapters = this._adapterEdgeTrie.get(x);
|
return this._adapterEdgeTrie.getLast(x);
|
||||||
if (adapters.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
"Invariant violation: Fallback adapter matches all edges"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return adapters[adapters.length - 1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
typeMatchingNode(x: NodeAddressT): NodeType {
|
typeMatchingNode(x: NodeAddressT): NodeType {
|
||||||
const types = this._typeNodeTrie.get(x);
|
return this._typeNodeTrie.getLast(x);
|
||||||
if (types.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
"Invariant violation: Fallback adapter's type matches all nodes"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return types[types.length - 1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
typeMatchingEdge(x: EdgeAddressT): EdgeType {
|
typeMatchingEdge(x: EdgeAddressT): EdgeType {
|
||||||
const types = this._typeEdgeTrie.get(x);
|
return this._typeEdgeTrie.getLast(x);
|
||||||
if (types.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
"Invariant violation: Fallback adapter's type matches all edges"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return types[types.length - 1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
load(repo: Repo): Promise<DynamicAdapterSet> {
|
load(repo: Repo): Promise<DynamicAdapterSet> {
|
||||||
|
@ -120,23 +96,11 @@ export class DynamicAdapterSet {
|
||||||
}
|
}
|
||||||
|
|
||||||
adapterMatchingNode(x: NodeAddressT): DynamicPluginAdapter {
|
adapterMatchingNode(x: NodeAddressT): DynamicPluginAdapter {
|
||||||
const adapters = this._adapterNodeTrie.get(x);
|
return this._adapterNodeTrie.getLast(x);
|
||||||
if (adapters.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
"Invariant violation: Fallback adapter matches all nodes"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return adapters[adapters.length - 1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
adapterMatchingEdge(x: EdgeAddressT): DynamicPluginAdapter {
|
adapterMatchingEdge(x: EdgeAddressT): DynamicPluginAdapter {
|
||||||
const adapters = this._adapterEdgeTrie.get(x);
|
return this._adapterEdgeTrie.getLast(x);
|
||||||
if (adapters.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
"Invariant violation: Fallback adapter matches all edges"
|
|
||||||
);
|
|
||||||
}
|
|
||||||
return adapters[adapters.length - 1];
|
|
||||||
}
|
}
|
||||||
|
|
||||||
adapters(): $ReadOnlyArray<DynamicPluginAdapter> {
|
adapters(): $ReadOnlyArray<DynamicPluginAdapter> {
|
||||||
|
|
|
@ -19,10 +19,7 @@ export function byEdgeType(
|
||||||
trie.add(weightedPrefix.prefix, weightedPrefix);
|
trie.add(weightedPrefix.prefix, weightedPrefix);
|
||||||
}
|
}
|
||||||
return function evaluator(edge: Edge) {
|
return function evaluator(edge: Edge) {
|
||||||
const matchingPrefixes = trie.get(edge.address);
|
const {weight, directionality} = trie.getLast(edge.address);
|
||||||
const {weight, directionality} = matchingPrefixes[
|
|
||||||
matchingPrefixes.length - 1
|
|
||||||
];
|
|
||||||
return {
|
return {
|
||||||
toWeight: directionality * weight,
|
toWeight: directionality * weight,
|
||||||
froWeight: (1 - directionality) * weight,
|
froWeight: (1 - directionality) * weight,
|
||||||
|
@ -42,10 +39,8 @@ export function byNodeType(
|
||||||
trie.add(weightedPrefix.prefix, weightedPrefix);
|
trie.add(weightedPrefix.prefix, weightedPrefix);
|
||||||
}
|
}
|
||||||
return function evaluator(edge: Edge) {
|
return function evaluator(edge: Edge) {
|
||||||
const srcPrefixes = trie.get(edge.src);
|
const srcDatum = trie.getLast(edge.src);
|
||||||
const srcDatum = srcPrefixes[srcPrefixes.length - 1];
|
const dstDatum = trie.getLast(edge.dst);
|
||||||
const dstPrefixes = trie.get(edge.dst);
|
|
||||||
const dstDatum = dstPrefixes[dstPrefixes.length - 1];
|
|
||||||
|
|
||||||
const baseResult = base(edge);
|
const baseResult = base(edge);
|
||||||
return {
|
return {
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
// @flow
|
// @flow
|
||||||
|
|
||||||
import sortBy from "lodash.sortby";
|
import sortBy from "lodash.sortby";
|
||||||
import {NodeAddress, edgeToString} from "../../../core/graph";
|
|
||||||
import {NodeTrie, EdgeTrie} from "../../../core/trie";
|
import {NodeTrie, EdgeTrie} from "../../../core/trie";
|
||||||
import type {NodeType, EdgeType} from "../../adapters/pluginAdapter";
|
import type {NodeType, EdgeType} from "../../adapters/pluginAdapter";
|
||||||
import type {ScoredConnection} from "../../../core/attribution/pagerankNodeDecomposition";
|
import type {ScoredConnection} from "../../../core/attribution/pagerankNodeDecomposition";
|
||||||
|
@ -52,13 +51,7 @@ export function aggregateByNodeType(
|
||||||
}
|
}
|
||||||
const nodeTypeToConnections = new Map();
|
const nodeTypeToConnections = new Map();
|
||||||
for (const x of xs) {
|
for (const x of xs) {
|
||||||
const types = typeTrie.get(x.source);
|
const type = typeTrie.getLast(x.source);
|
||||||
if (types.length === 0) {
|
|
||||||
throw new Error(
|
|
||||||
`No matching NodeType for ${NodeAddress.toString(x.source)}`
|
|
||||||
);
|
|
||||||
}
|
|
||||||
const type = types[types.length - 1];
|
|
||||||
const connections = nodeTypeToConnections.get(type) || [];
|
const connections = nodeTypeToConnections.get(type) || [];
|
||||||
if (connections.length === 0) {
|
if (connections.length === 0) {
|
||||||
nodeTypeToConnections.set(type, connections);
|
nodeTypeToConnections.set(type, connections);
|
||||||
|
@ -114,11 +107,7 @@ export function aggregateByConnectionType(
|
||||||
throw new Error((x.connection.adjacency.type: empty));
|
throw new Error((x.connection.adjacency.type: empty));
|
||||||
}
|
}
|
||||||
const edge = x.connection.adjacency.edge;
|
const edge = x.connection.adjacency.edge;
|
||||||
const types = typeTrie.get(edge.address);
|
const type = typeTrie.getLast(edge.address);
|
||||||
if (types.length === 0) {
|
|
||||||
throw new Error(`No matching EdgeType for edge ${edgeToString(edge)}`);
|
|
||||||
}
|
|
||||||
const type = types[types.length - 1];
|
|
||||||
const connections = relevantMap.get(type) || [];
|
const connections = relevantMap.get(type) || [];
|
||||||
if (connections.length === 0) {
|
if (connections.length === 0) {
|
||||||
relevantMap.set(type, connections);
|
relevantMap.set(type, connections);
|
||||||
|
|
|
@ -207,7 +207,7 @@ describe("app/credExplorer/aggregate", () => {
|
||||||
it("errors if any connection has no matching type", () => {
|
it("errors if any connection has no matching type", () => {
|
||||||
const {scoredConnectionsArray} = example();
|
const {scoredConnectionsArray} = example();
|
||||||
const shouldFail = () => aggregateByNodeType(scoredConnectionsArray, []);
|
const shouldFail = () => aggregateByNodeType(scoredConnectionsArray, []);
|
||||||
expect(shouldFail).toThrowError("No matching NodeType");
|
expect(shouldFail).toThrowError("no matching entry");
|
||||||
});
|
});
|
||||||
it("sorts the aggregations by total score", () => {
|
it("sorts the aggregations by total score", () => {
|
||||||
let lastSeenScore = Infinity;
|
let lastSeenScore = Infinity;
|
||||||
|
@ -324,7 +324,7 @@ describe("app/credExplorer/aggregate", () => {
|
||||||
const {scoredConnectionsArray, nodeTypesArray} = example();
|
const {scoredConnectionsArray, nodeTypesArray} = example();
|
||||||
const shouldFail = () =>
|
const shouldFail = () =>
|
||||||
aggregateByConnectionType(scoredConnectionsArray, nodeTypesArray, []);
|
aggregateByConnectionType(scoredConnectionsArray, nodeTypesArray, []);
|
||||||
expect(shouldFail).toThrowError("No matching EdgeType");
|
expect(shouldFail).toThrowError("no matching entry");
|
||||||
});
|
});
|
||||||
it("sorts the aggregations by total score", () => {
|
it("sorts the aggregations by total score", () => {
|
||||||
let lastSeenScore = Infinity;
|
let lastSeenScore = Infinity;
|
||||||
|
|
|
@ -84,6 +84,18 @@ class BaseTrie<K, V> {
|
||||||
}
|
}
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Get the last stored value `v` in the path to key `k`.
|
||||||
|
* Throws an error if no value is available.
|
||||||
|
*/
|
||||||
|
getLast(k: K): V {
|
||||||
|
const path = this.get(k);
|
||||||
|
if (path.length === 0) {
|
||||||
|
throw new Error("Tried to getLast, but no matching entry existed");
|
||||||
|
}
|
||||||
|
return path[path.length - 1];
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
export class NodeTrie<V> extends BaseTrie<NodeAddressT, V> {
|
export class NodeTrie<V> extends BaseTrie<NodeAddressT, V> {
|
||||||
|
|
|
@ -82,6 +82,18 @@ describe("core/trie", () => {
|
||||||
expect(x.get(fooBarZod)).toEqual([0, 2, 3]);
|
expect(x.get(fooBarZod)).toEqual([0, 2, 3]);
|
||||||
});
|
});
|
||||||
|
|
||||||
|
it("getLast gets the last available value", () => {
|
||||||
|
const x = new NodeTrie()
|
||||||
|
.add(foo, 2)
|
||||||
|
.add(fooBar, 3)
|
||||||
|
.add(empty, 0);
|
||||||
|
expect(x.getLast(fooBarZod)).toEqual(3);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("getLast throws an error if no value is available", () => {
|
||||||
|
expect(() => new NodeTrie().getLast(foo)).toThrowError("no matching entry");
|
||||||
|
});
|
||||||
|
|
||||||
it("overwriting a value is illegal", () => {
|
it("overwriting a value is illegal", () => {
|
||||||
expect(() =>
|
expect(() =>
|
||||||
new NodeTrie()
|
new NodeTrie()
|
||||||
|
|
Loading…
Reference in New Issue