Require all findStationaryDistribution options (#453)

I'm planning to make a `pagerank.js` module that is a clean entry point
for all the graph-pagerank-related code, so it will be cleaner to expose all
the default options there.

Test plan: travis

Paired with @wchargin
This commit is contained in:
Dandelion Mané 2018-06-29 14:04:15 -07:00 committed by GitHub
parent a5608dd7c8
commit 5c93085430
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 29 additions and 19 deletions

View File

@ -92,18 +92,12 @@ export function sparseMarkovChainAction(
export function findStationaryDistribution(
chain: SparseMarkovChain,
options?: {|
+verbose?: boolean,
+convergenceThreshold?: number,
+maxIterations?: number,
options: {|
+verbose: boolean,
+convergenceThreshold: number,
+maxIterations: number,
|}
): Distribution {
const fullOptions = {
verbose: false,
convergenceThreshold: 1e-7,
maxIterations: 255,
...(options || {}),
};
let r0 = uniformDistribution(chain.length);
function computeDelta(pi0, pi1) {
let maxDelta = -Infinity;
@ -117,8 +111,8 @@ export function findStationaryDistribution(
}
let iteration = 0;
while (true) {
if (iteration >= fullOptions.maxIterations) {
if (fullOptions.verbose) {
if (iteration >= options.maxIterations) {
if (options.verbose) {
console.log(`[${iteration}] FAILED to converge`);
}
return r0;
@ -127,11 +121,11 @@ export function findStationaryDistribution(
const r1 = sparseMarkovChainAction(chain, r0);
const delta = computeDelta(r0, r1);
r0 = r1;
if (fullOptions.verbose) {
if (options.verbose) {
console.log(`[${iteration}] delta = ${delta}`);
}
if (delta < fullOptions.convergenceThreshold) {
if (fullOptions.verbose) {
if (delta < options.convergenceThreshold) {
if (options.verbose) {
console.log(`[${iteration}] CONVERGED`);
}
return r0;

View File

@ -151,7 +151,11 @@ describe("core/attribution/markovChain", () => {
[0.25, 0, 0.75],
[0.25, 0.75, 0],
]);
const pi = findStationaryDistribution(chain);
const pi = findStationaryDistribution(chain, {
maxIterations: 255,
convergenceThreshold: 1e-7,
verbose: false,
});
expectStationary(chain, pi);
const expected = new Float64Array([1, 0, 0]);
expectAllClose(pi, expected);
@ -169,7 +173,11 @@ describe("core/attribution/markovChain", () => {
[0.5, 0, 0.25, 0, 0.25],
[0.5, 0.25, 0, 0.25, 0],
]);
const pi = findStationaryDistribution(chain);
const pi = findStationaryDistribution(chain, {
maxIterations: 255,
convergenceThreshold: 1e-7,
verbose: false,
});
expectStationary(chain, pi);
const expected = new Float64Array([1 / 3, 1 / 6, 1 / 6, 1 / 6, 1 / 6]);
expectAllClose(pi, expected);
@ -177,7 +185,11 @@ describe("core/attribution/markovChain", () => {
it("finds the stationary distribution of a periodic chain", () => {
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]);
const pi = findStationaryDistribution(chain);
const pi = findStationaryDistribution(chain, {
maxIterations: 255,
convergenceThreshold: 1e-7,
verbose: false,
});
expectStationary(chain, pi);
const expected = new Float64Array([0.5, 0.5]);
expectAllClose(pi, expected);
@ -185,7 +197,11 @@ describe("core/attribution/markovChain", () => {
it("returns initial distribution if maxIterations===0", () => {
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [0, 1]]);
const pi = findStationaryDistribution(chain, {maxIterations: 0});
const pi = findStationaryDistribution(chain, {
verbose: false,
convergenceThreshold: 1e-7,
maxIterations: 0,
});
const expected = new Float64Array([0.5, 0.5]);
expect(pi).toEqual(expected);
});