diff --git a/src/app/credExplorer/basicPagerank.js b/src/app/credExplorer/basicPagerank.js index d08178f..55dbc3b 100644 --- a/src/app/credExplorer/basicPagerank.js +++ b/src/app/credExplorer/basicPagerank.js @@ -9,10 +9,7 @@ import type { Distribution, SparseMarkovChain, } from "../../core/attribution/markovChain"; -import { - sparseMarkovChainAction, - uniformDistribution, -} from "../../core/attribution/markovChain"; +import {findStationaryDistribution} from "../../core/attribution/markovChain"; export type PagerankResult = AddressMap<{| +address: Address, @@ -119,53 +116,6 @@ export function graphToOrderedSparseMarkovChain( ); } -function findStationaryDistribution( - chain: SparseMarkovChain, - options?: {| - +verbose?: boolean, - +convergenceThreshold?: number, - +maxIterations?: number, - |} -): Distribution { - const fullOptions = { - verbose: false, - convergenceThreshold: 1e-7, - maxIterations: 255, - ...(options || {}), - }; - let r0 = uniformDistribution(chain.length); - function computeDelta(pi0, pi1) { - // Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the - // same (i.e., there has been no permutation). - return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i]))); - } - let iteration = 0; - while (true) { - iteration++; - const r1 = sparseMarkovChainAction(chain, r0); - const delta = computeDelta(r0, r1); - r0 = r1; - if (fullOptions.verbose) { - console.log(`[${iteration}] delta = ${delta}`); - } - if (delta < fullOptions.convergenceThreshold) { - if (fullOptions.verbose) { - console.log(`[${iteration}] CONVERGED`); - } - return r0; - } - if (iteration >= fullOptions.maxIterations) { - if (fullOptions.verbose) { - console.log(`[${iteration}] FAILED to converge`); - } - return r0; - } - } - // ESLint knows that this next line is unreachable, but Flow doesn't. :-) - // eslint-disable-next-line no-unreachable - throw new Error("Unreachable."); -} - function distributionToPagerankResult( nodeOrder: $ReadOnlyArray
, pi: Distribution diff --git a/src/core/attribution/markovChain.js b/src/core/attribution/markovChain.js index a798790..812012b 100644 --- a/src/core/attribution/markovChain.js +++ b/src/core/attribution/markovChain.js @@ -89,3 +89,50 @@ export function sparseMarkovChainAction( }); return result; } + +export function findStationaryDistribution( + chain: SparseMarkovChain, + options?: {| + +verbose?: boolean, + +convergenceThreshold?: number, + +maxIterations?: number, + |} +): Distribution { + const fullOptions = { + verbose: false, + convergenceThreshold: 1e-7, + maxIterations: 255, + ...(options || {}), + }; + let r0 = uniformDistribution(chain.length); + function computeDelta(pi0, pi1) { + // Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the + // same (i.e., there has been no permutation). + return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i]))); + } + let iteration = 0; + while (true) { + iteration++; + const r1 = sparseMarkovChainAction(chain, r0); + const delta = computeDelta(r0, r1); + r0 = r1; + if (fullOptions.verbose) { + console.log(`[${iteration}] delta = ${delta}`); + } + if (delta < fullOptions.convergenceThreshold) { + if (fullOptions.verbose) { + console.log(`[${iteration}] CONVERGED`); + } + return r0; + } + if (iteration >= fullOptions.maxIterations) { + if (fullOptions.verbose) { + console.log(`[${iteration}] FAILED to converge`); + } + return r0; + } + } + // ESLint knows that this next line is unreachable, but Flow doesn't. :-) + // eslint-disable-next-line no-unreachable + throw new Error("Unreachable."); +} diff --git a/src/core/attribution/markovChain.test.js b/src/core/attribution/markovChain.test.js index 168e64f..83ee4bf 100644 --- a/src/core/attribution/markovChain.test.js +++ b/src/core/attribution/markovChain.test.js @@ -1,6 +1,8 @@ // @flow +import type {Distribution, SparseMarkovChain} from "./markovChain"; import { + findStationaryDistribution, sparseMarkovChainAction, sparseMarkovChainFromTransitionMatrix, uniformDistribution, @@ -120,3 +122,61 @@ describe("sparseMarkovChainAction", () => { expect(pi1).toEqual(expected); }); }); + +function expectAllClose( + actual: Float64Array, + expected: Float64Array, + epsilon: number = 1e-6 +): void { + expect(actual).toHaveLength(expected.length); + for (let i = 0; i < expected.length; i++) { + if (Math.abs(actual[i] - expected[i]) >= epsilon) { + expect(actual).toEqual(expected); // will fail + return; + } + } +} + +function expectStationary(chain: SparseMarkovChain, pi: Distribution): void { + expectAllClose(sparseMarkovChainAction(chain, pi), pi); +} + +describe("findStationaryDistribution", () => { + it("finds an all-accumulating stationary distribution", () => { + const chain = sparseMarkovChainFromTransitionMatrix([ + [1, 0, 0], + [0.25, 0, 0.75], + [0.25, 0.75, 0], + ]); + const pi = findStationaryDistribution(chain); + expectStationary(chain, pi); + const expected = new Float64Array([1, 0, 0]); + expectAllClose(pi, expected); + }); + + it("finds a non-degenerate stationary distribution", () => { + // Node 0 is the "center"; nodes 1 through 4 are "satellites". A + // satellite transitions to the center with probability 0.5, or to a + // cyclically adjacent satellite with probability 0.25 each. The + // center transitions to a uniformly random satellite. + const chain = sparseMarkovChainFromTransitionMatrix([ + [0, 0.25, 0.25, 0.25, 0.25], + [0.5, 0, 0.25, 0, 0.25], + [0.5, 0.25, 0, 0.25, 0], + [0.5, 0, 0.25, 0, 0.25], + [0.5, 0.25, 0, 0.25, 0], + ]); + const pi = findStationaryDistribution(chain); + expectStationary(chain, pi); + const expected = new Float64Array([1 / 3, 1 / 6, 1 / 6, 1 / 6, 1 / 6]); + expectAllClose(pi, expected); + }); + + it("finds the stationary distribution of a periodic chain", () => { + const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]); + const pi = findStationaryDistribution(chain); + expectStationary(chain, pi); + const expected = new Float64Array([0.5, 0.5]); + expectAllClose(pi, expected); + }); +});