diff --git a/src/app/credExplorer/__snapshots__/basicPagerank.test.js.snap b/src/app/credExplorer/__snapshots__/basicPagerank.test.js.snap new file mode 100644 index 0000000..7b51a23 --- /dev/null +++ b/src/app/credExplorer/__snapshots__/basicPagerank.test.js.snap @@ -0,0 +1,13 @@ +// Jest Snapshot v1, https://goo.gl/fbAQLP + +exports[`graphToMarkovChain is correct for a trivial one-node chain 1`] = ` +Object { + "{\\"id\\":\\"who are you blah blah\\",\\"pluginName\\":\\"the magnificent foo plugin\\",\\"type\\":\\"irrelevant!\\"}": Object { + "inNeighbors": Object { + "{\\"id\\":\\"who are you blah blah\\",\\"pluginName\\":\\"the magnificent foo plugin\\",\\"type\\":\\"irrelevant!\\"}": Object { + "weight": 1, + }, + }, + }, +} +`; diff --git a/src/app/credExplorer/basicPagerank.js b/src/app/credExplorer/basicPagerank.js index 76119fa..883417d 100644 --- a/src/app/credExplorer/basicPagerank.js +++ b/src/app/credExplorer/basicPagerank.js @@ -1,84 +1,112 @@ // @flow -import * as tf from "@tensorflow/tfjs-core"; - import type {Address} from "../../core/address"; +import type {Edge} from "../../core/graph"; import {AddressMap} from "../../core/address"; import {Graph} from "../../core/graph"; -export type PagerankResult = AddressMap<{| +export type Distribution = AddressMap<{| +address: Address, +probability: number, |}>; +export type PagerankResult = Distribution; + +type MarkovChain = AddressMap<{| + +address: Address, + +inNeighbors: AddressMap<{| + +address: Address, + +weight: number, + |}>, +|}>; export default function basicPagerank(graph: Graph): PagerankResult { - return tf.tidy(() => { - const {nodes, markovChain} = graphToMarkovChain(graph); - const stationaryDistribution = findStationaryDistribution(markovChain); - const stationaryDistributionRaw = stationaryDistribution.dataSync(); - const result = new AddressMap(); - nodes.forEach((node, i) => { - result.add({ - address: node.address, - probability: stationaryDistributionRaw[i], - }); + return findStationaryDistribution(graphToMarkovChain(graph)); +} + +function edgeWeight( + _unused_edge: Edge +): {|+toWeight: number, +froWeight: number|} { + return {toWeight: 1, froWeight: 1}; +} + +export function graphToMarkovChain(graph: Graph): MarkovChain { + const result = new AddressMap(); + const unnormalizedTotalOutWeights = new AddressMap(); + + function initializeNode(address) { + if (result.get(address) != null) { + return; + } + const inNeighbors = new AddressMap(); + result.add({address, inNeighbors}); + const selfLoopEdgeWeight = 1e-3; + unnormalizedTotalOutWeights.add({address, weight: selfLoopEdgeWeight}); + graph.neighborhood(address).forEach(({neighbor}) => { + inNeighbors.add({address: neighbor, weight: 0}); }); - return result; - }); -} - -function graphToMarkovChain(graph: Graph) { - const nodes = graph.nodes(); // for canonical ordering - const addressToIndex = new AddressMap(); - nodes.forEach(({address}, index) => { - addressToIndex.add({address, index}); - }); - const buffer = tf.buffer([nodes.length, nodes.length]); - graph.edges().forEach(({src, dst, address}) => { - if (graph.node(src) == null) { - console.warn("Edge has dangling src:", address, src); - return; - } - if (graph.node(dst) == null) { - console.warn("Edge has dangling dst:", address, dst); - return; - } - const u = addressToIndex.get(src).index; - const v = addressToIndex.get(dst).index; - buffer.set(1, u, v); - buffer.set(1, v, u); - }); - return { - nodes, - markovChain: tf.tidy(() => { - const dampingFactor = 1e-4; - const raw = buffer.toTensor(); - const nonsingular = raw.add(tf.scalar(1e-9)); - const normalized = nonsingular.div(nonsingular.sum(1)); - const damped = tf.add( - normalized.mul(tf.scalar(1 - dampingFactor)), - tf.onesLike(normalized).mul(tf.scalar(dampingFactor / nodes.length)) - ); - return damped; - }), - }; -} - -function findStationaryDistribution(markovChain: $Call) { - const n = markovChain.shape[0]; - if (markovChain.shape.length !== 2 || markovChain.shape[1] !== n) { - throw new Error(`Expected square matrix; got: ${markovChain.shape}`); + inNeighbors.add({address: address, weight: selfLoopEdgeWeight}); } - let r0 = tf.tidy(() => tf.ones([n, 1]).div(tf.scalar(n))); + + graph.nodes().forEach(({address}) => { + initializeNode(address); + }); + graph.edges().forEach((edge) => { + const {src, dst} = edge; + initializeNode(src); + initializeNode(dst); + const {toWeight, froWeight} = edgeWeight(edge); + result.get(dst).inNeighbors.get(src).weight += toWeight; + result.get(src).inNeighbors.get(dst).weight += froWeight; + unnormalizedTotalOutWeights.get(src).weight += toWeight; + unnormalizedTotalOutWeights.get(dst).weight += froWeight; + }); + + // Normalize. + result.getAll().forEach(({inNeighbors}) => { + inNeighbors.getAll().forEach((entry) => { + entry.weight /= unnormalizedTotalOutWeights.get(entry.address).weight; + }); + }); + return result; +} + +function markovChainAction(mc: MarkovChain, pi: Distribution): Distribution { + const result = new AddressMap(); + mc.getAll().forEach(({address, inNeighbors}) => { + let probability = 0; + inNeighbors.getAll().forEach(({address: neighbor, weight}) => { + probability += pi.get(neighbor).probability * weight; + }); + result.add({address, probability}); + }); + return result; +} + +function uniformDistribution(addresses: $ReadOnlyArray
) { + const result = new AddressMap(); + const probability = 1.0 / addresses.length; + addresses.forEach((address) => { + result.add({address, probability}); + }); + return result; +} + +function findStationaryDistribution(mc: MarkovChain): Distribution { + let r0 = uniformDistribution(mc.getAll().map(({address}) => address)); function computeDelta(pi0, pi1) { - return tf.tidy(() => tf.max(tf.abs(pi0.sub(pi1))).dataSync()[0]); + return Math.max( + ...pi0 + .getAll() + .map(({address}) => + Math.abs(pi0.get(address).probability - pi1.get(address).probability) + ) + ); } let iteration = 0; while (true) { iteration++; - const r1 = tf.matMul(markovChain, r0); + const r1 = markovChainAction(mc, r0); const delta = computeDelta(r0, r1); - r0.dispose(); r0 = r1; console.log(`[${iteration}] delta = ${delta}`); if (delta < 1e-7) { diff --git a/src/app/credExplorer/basicPagerank.test.js b/src/app/credExplorer/basicPagerank.test.js new file mode 100644 index 0000000..9a5777c --- /dev/null +++ b/src/app/credExplorer/basicPagerank.test.js @@ -0,0 +1,19 @@ +// @flow + +import {Graph} from "../../core/graph"; +import {graphToMarkovChain} from "./basicPagerank"; + +describe("graphToMarkovChain", () => { + it("is correct for a trivial one-node chain", () => { + const g = new Graph(); + g.addNode({ + address: { + pluginName: "the magnificent foo plugin", + type: "irrelevant!", + id: "who are you blah blah", + }, + payload: "yes", + }); + expect(graphToMarkovChain(g)).toMatchSnapshot(); + }); +});