diff --git a/src/v3/core/attribution/markovChain.js b/src/v3/core/attribution/markovChain.js index f87e900..2f1c24f 100644 --- a/src/v3/core/attribution/markovChain.js +++ b/src/v3/core/attribution/markovChain.js @@ -92,18 +92,12 @@ export function sparseMarkovChainAction( export function findStationaryDistribution( chain: SparseMarkovChain, - options?: {| - +verbose?: boolean, - +convergenceThreshold?: number, - +maxIterations?: number, + options: {| + +verbose: boolean, + +convergenceThreshold: number, + +maxIterations: number, |} ): Distribution { - const fullOptions = { - verbose: false, - convergenceThreshold: 1e-7, - maxIterations: 255, - ...(options || {}), - }; let r0 = uniformDistribution(chain.length); function computeDelta(pi0, pi1) { let maxDelta = -Infinity; @@ -117,8 +111,8 @@ export function findStationaryDistribution( } let iteration = 0; while (true) { - if (iteration >= fullOptions.maxIterations) { - if (fullOptions.verbose) { + if (iteration >= options.maxIterations) { + if (options.verbose) { console.log(`[${iteration}] FAILED to converge`); } return r0; @@ -127,11 +121,11 @@ export function findStationaryDistribution( const r1 = sparseMarkovChainAction(chain, r0); const delta = computeDelta(r0, r1); r0 = r1; - if (fullOptions.verbose) { + if (options.verbose) { console.log(`[${iteration}] delta = ${delta}`); } - if (delta < fullOptions.convergenceThreshold) { - if (fullOptions.verbose) { + if (delta < options.convergenceThreshold) { + if (options.verbose) { console.log(`[${iteration}] CONVERGED`); } return r0; diff --git a/src/v3/core/attribution/markovChain.test.js b/src/v3/core/attribution/markovChain.test.js index 500742c..679be43 100644 --- a/src/v3/core/attribution/markovChain.test.js +++ b/src/v3/core/attribution/markovChain.test.js @@ -151,7 +151,11 @@ describe("core/attribution/markovChain", () => { [0.25, 0, 0.75], [0.25, 0.75, 0], ]); - const pi = findStationaryDistribution(chain); + const pi = findStationaryDistribution(chain, { + maxIterations: 255, + convergenceThreshold: 1e-7, + verbose: false, + }); expectStationary(chain, pi); const expected = new Float64Array([1, 0, 0]); expectAllClose(pi, expected); @@ -169,7 +173,11 @@ describe("core/attribution/markovChain", () => { [0.5, 0, 0.25, 0, 0.25], [0.5, 0.25, 0, 0.25, 0], ]); - const pi = findStationaryDistribution(chain); + const pi = findStationaryDistribution(chain, { + maxIterations: 255, + convergenceThreshold: 1e-7, + verbose: false, + }); expectStationary(chain, pi); const expected = new Float64Array([1 / 3, 1 / 6, 1 / 6, 1 / 6, 1 / 6]); expectAllClose(pi, expected); @@ -177,7 +185,11 @@ describe("core/attribution/markovChain", () => { it("finds the stationary distribution of a periodic chain", () => { const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]); - const pi = findStationaryDistribution(chain); + const pi = findStationaryDistribution(chain, { + maxIterations: 255, + convergenceThreshold: 1e-7, + verbose: false, + }); expectStationary(chain, pi); const expected = new Float64Array([0.5, 0.5]); expectAllClose(pi, expected); @@ -185,7 +197,11 @@ describe("core/attribution/markovChain", () => { it("returns initial distribution if maxIterations===0", () => { const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [0, 1]]); - const pi = findStationaryDistribution(chain, {maxIterations: 0}); + const pi = findStationaryDistribution(chain, { + verbose: false, + convergenceThreshold: 1e-7, + maxIterations: 0, + }); const expected = new Float64Array([0.5, 0.5]); expect(pi).toEqual(expected); });