diff --git a/src/core/pagerankGraph.js b/src/core/pagerankGraph.js index c50815c..beaa831 100644 --- a/src/core/pagerankGraph.js +++ b/src/core/pagerankGraph.js @@ -24,7 +24,7 @@ import { import { findStationaryDistribution, type PagerankParams, - type PagerankOptions, + type PagerankOptions as CorePagerankOptions, } from "../core/attribution/markovChain"; import * as NullUtil from "../util/null"; @@ -70,12 +70,14 @@ export opaque type PagerankGraphJSON = Compatible<{| /** * Options to control how PageRank runs and when it stops */ -export type PagerankConvergenceOptions = {| +export type PagerankOptions = {| // Maximum number of iterations before we give up on PageRank Convergence - +maxIterations: number, + // Defaults to DEFAULT_MAX_ITERATIONS if not provided. + +maxIterations?: number, // PageRank will stop running once the diff between the previous iteration - // and the latest is less than this threshold - +convergenceThreshold: number, + // and the latest is less than this threshold. + // Defaults to DEFAULT_CONVERGENCE_THRESHOLD if not provided. + +convergenceThreshold?: number, |}; export type PagerankConvergenceReport = {| @@ -90,6 +92,13 @@ export const DEFAULT_SYNTHETIC_LOOP_WEIGHT = 1e-3; export const DEFAULT_MAX_ITERATIONS = 255; export const DEFAULT_CONVERGENCE_THRESHOLD = 1e-7; +function defaultOptions(): PagerankOptions { + return { + maxIterations: DEFAULT_MAX_ITERATIONS, + convergenceThreshold: DEFAULT_CONVERGENCE_THRESHOLD, + }; +} + const COMPAT_INFO = {type: "sourcecred/pagerankGraph", version: "0.1.0"}; /** @@ -414,9 +423,13 @@ export class PagerankGraph { * scratch every time `runPagerank` is called. */ async runPagerank( - options: PagerankConvergenceOptions + options?: PagerankOptions ): Promise { this._verifyGraphNotModified(); + const fullOptions = { + ...defaultOptions(), + ...(options || {}), + }; const edgeEvaluator = (x: Edge) => NullUtil.get(this._edgeWeights.get(x.address)); const connections = createConnections( @@ -426,10 +439,10 @@ export class PagerankGraph { ); const osmc = createOrderedSparseMarkovChain(connections); const params: PagerankParams = {chain: osmc.chain}; - const coreOptions: PagerankOptions = { + const coreOptions: CorePagerankOptions = { verbose: false, - convergenceThreshold: options.convergenceThreshold, - maxIterations: options.maxIterations, + convergenceThreshold: fullOptions.convergenceThreshold, + maxIterations: fullOptions.maxIterations, yieldAfterMs: 30, }; const distributionResult = await findStationaryDistribution( diff --git a/src/core/pagerankGraph.test.js b/src/core/pagerankGraph.test.js index 7bba87d..9f76e97 100644 --- a/src/core/pagerankGraph.test.js +++ b/src/core/pagerankGraph.test.js @@ -9,7 +9,12 @@ import { type Edge, type EdgesOptions, } from "./graph"; -import {PagerankGraph, Direction} from "./pagerankGraph"; +import { + PagerankGraph, + Direction, + DEFAULT_MAX_ITERATIONS, + DEFAULT_CONVERGENCE_THRESHOLD, +} from "./pagerankGraph"; import {advancedGraph} from "./graphTestUtil"; import * as NullUtil from "../util/null"; @@ -486,6 +491,20 @@ describe("core/pagerankGraph", () => { expect(total).toBeCloseTo(1); } + it("runs PageRank with default options if not specified", () => { + const pg1 = examplePagerankGraph(); + const pg2 = examplePagerankGraph(); + const pg3 = examplePagerankGraph(); + pg1.runPagerank(); + pg2.runPagerank({}); + pg3.runPagerank({ + maxIterations: DEFAULT_MAX_ITERATIONS, + convergenceThreshold: DEFAULT_CONVERGENCE_THRESHOLD, + }); + expect(pg1.equals(pg2)).toBe(true); + expect(pg1.equals(pg3)).toBe(true); + }); + it("promise rejects if the graph was modified", async () => { const pg = examplePagerankGraph(); pg.graph().addNode(NodeAddress.empty);