diff --git a/src/core/attribution/nodeScore.js b/src/core/attribution/nodeScore.js index ccca24a..1368808 100644 --- a/src/core/attribution/nodeScore.js +++ b/src/core/attribution/nodeScore.js @@ -1,10 +1,11 @@ // @flow -import type {NodeAddressT} from "../graph"; +import {NodeAddress, type NodeAddressT} from "../graph"; import type {NodeDistribution} from "./graphToMarkovChain"; export type NodeScore = Map; +/* Normalize scores so that the maximum score has a fixed value */ export function scoreByMaximumProbability( pi: NodeDistribution, maxScore: number @@ -26,3 +27,34 @@ export function scoreByMaximumProbability( } return scoreMap; } + +/* Normalize scores so that a group of nodes have a fixed total score */ +export function scoreByConstantTotal( + pi: NodeDistribution, + totalScore: number, + nodeFilter: NodeAddressT /* Normalizes based on nodes matching this prefix */ +): NodeScore { + if (totalScore <= 0) { + throw new Error("Invalid argument: totalScore must be >= 0"); + } + + let unnormalizedTotal = 0; + for (const [addr, prob] of pi) { + if (NodeAddress.hasPrefix(addr, nodeFilter)) { + unnormalizedTotal += prob; + } + } + if (unnormalizedTotal === 0) { + throw new Error( + "Tried to normalize based on nodes with no score. " + + "This probably means that there were no nodes matching prefix: " + + NodeAddress.toString(nodeFilter) + ); + } + const f = totalScore / unnormalizedTotal; + const scoreMap = new Map(); + for (const [addr, prob] of pi) { + scoreMap.set(addr, prob * f); + } + return scoreMap; +} diff --git a/src/core/attribution/nodeScore.test.js b/src/core/attribution/nodeScore.test.js index 3085814..c91b805 100644 --- a/src/core/attribution/nodeScore.test.js +++ b/src/core/attribution/nodeScore.test.js @@ -1,47 +1,112 @@ // @flow import {NodeAddress} from "../graph"; -import {scoreByMaximumProbability} from "./nodeScore"; +import {scoreByMaximumProbability, scoreByConstantTotal} from "./nodeScore"; describe("core/attribution/nodeScore", () => { const foo = NodeAddress.fromParts(["foo"]); const bar = NodeAddress.fromParts(["bar"]); const zod = NodeAddress.fromParts(["zod"]); - it("works on a simple case", () => { - const distribution = new Map(); - distribution.set(foo, 0.5); - distribution.set(bar, 0.3); - distribution.set(zod, 0.2); - const result = scoreByMaximumProbability(distribution, 100); - expect(result.get(foo)).toEqual(100); - expect(result.get(bar)).toEqual(60); - expect(result.get(zod)).toEqual(40); + const foobar = NodeAddress.fromParts(["foo", "bar"]); + describe("scoreByMaximumProbability", () => { + it("works on a simple case", () => { + const distribution = new Map(); + distribution.set(foo, 0.5); + distribution.set(bar, 0.3); + distribution.set(zod, 0.2); + const result = scoreByMaximumProbability(distribution, 100); + expect(result.get(foo)).toEqual(100); + expect(result.get(bar)).toEqual(60); + expect(result.get(zod)).toEqual(40); + }); + it("normalizes to the maxScore argument", () => { + const distribution = new Map(); + distribution.set(foo, 0.5); + distribution.set(bar, 0.3); + distribution.set(zod, 0.2); + const result = scoreByMaximumProbability(distribution, 1000); + expect(result.get(foo)).toEqual(1000); + expect(result.get(bar)).toEqual(600); + expect(result.get(zod)).toEqual(400); + }); + it("handles a case with only a single node", () => { + const distribution = new Map(); + distribution.set(foo, 1.0); + const result = scoreByMaximumProbability(distribution, 1000); + expect(result.get(foo)).toEqual(1000); + }); + it("errors if maxScore <= 0", () => { + const distribution = new Map(); + distribution.set(foo, 1.0); + const result = () => scoreByMaximumProbability(distribution, 0); + expect(result).toThrowError("Invalid argument"); + }); + it("throws an error rather than divide by 0", () => { + const distribution = new Map(); + distribution.set(foo, 0.0); + const result = () => scoreByMaximumProbability(distribution, 1000); + expect(result).toThrowError("Invariant violation"); + }); }); - it("normalizes to the maxScore argument", () => { - const distribution = new Map(); - distribution.set(foo, 0.5); - distribution.set(bar, 0.3); - distribution.set(zod, 0.2); - const result = scoreByMaximumProbability(distribution, 1000); - expect(result.get(foo)).toEqual(1000); - expect(result.get(bar)).toEqual(600); - expect(result.get(zod)).toEqual(400); - }); - it("handles a case with only a single node", () => { - const distribution = new Map(); - distribution.set(foo, 1.0); - const result = scoreByMaximumProbability(distribution, 1000); - expect(result.get(foo)).toEqual(1000); - }); - it("errors if maxScore <= 0", () => { - const distribution = new Map(); - distribution.set(foo, 1.0); - const result = () => scoreByMaximumProbability(distribution, 0); - expect(result).toThrowError("Invalid argument"); - }); - it("throws an error rather than divide by 0", () => { - const distribution = new Map(); - distribution.set(foo, 0.0); - const result = () => scoreByMaximumProbability(distribution, 1000); - expect(result).toThrowError("Invariant violation"); + describe("scoreByConstantTotal", () => { + it("works on a simple case", () => { + const distribution = new Map(); + distribution.set(foo, 0.5); + distribution.set(bar, 0.3); + distribution.set(zod, 0.2); + const result = scoreByConstantTotal(distribution, 100, NodeAddress.empty); + expect(result.get(foo)).toEqual(50); + expect(result.get(bar)).toEqual(30); + expect(result.get(zod)).toEqual(20); + }); + it("normalizes based on the totalScore argument", () => { + const distribution = new Map(); + distribution.set(foo, 0.5); + distribution.set(bar, 0.3); + distribution.set(zod, 0.2); + const result = scoreByConstantTotal( + distribution, + 1000, + NodeAddress.empty + ); + expect(result.get(foo)).toEqual(500); + expect(result.get(bar)).toEqual(300); + expect(result.get(zod)).toEqual(200); + }); + it("normalizes based on which nodes match the filter", () => { + const distribution = new Map(); + distribution.set(foo, 0.5); + distribution.set(foobar, 0.5); + distribution.set(bar, 0.3); + distribution.set(zod, 0.2); + const result = scoreByConstantTotal(distribution, 1000, foo); + expect(result.get(foo)).toEqual(500); + expect(result.get(foobar)).toEqual(500); + expect(result.get(bar)).toEqual(300); + expect(result.get(zod)).toEqual(200); + }); + it("handles a case with only a single node", () => { + const distribution = new Map(); + distribution.set(foo, 1.0); + const result = scoreByConstantTotal( + distribution, + 1000, + NodeAddress.empty + ); + expect(result.get(foo)).toEqual(1000); + }); + it("errors if maxScore <= 0", () => { + const distribution = new Map(); + distribution.set(foo, 1.0); + const result = () => scoreByConstantTotal(distribution, 0, foo); + expect(result).toThrowError("Invalid argument"); + }); + it("throws an error rather than divide by 0", () => { + const distribution = new Map(); + distribution.set(foo, 1.0); + const result = () => scoreByConstantTotal(distribution, 1000, bar); + expect(result).toThrowError( + "Tried to normalize based on nodes with no score" + ); + }); }); });