Implement basic PageRank analysis (#252)

Summary:
We don’t expect the results to be of good quality right now. Rather,
this gives us a starting point from which to iterate the algorithm.

The convergence criterion also needs to be adjusted. (In particular, it
should almost certainly not be a constant.)

Test Plan:
Run `yarn start`. Select a graph, like `sourcecred/example-github`. Open
the JS console and click “Run basic PageRank”. Watch the console.

wchargin-branch: basic-pagerank
This commit is contained in:
William Chargin 2018-05-10 11:21:18 -07:00 committed by GitHub
parent 8e4668cc91
commit 61d3cb3f52
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 140 additions and 0 deletions

View File

@ -1,9 +1,12 @@
// @flow // @flow
import stringify from "json-stable-stringify";
import React from "react"; import React from "react";
import {StyleSheet, css} from "aphrodite/no-important"; import {StyleSheet, css} from "aphrodite/no-important";
import type {PagerankResult} from "./basicPagerank";
import {Graph} from "../../core/graph"; import {Graph} from "../../core/graph";
import basicPagerank from "./basicPagerank";
type Props = {}; type Props = {};
type State = { type State = {
@ -62,6 +65,19 @@ export default class App extends React.Component<Props, State> {
) : ( ) : (
<p>Graph not loaded.</p> <p>Graph not loaded.</p>
)} )}
<button
disabled={graph == null}
onClick={() => {
setTimeout(() => {
if (graph != null) {
const pagerankResult = basicPagerank(graph);
this.analyzePagerankResult(pagerankResult);
}
}, 0);
}}
>
Run basic PageRank (results in console)
</button>
</div> </div>
</div> </div>
); );
@ -88,6 +104,34 @@ export default class App extends React.Component<Props, State> {
console.error("Error while fetching:", e); console.error("Error while fetching:", e);
}); });
} }
analyzePagerankResult(pagerankResult: PagerankResult) {
const addressKey = ({pluginName, type}) => stringify({pluginName, type});
const addressesByKey = {};
pagerankResult.getAll().forEach(({address}) => {
if (addressesByKey[addressKey(address)] === undefined) {
addressesByKey[addressKey(address)] = [];
}
addressesByKey[addressKey(address)].push(address);
});
Object.keys(addressesByKey).forEach((key) => {
addressesByKey[key] = addressesByKey[key]
.slice()
.sort((x, y) => {
const px = pagerankResult.get(x).probability;
const py = pagerankResult.get(y).probability;
return px - py;
})
.reverse();
const {pluginName, type} = JSON.parse(key);
console.log(`%c${type} (${pluginName})`, "font-weight: bold");
addressesByKey[key].slice(0, 5).forEach((address) => {
const score = pagerankResult.get(address).probability;
const name = address.id;
console.log(` - [${score.toString()}] ${name}`);
});
});
}
} }
const styles = StyleSheet.create({ const styles = StyleSheet.create({

View File

@ -0,0 +1,96 @@
// @flow
import * as tf from "@tensorflow/tfjs-core";
import type {Address} from "../../core/address";
import {AddressMap} from "../../core/address";
import {Graph} from "../../core/graph";
export type PagerankResult = AddressMap<{|
+address: Address,
+probability: number,
|}>;
export default function basicPagerank(graph: Graph<any, any>): PagerankResult {
return tf.tidy(() => {
const {nodes, markovChain} = graphToMarkovChain(graph);
const stationaryDistribution = findStationaryDistribution(markovChain);
const stationaryDistributionRaw = stationaryDistribution.dataSync();
const result = new AddressMap();
nodes.forEach((node, i) => {
result.add({
address: node.address,
probability: stationaryDistributionRaw[i],
});
});
return result;
});
}
function graphToMarkovChain(graph: Graph<any, any>) {
const nodes = graph.nodes(); // for canonical ordering
const addressToIndex = new AddressMap();
nodes.forEach(({address}, index) => {
addressToIndex.add({address, index});
});
const buffer = tf.buffer([nodes.length, nodes.length]);
graph.edges().forEach(({src, dst, address}) => {
if (graph.node(src) == null) {
console.warn("Edge has dangling src:", address, src);
return;
}
if (graph.node(dst) == null) {
console.warn("Edge has dangling dst:", address, dst);
return;
}
const u = addressToIndex.get(src).index;
const v = addressToIndex.get(dst).index;
buffer.set(1, u, v);
buffer.set(1, v, u);
});
return {
nodes,
markovChain: tf.tidy(() => {
const dampingFactor = 1e-4;
const raw = buffer.toTensor();
const nonsingular = raw.add(tf.scalar(1e-9));
const normalized = nonsingular.div(nonsingular.sum(1));
const damped = tf.add(
normalized.mul(tf.scalar(1 - dampingFactor)),
tf.onesLike(normalized).mul(tf.scalar(dampingFactor / nodes.length))
);
return damped;
}),
};
}
function findStationaryDistribution(markovChain: $Call<tf.tensor2d>) {
const n = markovChain.shape[0];
if (markovChain.shape.length !== 2 || markovChain.shape[1] !== n) {
throw new Error(`Expected square matrix; got: ${markovChain.shape}`);
}
let r0 = tf.tidy(() => tf.ones([n, 1]).div(tf.scalar(n)));
function computeDelta(pi0, pi1) {
return tf.tidy(() => tf.max(tf.abs(pi0.sub(pi1))).dataSync()[0]);
}
let iteration = 0;
while (true) {
iteration++;
const r1 = tf.matMul(markovChain, r0);
const delta = computeDelta(r0, r1);
r0.dispose();
r0 = r1;
console.log(`[${iteration}] delta = ${delta}`);
if (delta < 1e-7) {
console.log(`[${iteration}] CONVERGED`);
return r0;
}
if (iteration >= 255) {
console.log(`[${iteration}] FAILED to converge`);
return r0;
}
}
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
// eslint-disable-next-line no-unreachable
throw new Error("Unreachable.");
}