Extract `findStationaryDistribution` (#277)
Test Plan: Unit tests added. Run `yarn test`. wchargin-branch: extract-findStationaryDistribution
This commit is contained in:
parent
9d7f9f78cd
commit
115d7f3921
|
@ -9,10 +9,7 @@ import type {
|
||||||
Distribution,
|
Distribution,
|
||||||
SparseMarkovChain,
|
SparseMarkovChain,
|
||||||
} from "../../core/attribution/markovChain";
|
} from "../../core/attribution/markovChain";
|
||||||
import {
|
import {findStationaryDistribution} from "../../core/attribution/markovChain";
|
||||||
sparseMarkovChainAction,
|
|
||||||
uniformDistribution,
|
|
||||||
} from "../../core/attribution/markovChain";
|
|
||||||
|
|
||||||
export type PagerankResult = AddressMap<{|
|
export type PagerankResult = AddressMap<{|
|
||||||
+address: Address,
|
+address: Address,
|
||||||
|
@ -119,53 +116,6 @@ export function graphToOrderedSparseMarkovChain(
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
function findStationaryDistribution(
|
|
||||||
chain: SparseMarkovChain,
|
|
||||||
options?: {|
|
|
||||||
+verbose?: boolean,
|
|
||||||
+convergenceThreshold?: number,
|
|
||||||
+maxIterations?: number,
|
|
||||||
|}
|
|
||||||
): Distribution {
|
|
||||||
const fullOptions = {
|
|
||||||
verbose: false,
|
|
||||||
convergenceThreshold: 1e-7,
|
|
||||||
maxIterations: 255,
|
|
||||||
...(options || {}),
|
|
||||||
};
|
|
||||||
let r0 = uniformDistribution(chain.length);
|
|
||||||
function computeDelta(pi0, pi1) {
|
|
||||||
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
|
|
||||||
// same (i.e., there has been no permutation).
|
|
||||||
return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i])));
|
|
||||||
}
|
|
||||||
let iteration = 0;
|
|
||||||
while (true) {
|
|
||||||
iteration++;
|
|
||||||
const r1 = sparseMarkovChainAction(chain, r0);
|
|
||||||
const delta = computeDelta(r0, r1);
|
|
||||||
r0 = r1;
|
|
||||||
if (fullOptions.verbose) {
|
|
||||||
console.log(`[${iteration}] delta = ${delta}`);
|
|
||||||
}
|
|
||||||
if (delta < fullOptions.convergenceThreshold) {
|
|
||||||
if (fullOptions.verbose) {
|
|
||||||
console.log(`[${iteration}] CONVERGED`);
|
|
||||||
}
|
|
||||||
return r0;
|
|
||||||
}
|
|
||||||
if (iteration >= fullOptions.maxIterations) {
|
|
||||||
if (fullOptions.verbose) {
|
|
||||||
console.log(`[${iteration}] FAILED to converge`);
|
|
||||||
}
|
|
||||||
return r0;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
|
|
||||||
// eslint-disable-next-line no-unreachable
|
|
||||||
throw new Error("Unreachable.");
|
|
||||||
}
|
|
||||||
|
|
||||||
function distributionToPagerankResult(
|
function distributionToPagerankResult(
|
||||||
nodeOrder: $ReadOnlyArray<Address>,
|
nodeOrder: $ReadOnlyArray<Address>,
|
||||||
pi: Distribution
|
pi: Distribution
|
||||||
|
|
|
@ -89,3 +89,50 @@ export function sparseMarkovChainAction(
|
||||||
});
|
});
|
||||||
return result;
|
return result;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
export function findStationaryDistribution(
|
||||||
|
chain: SparseMarkovChain,
|
||||||
|
options?: {|
|
||||||
|
+verbose?: boolean,
|
||||||
|
+convergenceThreshold?: number,
|
||||||
|
+maxIterations?: number,
|
||||||
|
|}
|
||||||
|
): Distribution {
|
||||||
|
const fullOptions = {
|
||||||
|
verbose: false,
|
||||||
|
convergenceThreshold: 1e-7,
|
||||||
|
maxIterations: 255,
|
||||||
|
...(options || {}),
|
||||||
|
};
|
||||||
|
let r0 = uniformDistribution(chain.length);
|
||||||
|
function computeDelta(pi0, pi1) {
|
||||||
|
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
|
||||||
|
// same (i.e., there has been no permutation).
|
||||||
|
return Math.max(...pi0.map((x, i) => Math.abs(x - pi1[i])));
|
||||||
|
}
|
||||||
|
let iteration = 0;
|
||||||
|
while (true) {
|
||||||
|
iteration++;
|
||||||
|
const r1 = sparseMarkovChainAction(chain, r0);
|
||||||
|
const delta = computeDelta(r0, r1);
|
||||||
|
r0 = r1;
|
||||||
|
if (fullOptions.verbose) {
|
||||||
|
console.log(`[${iteration}] delta = ${delta}`);
|
||||||
|
}
|
||||||
|
if (delta < fullOptions.convergenceThreshold) {
|
||||||
|
if (fullOptions.verbose) {
|
||||||
|
console.log(`[${iteration}] CONVERGED`);
|
||||||
|
}
|
||||||
|
return r0;
|
||||||
|
}
|
||||||
|
if (iteration >= fullOptions.maxIterations) {
|
||||||
|
if (fullOptions.verbose) {
|
||||||
|
console.log(`[${iteration}] FAILED to converge`);
|
||||||
|
}
|
||||||
|
return r0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
|
||||||
|
// eslint-disable-next-line no-unreachable
|
||||||
|
throw new Error("Unreachable.");
|
||||||
|
}
|
||||||
|
|
|
@ -1,6 +1,8 @@
|
||||||
// @flow
|
// @flow
|
||||||
|
|
||||||
|
import type {Distribution, SparseMarkovChain} from "./markovChain";
|
||||||
import {
|
import {
|
||||||
|
findStationaryDistribution,
|
||||||
sparseMarkovChainAction,
|
sparseMarkovChainAction,
|
||||||
sparseMarkovChainFromTransitionMatrix,
|
sparseMarkovChainFromTransitionMatrix,
|
||||||
uniformDistribution,
|
uniformDistribution,
|
||||||
|
@ -120,3 +122,61 @@ describe("sparseMarkovChainAction", () => {
|
||||||
expect(pi1).toEqual(expected);
|
expect(pi1).toEqual(expected);
|
||||||
});
|
});
|
||||||
});
|
});
|
||||||
|
|
||||||
|
function expectAllClose(
|
||||||
|
actual: Float64Array,
|
||||||
|
expected: Float64Array,
|
||||||
|
epsilon: number = 1e-6
|
||||||
|
): void {
|
||||||
|
expect(actual).toHaveLength(expected.length);
|
||||||
|
for (let i = 0; i < expected.length; i++) {
|
||||||
|
if (Math.abs(actual[i] - expected[i]) >= epsilon) {
|
||||||
|
expect(actual).toEqual(expected); // will fail
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function expectStationary(chain: SparseMarkovChain, pi: Distribution): void {
|
||||||
|
expectAllClose(sparseMarkovChainAction(chain, pi), pi);
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("findStationaryDistribution", () => {
|
||||||
|
it("finds an all-accumulating stationary distribution", () => {
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix([
|
||||||
|
[1, 0, 0],
|
||||||
|
[0.25, 0, 0.75],
|
||||||
|
[0.25, 0.75, 0],
|
||||||
|
]);
|
||||||
|
const pi = findStationaryDistribution(chain);
|
||||||
|
expectStationary(chain, pi);
|
||||||
|
const expected = new Float64Array([1, 0, 0]);
|
||||||
|
expectAllClose(pi, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("finds a non-degenerate stationary distribution", () => {
|
||||||
|
// Node 0 is the "center"; nodes 1 through 4 are "satellites". A
|
||||||
|
// satellite transitions to the center with probability 0.5, or to a
|
||||||
|
// cyclically adjacent satellite with probability 0.25 each. The
|
||||||
|
// center transitions to a uniformly random satellite.
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix([
|
||||||
|
[0, 0.25, 0.25, 0.25, 0.25],
|
||||||
|
[0.5, 0, 0.25, 0, 0.25],
|
||||||
|
[0.5, 0.25, 0, 0.25, 0],
|
||||||
|
[0.5, 0, 0.25, 0, 0.25],
|
||||||
|
[0.5, 0.25, 0, 0.25, 0],
|
||||||
|
]);
|
||||||
|
const pi = findStationaryDistribution(chain);
|
||||||
|
expectStationary(chain, pi);
|
||||||
|
const expected = new Float64Array([1 / 3, 1 / 6, 1 / 6, 1 / 6, 1 / 6]);
|
||||||
|
expectAllClose(pi, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("finds the stationary distribution of a periodic chain", () => {
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]);
|
||||||
|
const pi = findStationaryDistribution(chain);
|
||||||
|
expectStationary(chain, pi);
|
||||||
|
const expected = new Float64Array([0.5, 0.5]);
|
||||||
|
expectAllClose(pi, expected);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
Loading…
Reference in New Issue