refactor args to findStationaryDistribution (#1130)

In [#1128: Add support for seed vectors][#1128], we significantly
increase the number of arguments to
markovChain.findStationaryDistribution. To clean up the invocations, I
added a followon PR (#1129) which converts findStationaryDistribution to
use a `PagerankParams` object instead.

However, I think it will be cleaner to land the PagerankParams refactor
before adding new features in #1128, so I'm making this PR as
pre-cleanup.

Test plan: This is a trivial refactor. `yarn test` passes.

[#1128]: https://github.com/sourcecred/sourcecred/pull/1128
This commit is contained in:
Dandelion Mané 2019-04-21 14:00:30 +03:00 committed by GitHub
parent 6dd58a9c67
commit a8a3f4fc3a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 76 additions and 22 deletions

View File

@ -19,7 +19,11 @@ import {
import {scoreByConstantTotal} from "./nodeScore"; import {scoreByConstantTotal} from "./nodeScore";
import {findStationaryDistribution} from "../core/attribution/markovChain"; import {
findStationaryDistribution,
type PagerankParams,
type PagerankOptions as CorePagerankOptions,
} from "../core/attribution/markovChain";
export type {NodeDistribution} from "../core/attribution/graphToMarkovChain"; export type {NodeDistribution} from "../core/attribution/graphToMarkovChain";
export type {PagerankNodeDecomposition} from "./pagerankNodeDecomposition"; export type {PagerankNodeDecomposition} from "./pagerankNodeDecomposition";
@ -63,12 +67,17 @@ export async function pagerank(
fullOptions.selfLoopWeight fullOptions.selfLoopWeight
); );
const osmc = createOrderedSparseMarkovChain(connections); const osmc = createOrderedSparseMarkovChain(connections);
const distributionResult = await findStationaryDistribution(osmc.chain, { const params: PagerankParams = {chain: osmc.chain};
const coreOptions: CorePagerankOptions = {
verbose: fullOptions.verbose, verbose: fullOptions.verbose,
convergenceThreshold: fullOptions.convergenceThreshold, convergenceThreshold: fullOptions.convergenceThreshold,
maxIterations: fullOptions.maxIterations, maxIterations: fullOptions.maxIterations,
yieldAfterMs: 30, yieldAfterMs: 30,
}); };
const distributionResult = await findStationaryDistribution(
params,
coreOptions
);
const pi = distributionToNodeDistribution( const pi = distributionToNodeDistribution(
osmc.nodeOrder, osmc.nodeOrder,
distributionResult.pi distributionResult.pi

View File

@ -6,7 +6,10 @@ import {
createConnections, createConnections,
createOrderedSparseMarkovChain, createOrderedSparseMarkovChain,
} from "../core/attribution/graphToMarkovChain"; } from "../core/attribution/graphToMarkovChain";
import {findStationaryDistribution} from "../core/attribution/markovChain"; import {
findStationaryDistribution,
type PagerankParams,
} from "../core/attribution/markovChain";
import { import {
decompose, decompose,
type PagerankNodeDecomposition, type PagerankNodeDecomposition,
@ -131,7 +134,8 @@ describe("analysis/pagerankNodeDecomposition", () => {
const edgeWeight = () => ({toWeight: 6.0, froWeight: 3.0}); const edgeWeight = () => ({toWeight: 6.0, froWeight: 3.0});
const connections = createConnections(g, edgeWeight, 1.0); const connections = createConnections(g, edgeWeight, 1.0);
const osmc = createOrderedSparseMarkovChain(connections); const osmc = createOrderedSparseMarkovChain(connections);
const distributionResult = await findStationaryDistribution(osmc.chain, { const params: PagerankParams = {chain: osmc.chain};
const distributionResult = await findStationaryDistribution(params, {
verbose: false, verbose: false,
convergenceThreshold: 1e-6, convergenceThreshold: 1e-6,
maxIterations: 255, maxIterations: 255,
@ -151,7 +155,8 @@ describe("analysis/pagerankNodeDecomposition", () => {
const edgeWeight = () => ({toWeight: 6.0, froWeight: 3.0}); const edgeWeight = () => ({toWeight: 6.0, froWeight: 3.0});
const connections = createConnections(g, edgeWeight, 1.0); const connections = createConnections(g, edgeWeight, 1.0);
const osmc = createOrderedSparseMarkovChain(connections); const osmc = createOrderedSparseMarkovChain(connections);
const distributionResult = await findStationaryDistribution(osmc.chain, { const params: PagerankParams = {chain: osmc.chain};
const distributionResult = await findStationaryDistribution(params, {
verbose: false, verbose: false,
convergenceThreshold: 1e-6, convergenceThreshold: 1e-6,
maxIterations: 255, maxIterations: 255,

View File

@ -7,6 +7,36 @@
*/ */
export type Distribution = Float64Array; export type Distribution = Float64Array;
/**
* The data inputs to running PageRank.
*
* We keep these separate from the PagerankOptions below,
* because we expect that within a given context, every call to
* findStationaryDistribution (or other Pagerank functions) will
* have different PagerankParams, but often have the same PagerankOptions.
*/
export type PagerankParams = {|
+chain: SparseMarkovChain,
|};
/**
* PagerankOptions allows the user to tweak PageRank's behavior, especially around
* convergence.
*/
export type PagerankOptions = {|
// Causes runtime information to get logged to console.
+verbose: boolean,
// A distribution is considered stationary if the action of the Markov
// chain on the distribution does not change any component by more than
// `convergenceThreshold` in absolute value.
+convergenceThreshold: number,
// We will run maxIterations markov chain steps at most.
+maxIterations: number,
// To prevent locking the rest of the application, PageRank will yield control
// after this many miliseconds, allowing UI updates, etc.
+yieldAfterMs: number,
|};
export type StationaryDistributionResult = {| export type StationaryDistributionResult = {|
// The final distribution after attempting to find the stationary distribution // The final distribution after attempting to find the stationary distribution
// of the Markov chain. // of the Markov chain.
@ -134,7 +164,7 @@ export function computeDelta(pi0: Distribution, pi1: Distribution) {
} }
function* findStationaryDistributionGenerator( function* findStationaryDistributionGenerator(
chain: SparseMarkovChain, params: PagerankParams,
options: {| options: {|
+verbose: boolean, +verbose: boolean,
// A distribution is considered stationary if the action of the Markov // A distribution is considered stationary if the action of the Markov
@ -145,6 +175,7 @@ function* findStationaryDistributionGenerator(
+maxIterations: number, +maxIterations: number,
|} |}
): Generator<void, StationaryDistributionResult, void> { ): Generator<void, StationaryDistributionResult, void> {
const {chain} = params;
let pi = uniformDistribution(chain.length); let pi = uniformDistribution(chain.length);
let scratch = new Float64Array(pi.length); let scratch = new Float64Array(pi.length);
@ -186,15 +217,10 @@ function* findStationaryDistributionGenerator(
} }
export function findStationaryDistribution( export function findStationaryDistribution(
chain: SparseMarkovChain, params: PagerankParams,
options: {| options: PagerankOptions
+verbose: boolean,
+convergenceThreshold: number,
+maxIterations: number,
+yieldAfterMs: number,
|}
): Promise<StationaryDistributionResult> { ): Promise<StationaryDistributionResult> {
let gen = findStationaryDistributionGenerator(chain, { let gen = findStationaryDistributionGenerator(params, {
verbose: options.verbose, verbose: options.verbose,
convergenceThreshold: options.convergenceThreshold, convergenceThreshold: options.convergenceThreshold,
maxIterations: options.maxIterations, maxIterations: options.maxIterations,

View File

@ -8,6 +8,7 @@ import {
uniformDistribution, uniformDistribution,
computeDelta, computeDelta,
type StationaryDistributionResult, type StationaryDistributionResult,
type PagerankParams,
} from "./markovChain"; } from "./markovChain";
describe("core/attribution/markovChain", () => { describe("core/attribution/markovChain", () => {
@ -158,7 +159,8 @@ describe("core/attribution/markovChain", () => {
[0.25, 0, 0.75], [0.25, 0, 0.75],
[0.25, 0.75, 0], [0.25, 0.75, 0],
]); ]);
const result = await findStationaryDistribution(chain, { const params: PagerankParams = {chain};
const result = await findStationaryDistribution(params, {
maxIterations: 255, maxIterations: 255,
convergenceThreshold: 1e-7, convergenceThreshold: 1e-7,
verbose: false, verbose: false,
@ -184,7 +186,8 @@ describe("core/attribution/markovChain", () => {
[0.5, 0, 0.25, 0, 0.25], [0.5, 0, 0.25, 0, 0.25],
[0.5, 0.25, 0, 0.25, 0], [0.5, 0.25, 0, 0.25, 0],
]); ]);
const result = await findStationaryDistribution(chain, { const params: PagerankParams = {chain};
const result = await findStationaryDistribution(params, {
maxIterations: 255, maxIterations: 255,
convergenceThreshold: 1e-7, convergenceThreshold: 1e-7,
verbose: false, verbose: false,
@ -201,7 +204,8 @@ describe("core/attribution/markovChain", () => {
it("finds the stationary distribution of a periodic chain", async () => { it("finds the stationary distribution of a periodic chain", async () => {
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]); const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]);
const result = await findStationaryDistribution(chain, { const params: PagerankParams = {chain};
const result = await findStationaryDistribution(params, {
maxIterations: 255, maxIterations: 255,
convergenceThreshold: 1e-7, convergenceThreshold: 1e-7,
verbose: false, verbose: false,
@ -218,7 +222,8 @@ describe("core/attribution/markovChain", () => {
it("returns initial distribution if maxIterations===0", async () => { it("returns initial distribution if maxIterations===0", async () => {
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [0, 1]]); const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [0, 1]]);
const result = await findStationaryDistribution(chain, { const params: PagerankParams = {chain};
const result = await findStationaryDistribution(params, {
verbose: false, verbose: false,
convergenceThreshold: 1e-7, convergenceThreshold: 1e-7,
maxIterations: 0, maxIterations: 0,

View File

@ -21,7 +21,11 @@ import {
createOrderedSparseMarkovChain, createOrderedSparseMarkovChain,
type EdgeWeight, type EdgeWeight,
} from "./attribution/graphToMarkovChain"; } from "./attribution/graphToMarkovChain";
import {findStationaryDistribution} from "../core/attribution/markovChain"; import {
findStationaryDistribution,
type PagerankParams,
type PagerankOptions,
} from "../core/attribution/markovChain";
import * as NullUtil from "../util/null"; import * as NullUtil from "../util/null";
export {Direction} from "./graph"; export {Direction} from "./graph";
@ -421,12 +425,17 @@ export class PagerankGraph {
this._syntheticLoopWeight this._syntheticLoopWeight
); );
const osmc = createOrderedSparseMarkovChain(connections); const osmc = createOrderedSparseMarkovChain(connections);
const distributionResult = await findStationaryDistribution(osmc.chain, { const params: PagerankParams = {chain: osmc.chain};
const coreOptions: PagerankOptions = {
verbose: false, verbose: false,
convergenceThreshold: options.convergenceThreshold, convergenceThreshold: options.convergenceThreshold,
maxIterations: options.maxIterations, maxIterations: options.maxIterations,
yieldAfterMs: 30, yieldAfterMs: 30,
}); };
const distributionResult = await findStationaryDistribution(
params,
coreOptions
);
this._scores = distributionToNodeDistribution( this._scores = distributionToNodeDistribution(
osmc.nodeOrder, osmc.nodeOrder,
distributionResult.pi distributionResult.pi