Rewrite basic PageRank without TFJS (#266)
Summary: We’re not convinced that using TFJS at this time is worth it, for two reasons. First, our matrix computations can be expressed using sparse matrices, which will improve the performance by orders of magnitude. Sparse matrices do not appear to be supported by TFJS (the layers API makes some use of them, but it is not clear that they have much support their, either). Second, having to deal with GPU memory and WebGL has already been problematic. When WebGL PageRank is running, the machine is mostly unusable, and other applications’ video output is negatively affected (!). This commit rewrites the internals of `basicPagerank.js` while retaining its end-to-end public interface. We also add a test file with a trivial test. The resulting code is not faster yet—in fact, it’s a fair amount slower. But this is because our use of `AddressMap`s puts JSON stringification on the critical path, which is obviously a bad idea. In a subsequent commit, we will rewrite the internals again to use typed arrays. Paired with @decentralion. Test Plan: The new unit test is not sufficient. Instead, run `yarn start` and re-run PageRank on SourceCred; note that the results are roughly unchanged. wchargin-branch: pagerank-without-tfjs
This commit is contained in:
parent
2a52ff85f8
commit
7e97ba6bf3
|
@ -0,0 +1,13 @@
|
|||
// Jest Snapshot v1, https://goo.gl/fbAQLP
|
||||
|
||||
exports[`graphToMarkovChain is correct for a trivial one-node chain 1`] = `
|
||||
Object {
|
||||
"{\\"id\\":\\"who are you blah blah\\",\\"pluginName\\":\\"the magnificent foo plugin\\",\\"type\\":\\"irrelevant!\\"}": Object {
|
||||
"inNeighbors": Object {
|
||||
"{\\"id\\":\\"who are you blah blah\\",\\"pluginName\\":\\"the magnificent foo plugin\\",\\"type\\":\\"irrelevant!\\"}": Object {
|
||||
"weight": 1,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
`;
|
|
@ -1,84 +1,112 @@
|
|||
// @flow
|
||||
|
||||
import * as tf from "@tensorflow/tfjs-core";
|
||||
|
||||
import type {Address} from "../../core/address";
|
||||
import type {Edge} from "../../core/graph";
|
||||
import {AddressMap} from "../../core/address";
|
||||
import {Graph} from "../../core/graph";
|
||||
|
||||
export type PagerankResult = AddressMap<{|
|
||||
export type Distribution = AddressMap<{|
|
||||
+address: Address,
|
||||
+probability: number,
|
||||
|}>;
|
||||
export type PagerankResult = Distribution;
|
||||
|
||||
type MarkovChain = AddressMap<{|
|
||||
+address: Address,
|
||||
+inNeighbors: AddressMap<{|
|
||||
+address: Address,
|
||||
+weight: number,
|
||||
|}>,
|
||||
|}>;
|
||||
|
||||
export default function basicPagerank(graph: Graph<any, any>): PagerankResult {
|
||||
return tf.tidy(() => {
|
||||
const {nodes, markovChain} = graphToMarkovChain(graph);
|
||||
const stationaryDistribution = findStationaryDistribution(markovChain);
|
||||
const stationaryDistributionRaw = stationaryDistribution.dataSync();
|
||||
const result = new AddressMap();
|
||||
nodes.forEach((node, i) => {
|
||||
result.add({
|
||||
address: node.address,
|
||||
probability: stationaryDistributionRaw[i],
|
||||
});
|
||||
return findStationaryDistribution(graphToMarkovChain(graph));
|
||||
}
|
||||
|
||||
function edgeWeight(
|
||||
_unused_edge: Edge<any>
|
||||
): {|+toWeight: number, +froWeight: number|} {
|
||||
return {toWeight: 1, froWeight: 1};
|
||||
}
|
||||
|
||||
export function graphToMarkovChain(graph: Graph<any, any>): MarkovChain {
|
||||
const result = new AddressMap();
|
||||
const unnormalizedTotalOutWeights = new AddressMap();
|
||||
|
||||
function initializeNode(address) {
|
||||
if (result.get(address) != null) {
|
||||
return;
|
||||
}
|
||||
const inNeighbors = new AddressMap();
|
||||
result.add({address, inNeighbors});
|
||||
const selfLoopEdgeWeight = 1e-3;
|
||||
unnormalizedTotalOutWeights.add({address, weight: selfLoopEdgeWeight});
|
||||
graph.neighborhood(address).forEach(({neighbor}) => {
|
||||
inNeighbors.add({address: neighbor, weight: 0});
|
||||
});
|
||||
return result;
|
||||
});
|
||||
}
|
||||
|
||||
function graphToMarkovChain(graph: Graph<any, any>) {
|
||||
const nodes = graph.nodes(); // for canonical ordering
|
||||
const addressToIndex = new AddressMap();
|
||||
nodes.forEach(({address}, index) => {
|
||||
addressToIndex.add({address, index});
|
||||
});
|
||||
const buffer = tf.buffer([nodes.length, nodes.length]);
|
||||
graph.edges().forEach(({src, dst, address}) => {
|
||||
if (graph.node(src) == null) {
|
||||
console.warn("Edge has dangling src:", address, src);
|
||||
return;
|
||||
}
|
||||
if (graph.node(dst) == null) {
|
||||
console.warn("Edge has dangling dst:", address, dst);
|
||||
return;
|
||||
}
|
||||
const u = addressToIndex.get(src).index;
|
||||
const v = addressToIndex.get(dst).index;
|
||||
buffer.set(1, u, v);
|
||||
buffer.set(1, v, u);
|
||||
});
|
||||
return {
|
||||
nodes,
|
||||
markovChain: tf.tidy(() => {
|
||||
const dampingFactor = 1e-4;
|
||||
const raw = buffer.toTensor();
|
||||
const nonsingular = raw.add(tf.scalar(1e-9));
|
||||
const normalized = nonsingular.div(nonsingular.sum(1));
|
||||
const damped = tf.add(
|
||||
normalized.mul(tf.scalar(1 - dampingFactor)),
|
||||
tf.onesLike(normalized).mul(tf.scalar(dampingFactor / nodes.length))
|
||||
);
|
||||
return damped;
|
||||
}),
|
||||
};
|
||||
}
|
||||
|
||||
function findStationaryDistribution(markovChain: $Call<tf.tensor2d>) {
|
||||
const n = markovChain.shape[0];
|
||||
if (markovChain.shape.length !== 2 || markovChain.shape[1] !== n) {
|
||||
throw new Error(`Expected square matrix; got: ${markovChain.shape}`);
|
||||
inNeighbors.add({address: address, weight: selfLoopEdgeWeight});
|
||||
}
|
||||
let r0 = tf.tidy(() => tf.ones([n, 1]).div(tf.scalar(n)));
|
||||
|
||||
graph.nodes().forEach(({address}) => {
|
||||
initializeNode(address);
|
||||
});
|
||||
graph.edges().forEach((edge) => {
|
||||
const {src, dst} = edge;
|
||||
initializeNode(src);
|
||||
initializeNode(dst);
|
||||
const {toWeight, froWeight} = edgeWeight(edge);
|
||||
result.get(dst).inNeighbors.get(src).weight += toWeight;
|
||||
result.get(src).inNeighbors.get(dst).weight += froWeight;
|
||||
unnormalizedTotalOutWeights.get(src).weight += toWeight;
|
||||
unnormalizedTotalOutWeights.get(dst).weight += froWeight;
|
||||
});
|
||||
|
||||
// Normalize.
|
||||
result.getAll().forEach(({inNeighbors}) => {
|
||||
inNeighbors.getAll().forEach((entry) => {
|
||||
entry.weight /= unnormalizedTotalOutWeights.get(entry.address).weight;
|
||||
});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
function markovChainAction(mc: MarkovChain, pi: Distribution): Distribution {
|
||||
const result = new AddressMap();
|
||||
mc.getAll().forEach(({address, inNeighbors}) => {
|
||||
let probability = 0;
|
||||
inNeighbors.getAll().forEach(({address: neighbor, weight}) => {
|
||||
probability += pi.get(neighbor).probability * weight;
|
||||
});
|
||||
result.add({address, probability});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
function uniformDistribution(addresses: $ReadOnlyArray<Address>) {
|
||||
const result = new AddressMap();
|
||||
const probability = 1.0 / addresses.length;
|
||||
addresses.forEach((address) => {
|
||||
result.add({address, probability});
|
||||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
function findStationaryDistribution(mc: MarkovChain): Distribution {
|
||||
let r0 = uniformDistribution(mc.getAll().map(({address}) => address));
|
||||
function computeDelta(pi0, pi1) {
|
||||
return tf.tidy(() => tf.max(tf.abs(pi0.sub(pi1))).dataSync()[0]);
|
||||
return Math.max(
|
||||
...pi0
|
||||
.getAll()
|
||||
.map(({address}) =>
|
||||
Math.abs(pi0.get(address).probability - pi1.get(address).probability)
|
||||
)
|
||||
);
|
||||
}
|
||||
let iteration = 0;
|
||||
while (true) {
|
||||
iteration++;
|
||||
const r1 = tf.matMul(markovChain, r0);
|
||||
const r1 = markovChainAction(mc, r0);
|
||||
const delta = computeDelta(r0, r1);
|
||||
r0.dispose();
|
||||
r0 = r1;
|
||||
console.log(`[${iteration}] delta = ${delta}`);
|
||||
if (delta < 1e-7) {
|
||||
|
|
|
@ -0,0 +1,19 @@
|
|||
// @flow
|
||||
|
||||
import {Graph} from "../../core/graph";
|
||||
import {graphToMarkovChain} from "./basicPagerank";
|
||||
|
||||
describe("graphToMarkovChain", () => {
|
||||
it("is correct for a trivial one-node chain", () => {
|
||||
const g = new Graph();
|
||||
g.addNode({
|
||||
address: {
|
||||
pluginName: "the magnificent foo plugin",
|
||||
type: "irrelevant!",
|
||||
id: "who are you blah blah",
|
||||
},
|
||||
payload: "yes",
|
||||
});
|
||||
expect(graphToMarkovChain(g)).toMatchSnapshot();
|
||||
});
|
||||
});
|
Loading…
Reference in New Issue