Add `Trie.getLast` (#646)
When using Tries, we often want the last matching entry for the given path, and to throw an error if one is not available. By adding this method to the API, we avoid a lot of unnecessary repetition in the code base. Test plan: Unit tests pass. As this touches the untested WeightConfig, I've also manually tested the weight config behavior.
This commit is contained in:
parent
00e2a67477
commit
f1c5d3756d
|
@ -53,43 +53,19 @@ export class StaticAdapterSet {
|
|||
}
|
||||
|
||||
adapterMatchingNode(x: NodeAddressT): StaticPluginAdapter {
|
||||
const adapters = this._adapterNodeTrie.get(x);
|
||||
if (adapters.length === 0) {
|
||||
throw new Error(
|
||||
"Invariant violation: Fallback adapter matches all nodes"
|
||||
);
|
||||
}
|
||||
return adapters[adapters.length - 1];
|
||||
return this._adapterNodeTrie.getLast(x);
|
||||
}
|
||||
|
||||
adapterMatchingEdge(x: EdgeAddressT): StaticPluginAdapter {
|
||||
const adapters = this._adapterEdgeTrie.get(x);
|
||||
if (adapters.length === 0) {
|
||||
throw new Error(
|
||||
"Invariant violation: Fallback adapter matches all edges"
|
||||
);
|
||||
}
|
||||
return adapters[adapters.length - 1];
|
||||
return this._adapterEdgeTrie.getLast(x);
|
||||
}
|
||||
|
||||
typeMatchingNode(x: NodeAddressT): NodeType {
|
||||
const types = this._typeNodeTrie.get(x);
|
||||
if (types.length === 0) {
|
||||
throw new Error(
|
||||
"Invariant violation: Fallback adapter's type matches all nodes"
|
||||
);
|
||||
}
|
||||
return types[types.length - 1];
|
||||
return this._typeNodeTrie.getLast(x);
|
||||
}
|
||||
|
||||
typeMatchingEdge(x: EdgeAddressT): EdgeType {
|
||||
const types = this._typeEdgeTrie.get(x);
|
||||
if (types.length === 0) {
|
||||
throw new Error(
|
||||
"Invariant violation: Fallback adapter's type matches all edges"
|
||||
);
|
||||
}
|
||||
return types[types.length - 1];
|
||||
return this._typeEdgeTrie.getLast(x);
|
||||
}
|
||||
|
||||
load(repo: Repo): Promise<DynamicAdapterSet> {
|
||||
|
@ -120,23 +96,11 @@ export class DynamicAdapterSet {
|
|||
}
|
||||
|
||||
adapterMatchingNode(x: NodeAddressT): DynamicPluginAdapter {
|
||||
const adapters = this._adapterNodeTrie.get(x);
|
||||
if (adapters.length === 0) {
|
||||
throw new Error(
|
||||
"Invariant violation: Fallback adapter matches all nodes"
|
||||
);
|
||||
}
|
||||
return adapters[adapters.length - 1];
|
||||
return this._adapterNodeTrie.getLast(x);
|
||||
}
|
||||
|
||||
adapterMatchingEdge(x: EdgeAddressT): DynamicPluginAdapter {
|
||||
const adapters = this._adapterEdgeTrie.get(x);
|
||||
if (adapters.length === 0) {
|
||||
throw new Error(
|
||||
"Invariant violation: Fallback adapter matches all edges"
|
||||
);
|
||||
}
|
||||
return adapters[adapters.length - 1];
|
||||
return this._adapterEdgeTrie.getLast(x);
|
||||
}
|
||||
|
||||
adapters(): $ReadOnlyArray<DynamicPluginAdapter> {
|
||||
|
|
|
@ -19,10 +19,7 @@ export function byEdgeType(
|
|||
trie.add(weightedPrefix.prefix, weightedPrefix);
|
||||
}
|
||||
return function evaluator(edge: Edge) {
|
||||
const matchingPrefixes = trie.get(edge.address);
|
||||
const {weight, directionality} = matchingPrefixes[
|
||||
matchingPrefixes.length - 1
|
||||
];
|
||||
const {weight, directionality} = trie.getLast(edge.address);
|
||||
return {
|
||||
toWeight: directionality * weight,
|
||||
froWeight: (1 - directionality) * weight,
|
||||
|
@ -42,10 +39,8 @@ export function byNodeType(
|
|||
trie.add(weightedPrefix.prefix, weightedPrefix);
|
||||
}
|
||||
return function evaluator(edge: Edge) {
|
||||
const srcPrefixes = trie.get(edge.src);
|
||||
const srcDatum = srcPrefixes[srcPrefixes.length - 1];
|
||||
const dstPrefixes = trie.get(edge.dst);
|
||||
const dstDatum = dstPrefixes[dstPrefixes.length - 1];
|
||||
const srcDatum = trie.getLast(edge.src);
|
||||
const dstDatum = trie.getLast(edge.dst);
|
||||
|
||||
const baseResult = base(edge);
|
||||
return {
|
||||
|
|
|
@ -1,7 +1,6 @@
|
|||
// @flow
|
||||
|
||||
import sortBy from "lodash.sortby";
|
||||
import {NodeAddress, edgeToString} from "../../../core/graph";
|
||||
import {NodeTrie, EdgeTrie} from "../../../core/trie";
|
||||
import type {NodeType, EdgeType} from "../../adapters/pluginAdapter";
|
||||
import type {ScoredConnection} from "../../../core/attribution/pagerankNodeDecomposition";
|
||||
|
@ -52,13 +51,7 @@ export function aggregateByNodeType(
|
|||
}
|
||||
const nodeTypeToConnections = new Map();
|
||||
for (const x of xs) {
|
||||
const types = typeTrie.get(x.source);
|
||||
if (types.length === 0) {
|
||||
throw new Error(
|
||||
`No matching NodeType for ${NodeAddress.toString(x.source)}`
|
||||
);
|
||||
}
|
||||
const type = types[types.length - 1];
|
||||
const type = typeTrie.getLast(x.source);
|
||||
const connections = nodeTypeToConnections.get(type) || [];
|
||||
if (connections.length === 0) {
|
||||
nodeTypeToConnections.set(type, connections);
|
||||
|
@ -114,11 +107,7 @@ export function aggregateByConnectionType(
|
|||
throw new Error((x.connection.adjacency.type: empty));
|
||||
}
|
||||
const edge = x.connection.adjacency.edge;
|
||||
const types = typeTrie.get(edge.address);
|
||||
if (types.length === 0) {
|
||||
throw new Error(`No matching EdgeType for edge ${edgeToString(edge)}`);
|
||||
}
|
||||
const type = types[types.length - 1];
|
||||
const type = typeTrie.getLast(edge.address);
|
||||
const connections = relevantMap.get(type) || [];
|
||||
if (connections.length === 0) {
|
||||
relevantMap.set(type, connections);
|
||||
|
|
|
@ -207,7 +207,7 @@ describe("app/credExplorer/aggregate", () => {
|
|||
it("errors if any connection has no matching type", () => {
|
||||
const {scoredConnectionsArray} = example();
|
||||
const shouldFail = () => aggregateByNodeType(scoredConnectionsArray, []);
|
||||
expect(shouldFail).toThrowError("No matching NodeType");
|
||||
expect(shouldFail).toThrowError("no matching entry");
|
||||
});
|
||||
it("sorts the aggregations by total score", () => {
|
||||
let lastSeenScore = Infinity;
|
||||
|
@ -324,7 +324,7 @@ describe("app/credExplorer/aggregate", () => {
|
|||
const {scoredConnectionsArray, nodeTypesArray} = example();
|
||||
const shouldFail = () =>
|
||||
aggregateByConnectionType(scoredConnectionsArray, nodeTypesArray, []);
|
||||
expect(shouldFail).toThrowError("No matching EdgeType");
|
||||
expect(shouldFail).toThrowError("no matching entry");
|
||||
});
|
||||
it("sorts the aggregations by total score", () => {
|
||||
let lastSeenScore = Infinity;
|
||||
|
|
|
@ -84,6 +84,18 @@ class BaseTrie<K, V> {
|
|||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get the last stored value `v` in the path to key `k`.
|
||||
* Throws an error if no value is available.
|
||||
*/
|
||||
getLast(k: K): V {
|
||||
const path = this.get(k);
|
||||
if (path.length === 0) {
|
||||
throw new Error("Tried to getLast, but no matching entry existed");
|
||||
}
|
||||
return path[path.length - 1];
|
||||
}
|
||||
}
|
||||
|
||||
export class NodeTrie<V> extends BaseTrie<NodeAddressT, V> {
|
||||
|
|
|
@ -82,6 +82,18 @@ describe("core/trie", () => {
|
|||
expect(x.get(fooBarZod)).toEqual([0, 2, 3]);
|
||||
});
|
||||
|
||||
it("getLast gets the last available value", () => {
|
||||
const x = new NodeTrie()
|
||||
.add(foo, 2)
|
||||
.add(fooBar, 3)
|
||||
.add(empty, 0);
|
||||
expect(x.getLast(fooBarZod)).toEqual(3);
|
||||
});
|
||||
|
||||
it("getLast throws an error if no value is available", () => {
|
||||
expect(() => new NodeTrie().getLast(foo)).toThrowError("no matching entry");
|
||||
});
|
||||
|
||||
it("overwriting a value is illegal", () => {
|
||||
expect(() =>
|
||||
new NodeTrie()
|
||||
|
|
Loading…
Reference in New Issue