Use `SparseMarkovChain` in `basicPagerank` (#273)

Summary:
This commit slightly reorganizes the internals of `basicPagerank` to use
the `SparseMarkovChain` type from the `markovChain` module.

Test Plan:
Behavior of `yarn start` is unchanged.

wchargin-branch: use-sparsemarkovchain
This commit is contained in:
William Chargin 2018-05-11 21:28:58 -07:00 committed by GitHub
parent e5472752ac
commit 017fbd774a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 45 additions and 50 deletions

View File

@ -2,7 +2,7 @@
exports[`graphToMarkovChain is correct for a trivial one-node chain 1`] = ` exports[`graphToMarkovChain is correct for a trivial one-node chain 1`] = `
Object { Object {
"inNeighbors": Array [ "chain": Array [
Object { Object {
"neighbor": Uint32Array [ "neighbor": Uint32Array [
0, 0,

View File

@ -5,10 +5,11 @@ import type {Edge} from "../../core/graph";
import {AddressMap} from "../../core/address"; import {AddressMap} from "../../core/address";
import {Graph} from "../../core/graph"; import {Graph} from "../../core/graph";
export type Distribution = {| import type {
+nodeOrder: $ReadOnlyArray<Address>, Distribution,
+data: Float64Array, SparseMarkovChain,
|}; } from "../../core/attribution/markovChain";
export type PagerankResult = AddressMap<{| export type PagerankResult = AddressMap<{|
+address: Address, +address: Address,
+probability: number, +probability: number,
@ -22,18 +23,15 @@ type AddressMapMarkovChain = AddressMap<{|
|}>, |}>,
|}>; |}>;
type TypedArrayMarkovChain = {| type OrderedSparseMarkovChain = {|
+nodeOrder: $ReadOnlyArray<Address>, +nodeOrder: $ReadOnlyArray<Address>,
+inNeighbors: $ReadOnlyArray<{| +chain: SparseMarkovChain,
+neighbor: Uint32Array,
+weight: Float64Array,
|}>,
|}; |};
export default function basicPagerank(graph: Graph<any, any>): PagerankResult { export default function basicPagerank(graph: Graph<any, any>): PagerankResult {
return distributionToPagerankResult( const {nodeOrder, chain} = graphToOrderedSparseMarkovChain(graph);
findStationaryDistribution(graphToTypedArrayMarkovChain(graph)) const pi = findStationaryDistribution(chain);
); return distributionToPagerankResult(nodeOrder, pi);
} }
function edgeWeight( function edgeWeight(
@ -85,20 +83,20 @@ function graphToAddressMapMarkovChain(
return result; return result;
} }
function addressMapMarkovChainToTypedArrayMarkovChain( function addressMapMarkovChainToOrderedSparseMarkovChain(
mc: AddressMapMarkovChain chain: AddressMapMarkovChain
): TypedArrayMarkovChain { ): OrderedSparseMarkovChain {
// The node ordering is arbitrary, but must be made canonical: calls // The node ordering is arbitrary, but must be made canonical: calls
// to `graph.nodes()` are not guaranteed to be stable. // to `graph.nodes()` are not guaranteed to be stable.
const nodeOrder = mc.getAll().map(({address}) => address); const nodeOrder = chain.getAll().map(({address}) => address);
const addressToIndex = new AddressMap(); const addressToIndex = new AddressMap();
nodeOrder.forEach((address, index) => { nodeOrder.forEach((address, index) => {
addressToIndex.add({address, index}); addressToIndex.add({address, index});
}); });
return { return {
nodeOrder, nodeOrder,
inNeighbors: nodeOrder.map((address) => { chain: nodeOrder.map((address) => {
const theseNeighbors = mc.get(address).inNeighbors.getAll(); const theseNeighbors = chain.get(address).inNeighbors.getAll();
return { return {
neighbor: new Uint32Array( neighbor: new Uint32Array(
theseNeighbors.map(({address}) => addressToIndex.get(address).index) theseNeighbors.map(({address}) => addressToIndex.get(address).index)
@ -109,52 +107,46 @@ function addressMapMarkovChainToTypedArrayMarkovChain(
}; };
} }
export function graphToTypedArrayMarkovChain( export function graphToOrderedSparseMarkovChain(
graph: Graph<any, any> graph: Graph<any, any>
): TypedArrayMarkovChain { ): OrderedSparseMarkovChain {
return addressMapMarkovChainToTypedArrayMarkovChain( return addressMapMarkovChainToOrderedSparseMarkovChain(
graphToAddressMapMarkovChain(graph) graphToAddressMapMarkovChain(graph)
); );
} }
function markovChainAction( function sparseMarkovChainAction(
mc: TypedArrayMarkovChain, chain: SparseMarkovChain,
pi: Distribution pi: Distribution
): Distribution { ): Distribution {
const data = new Float64Array(pi.data.length); const result = new Float64Array(pi.length);
for (let dst = 0; dst < mc.nodeOrder.length; dst++) { chain.forEach(({neighbor, weight}, dst) => {
const theseNeighbors = mc.inNeighbors[dst]; const inDegree = neighbor.length; // (also `weight.length`)
const inDegree = theseNeighbors.neighbor.length;
let probability = 0; let probability = 0;
for (let srcIndex = 0; srcIndex < inDegree; srcIndex++) { for (let i = 0; i < inDegree; i++) {
const src = theseNeighbors.neighbor[srcIndex]; const src = neighbor[i];
probability += pi.data[src] * theseNeighbors.weight[srcIndex]; probability += pi[src] * weight[i];
} }
data[dst] = probability; result[dst] = probability;
} });
return {nodeOrder: pi.nodeOrder, data}; return result;
} }
function uniformDistribution(nodeOrder: $ReadOnlyArray<Address>): Distribution { function uniformDistribution(n: number): Distribution {
return { return new Float64Array(n).fill(1 / n);
nodeOrder,
data: new Float64Array(
Array(nodeOrder.length).fill(1.0 / nodeOrder.length)
),
};
} }
function findStationaryDistribution(mc: TypedArrayMarkovChain): Distribution { function findStationaryDistribution(chain: SparseMarkovChain): Distribution {
let r0 = uniformDistribution(mc.nodeOrder); let r0 = uniformDistribution(chain.length);
function computeDelta(pi0, pi1) { function computeDelta(pi0, pi1) {
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the // Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
// same (i.e., there has been no permutation). // same (i.e., there has been no permutation).
return Math.max(...pi0.data.map((x, i) => Math.abs(x - pi1.data[i]))); return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i])));
} }
let iteration = 0; let iteration = 0;
while (true) { while (true) {
iteration++; iteration++;
const r1 = markovChainAction(mc, r0); const r1 = sparseMarkovChainAction(chain, r0);
const delta = computeDelta(r0, r1); const delta = computeDelta(r0, r1);
r0 = r1; r0 = r1;
console.log(`[${iteration}] delta = ${delta}`); console.log(`[${iteration}] delta = ${delta}`);
@ -172,10 +164,13 @@ function findStationaryDistribution(mc: TypedArrayMarkovChain): Distribution {
throw new Error("Unreachable."); throw new Error("Unreachable.");
} }
function distributionToPagerankResult(pi: Distribution): PagerankResult { function distributionToPagerankResult(
nodeOrder: $ReadOnlyArray<Address>,
pi: Distribution
): PagerankResult {
const result = new AddressMap(); const result = new AddressMap();
pi.nodeOrder.forEach((address, i) => { nodeOrder.forEach((address, i) => {
const probability = pi.data[i]; const probability = pi[i];
result.add({address, probability}); result.add({address, probability});
}); });
return result; return result;

View File

@ -1,7 +1,7 @@
// @flow // @flow
import {Graph} from "../../core/graph"; import {Graph} from "../../core/graph";
import {graphToTypedArrayMarkovChain} from "./basicPagerank"; import {graphToOrderedSparseMarkovChain} from "./basicPagerank";
describe("graphToMarkovChain", () => { describe("graphToMarkovChain", () => {
it("is correct for a trivial one-node chain", () => { it("is correct for a trivial one-node chain", () => {
@ -14,6 +14,6 @@ describe("graphToMarkovChain", () => {
}, },
payload: "yes", payload: "yes",
}); });
expect(graphToTypedArrayMarkovChain(g)).toMatchSnapshot(); expect(graphToOrderedSparseMarkovChain(g)).toMatchSnapshot();
}); });
}); });