Extract `findStationaryDistribution` (#277)
Test Plan: Unit tests added. Run `yarn test`. wchargin-branch: extract-findStationaryDistribution
This commit is contained in:
parent
9d7f9f78cd
commit
115d7f3921
|
@ -9,10 +9,7 @@ import type {
|
|||
Distribution,
|
||||
SparseMarkovChain,
|
||||
} from "../../core/attribution/markovChain";
|
||||
import {
|
||||
sparseMarkovChainAction,
|
||||
uniformDistribution,
|
||||
} from "../../core/attribution/markovChain";
|
||||
import {findStationaryDistribution} from "../../core/attribution/markovChain";
|
||||
|
||||
export type PagerankResult = AddressMap<{|
|
||||
+address: Address,
|
||||
|
@ -119,53 +116,6 @@ export function graphToOrderedSparseMarkovChain(
|
|||
);
|
||||
}
|
||||
|
||||
function findStationaryDistribution(
|
||||
chain: SparseMarkovChain,
|
||||
options?: {|
|
||||
+verbose?: boolean,
|
||||
+convergenceThreshold?: number,
|
||||
+maxIterations?: number,
|
||||
|}
|
||||
): Distribution {
|
||||
const fullOptions = {
|
||||
verbose: false,
|
||||
convergenceThreshold: 1e-7,
|
||||
maxIterations: 255,
|
||||
...(options || {}),
|
||||
};
|
||||
let r0 = uniformDistribution(chain.length);
|
||||
function computeDelta(pi0, pi1) {
|
||||
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
|
||||
// same (i.e., there has been no permutation).
|
||||
return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i])));
|
||||
}
|
||||
let iteration = 0;
|
||||
while (true) {
|
||||
iteration++;
|
||||
const r1 = sparseMarkovChainAction(chain, r0);
|
||||
const delta = computeDelta(r0, r1);
|
||||
r0 = r1;
|
||||
if (fullOptions.verbose) {
|
||||
console.log(`[${iteration}] delta = ${delta}`);
|
||||
}
|
||||
if (delta < fullOptions.convergenceThreshold) {
|
||||
if (fullOptions.verbose) {
|
||||
console.log(`[${iteration}] CONVERGED`);
|
||||
}
|
||||
return r0;
|
||||
}
|
||||
if (iteration >= fullOptions.maxIterations) {
|
||||
if (fullOptions.verbose) {
|
||||
console.log(`[${iteration}] FAILED to converge`);
|
||||
}
|
||||
return r0;
|
||||
}
|
||||
}
|
||||
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
|
||||
// eslint-disable-next-line no-unreachable
|
||||
throw new Error("Unreachable.");
|
||||
}
|
||||
|
||||
function distributionToPagerankResult(
|
||||
nodeOrder: $ReadOnlyArray<Address>,
|
||||
pi: Distribution
|
||||
|
|
|
@ -89,3 +89,50 @@ export function sparseMarkovChainAction(
|
|||
});
|
||||
return result;
|
||||
}
|
||||
|
||||
export function findStationaryDistribution(
|
||||
chain: SparseMarkovChain,
|
||||
options?: {|
|
||||
+verbose?: boolean,
|
||||
+convergenceThreshold?: number,
|
||||
+maxIterations?: number,
|
||||
|}
|
||||
): Distribution {
|
||||
const fullOptions = {
|
||||
verbose: false,
|
||||
convergenceThreshold: 1e-7,
|
||||
maxIterations: 255,
|
||||
...(options || {}),
|
||||
};
|
||||
let r0 = uniformDistribution(chain.length);
|
||||
function computeDelta(pi0, pi1) {
|
||||
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
|
||||
// same (i.e., there has been no permutation).
|
||||
return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i])));
|
||||
}
|
||||
let iteration = 0;
|
||||
while (true) {
|
||||
iteration++;
|
||||
const r1 = sparseMarkovChainAction(chain, r0);
|
||||
const delta = computeDelta(r0, r1);
|
||||
r0 = r1;
|
||||
if (fullOptions.verbose) {
|
||||
console.log(`[${iteration}] delta = ${delta}`);
|
||||
}
|
||||
if (delta < fullOptions.convergenceThreshold) {
|
||||
if (fullOptions.verbose) {
|
||||
console.log(`[${iteration}] CONVERGED`);
|
||||
}
|
||||
return r0;
|
||||
}
|
||||
if (iteration >= fullOptions.maxIterations) {
|
||||
if (fullOptions.verbose) {
|
||||
console.log(`[${iteration}] FAILED to converge`);
|
||||
}
|
||||
return r0;
|
||||
}
|
||||
}
|
||||
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
|
||||
// eslint-disable-next-line no-unreachable
|
||||
throw new Error("Unreachable.");
|
||||
}
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
// @flow
|
||||
|
||||
import type {Distribution, SparseMarkovChain} from "./markovChain";
|
||||
import {
|
||||
findStationaryDistribution,
|
||||
sparseMarkovChainAction,
|
||||
sparseMarkovChainFromTransitionMatrix,
|
||||
uniformDistribution,
|
||||
|
@ -120,3 +122,61 @@ describe("sparseMarkovChainAction", () => {
|
|||
expect(pi1).toEqual(expected);
|
||||
});
|
||||
});
|
||||
|
||||
function expectAllClose(
|
||||
actual: Float64Array,
|
||||
expected: Float64Array,
|
||||
epsilon: number = 1e-6
|
||||
): void {
|
||||
expect(actual).toHaveLength(expected.length);
|
||||
for (let i = 0; i < expected.length; i++) {
|
||||
if (Math.abs(actual[i] - expected[i]) >= epsilon) {
|
||||
expect(actual).toEqual(expected); // will fail
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
function expectStationary(chain: SparseMarkovChain, pi: Distribution): void {
|
||||
expectAllClose(sparseMarkovChainAction(chain, pi), pi);
|
||||
}
|
||||
|
||||
describe("findStationaryDistribution", () => {
|
||||
it("finds an all-accumulating stationary distribution", () => {
|
||||
const chain = sparseMarkovChainFromTransitionMatrix([
|
||||
[1, 0, 0],
|
||||
[0.25, 0, 0.75],
|
||||
[0.25, 0.75, 0],
|
||||
]);
|
||||
const pi = findStationaryDistribution(chain);
|
||||
expectStationary(chain, pi);
|
||||
const expected = new Float64Array([1, 0, 0]);
|
||||
expectAllClose(pi, expected);
|
||||
});
|
||||
|
||||
it("finds a non-degenerate stationary distribution", () => {
|
||||
// Node 0 is the "center"; nodes 1 through 4 are "satellites". A
|
||||
// satellite transitions to the center with probability 0.5, or to a
|
||||
// cyclically adjacent satellite with probability 0.25 each. The
|
||||
// center transitions to a uniformly random satellite.
|
||||
const chain = sparseMarkovChainFromTransitionMatrix([
|
||||
[0, 0.25, 0.25, 0.25, 0.25],
|
||||
[0.5, 0, 0.25, 0, 0.25],
|
||||
[0.5, 0.25, 0, 0.25, 0],
|
||||
[0.5, 0, 0.25, 0, 0.25],
|
||||
[0.5, 0.25, 0, 0.25, 0],
|
||||
]);
|
||||
const pi = findStationaryDistribution(chain);
|
||||
expectStationary(chain, pi);
|
||||
const expected = new Float64Array([1 / 3, 1 / 6, 1 / 6, 1 / 6, 1 / 6]);
|
||||
expectAllClose(pi, expected);
|
||||
});
|
||||
|
||||
it("finds the stationary distribution of a periodic chain", () => {
|
||||
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]);
|
||||
const pi = findStationaryDistribution(chain);
|
||||
expectStationary(chain, pi);
|
||||
const expected = new Float64Array([0.5, 0.5]);
|
||||
expectAllClose(pi, expected);
|
||||
});
|
||||
});
|
||||
|
|
Loading…
Reference in New Issue