Extract `findStationaryDistribution` (#277)

Test Plan:
Unit tests added. Run `yarn test`.

wchargin-branch: extract-findStationaryDistribution
This commit is contained in:
William Chargin 2018-05-11 21:56:52 -07:00 committed by GitHub
parent 9d7f9f78cd
commit 115d7f3921
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 108 additions and 51 deletions

View File

@ -9,10 +9,7 @@ import type {
Distribution, Distribution,
SparseMarkovChain, SparseMarkovChain,
} from "../../core/attribution/markovChain"; } from "../../core/attribution/markovChain";
import { import {findStationaryDistribution} from "../../core/attribution/markovChain";
sparseMarkovChainAction,
uniformDistribution,
} from "../../core/attribution/markovChain";
export type PagerankResult = AddressMap<{| export type PagerankResult = AddressMap<{|
+address: Address, +address: Address,
@ -119,53 +116,6 @@ export function graphToOrderedSparseMarkovChain(
); );
} }
function findStationaryDistribution(
chain: SparseMarkovChain,
options?: {|
+verbose?: boolean,
+convergenceThreshold?: number,
+maxIterations?: number,
|}
): Distribution {
const fullOptions = {
verbose: false,
convergenceThreshold: 1e-7,
maxIterations: 255,
...(options || {}),
};
let r0 = uniformDistribution(chain.length);
function computeDelta(pi0, pi1) {
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
// same (i.e., there has been no permutation).
return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i])));
}
let iteration = 0;
while (true) {
iteration++;
const r1 = sparseMarkovChainAction(chain, r0);
const delta = computeDelta(r0, r1);
r0 = r1;
if (fullOptions.verbose) {
console.log(`[${iteration}] delta = ${delta}`);
}
if (delta < fullOptions.convergenceThreshold) {
if (fullOptions.verbose) {
console.log(`[${iteration}] CONVERGED`);
}
return r0;
}
if (iteration >= fullOptions.maxIterations) {
if (fullOptions.verbose) {
console.log(`[${iteration}] FAILED to converge`);
}
return r0;
}
}
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
// eslint-disable-next-line no-unreachable
throw new Error("Unreachable.");
}
function distributionToPagerankResult( function distributionToPagerankResult(
nodeOrder: $ReadOnlyArray<Address>, nodeOrder: $ReadOnlyArray<Address>,
pi: Distribution pi: Distribution

View File

@ -89,3 +89,50 @@ export function sparseMarkovChainAction(
}); });
return result; return result;
} }
export function findStationaryDistribution(
chain: SparseMarkovChain,
options?: {|
+verbose?: boolean,
+convergenceThreshold?: number,
+maxIterations?: number,
|}
): Distribution {
const fullOptions = {
verbose: false,
convergenceThreshold: 1e-7,
maxIterations: 255,
...(options || {}),
};
let r0 = uniformDistribution(chain.length);
function computeDelta(pi0, pi1) {
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
// same (i.e., there has been no permutation).
return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i])));
}
let iteration = 0;
while (true) {
iteration++;
const r1 = sparseMarkovChainAction(chain, r0);
const delta = computeDelta(r0, r1);
r0 = r1;
if (fullOptions.verbose) {
console.log(`[${iteration}] delta = ${delta}`);
}
if (delta < fullOptions.convergenceThreshold) {
if (fullOptions.verbose) {
console.log(`[${iteration}] CONVERGED`);
}
return r0;
}
if (iteration >= fullOptions.maxIterations) {
if (fullOptions.verbose) {
console.log(`[${iteration}] FAILED to converge`);
}
return r0;
}
}
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
// eslint-disable-next-line no-unreachable
throw new Error("Unreachable.");
}

View File

@ -1,6 +1,8 @@
// @flow // @flow
import type {Distribution, SparseMarkovChain} from "./markovChain";
import { import {
findStationaryDistribution,
sparseMarkovChainAction, sparseMarkovChainAction,
sparseMarkovChainFromTransitionMatrix, sparseMarkovChainFromTransitionMatrix,
uniformDistribution, uniformDistribution,
@ -120,3 +122,61 @@ describe("sparseMarkovChainAction", () => {
expect(pi1).toEqual(expected); expect(pi1).toEqual(expected);
}); });
}); });
function expectAllClose(
actual: Float64Array,
expected: Float64Array,
epsilon: number = 1e-6
): void {
expect(actual).toHaveLength(expected.length);
for (let i = 0; i < expected.length; i++) {
if (Math.abs(actual[i] - expected[i]) >= epsilon) {
expect(actual).toEqual(expected); // will fail
return;
}
}
}
function expectStationary(chain: SparseMarkovChain, pi: Distribution): void {
expectAllClose(sparseMarkovChainAction(chain, pi), pi);
}
describe("findStationaryDistribution", () => {
it("finds an all-accumulating stationary distribution", () => {
const chain = sparseMarkovChainFromTransitionMatrix([
[1, 0, 0],
[0.25, 0, 0.75],
[0.25, 0.75, 0],
]);
const pi = findStationaryDistribution(chain);
expectStationary(chain, pi);
const expected = new Float64Array([1, 0, 0]);
expectAllClose(pi, expected);
});
it("finds a non-degenerate stationary distribution", () => {
// Node 0 is the "center"; nodes 1 through 4 are "satellites". A
// satellite transitions to the center with probability 0.5, or to a
// cyclically adjacent satellite with probability 0.25 each. The
// center transitions to a uniformly random satellite.
const chain = sparseMarkovChainFromTransitionMatrix([
[0, 0.25, 0.25, 0.25, 0.25],
[0.5, 0, 0.25, 0, 0.25],
[0.5, 0.25, 0, 0.25, 0],
[0.5, 0, 0.25, 0, 0.25],
[0.5, 0.25, 0, 0.25, 0],
]);
const pi = findStationaryDistribution(chain);
expectStationary(chain, pi);
const expected = new Float64Array([1 / 3, 1 / 6, 1 / 6, 1 / 6, 1 / 6]);
expectAllClose(pi, expected);
});
it("finds the stationary distribution of a periodic chain", () => {
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]);
const pi = findStationaryDistribution(chain);
expectStationary(chain, pi);
const expected = new Float64Array([0.5, 0.5]);
expectAllClose(pi, expected);
});
});