Copy Markov chain code from V1 to V3 (#425)
Summary: This code is independent of the graph abstraction, and so is mostly copied. The only change is to the structure of the test code (we now prefer to wrap everything in a big `describe` block with an absolute path to the module under test). Test Plan: Unit tests included. wchargin-branch: v3-markov-chain
This commit is contained in:
parent
659fc51d9b
commit
faa2f8c9d0
|
@ -0,0 +1,143 @@
|
||||||
|
// @flow
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A distribution over the integers `0` through `n - 1`, where `n` is
|
||||||
|
* the length of the array. The value at index `i` is the probability of
|
||||||
|
* `i` in the distribution. The values should sum to 1.
|
||||||
|
*/
|
||||||
|
export type Distribution = Float64Array;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A representation of a sparse transition matrix that is convenient for
|
||||||
|
* computations on Markov chains.
|
||||||
|
*
|
||||||
|
* A Markov chain has nodes indexed from `0` to `n - 1`, where `n` is
|
||||||
|
* the length of the chain. The elements of the chain represent the
|
||||||
|
* incoming edges to each node. Specifically, for each node `v`, the
|
||||||
|
* in-degree of `v` equals the length of both `chain[v].neighbor` and
|
||||||
|
* `chain[v].weight`. For each `i` from `0` to the degree of `v`
|
||||||
|
* (exclusive), there is an edge to `v` from `chain[v].neighbor[i]` with
|
||||||
|
* weight `chain[v].weight[i]`.
|
||||||
|
*
|
||||||
|
* In other words, `chain[v]` is a sparse-vector representation of
|
||||||
|
* column `v` of the transition matrix of the Markov chain.
|
||||||
|
*/
|
||||||
|
export type SparseMarkovChain = $ReadOnlyArray<{|
|
||||||
|
+neighbor: Uint32Array,
|
||||||
|
+weight: Float64Array,
|
||||||
|
|}>;
|
||||||
|
|
||||||
|
export function sparseMarkovChainFromTransitionMatrix(
|
||||||
|
matrix: $ReadOnlyArray<$ReadOnlyArray<number>>
|
||||||
|
): SparseMarkovChain {
|
||||||
|
const n = matrix.length;
|
||||||
|
matrix.forEach((row, i) => {
|
||||||
|
if (row.length !== n) {
|
||||||
|
throw new Error(
|
||||||
|
`expected rows to have length ${n}, but row ${i} has ${row.length}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
matrix.forEach((row, i) => {
|
||||||
|
row.forEach((value, j) => {
|
||||||
|
if (isNaN(value) || !isFinite(value) || value < 0) {
|
||||||
|
throw new Error(
|
||||||
|
`expected positive real entries, but [${i}][${j}] is ${value}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
matrix.forEach((row, i) => {
|
||||||
|
const rowsum = row.reduce((a, b) => a + b, 0);
|
||||||
|
if (Math.abs(rowsum - 1) > 1e-6) {
|
||||||
|
throw new Error(
|
||||||
|
`expected rows to sum to 1, but row ${i} sums to ${rowsum}`
|
||||||
|
);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
return matrix.map((_, j) => {
|
||||||
|
const column = matrix
|
||||||
|
.map((row, i) => [i, row[j]])
|
||||||
|
.filter(([_, p]) => p > 0);
|
||||||
|
return {
|
||||||
|
neighbor: new Uint32Array(column.map(([i, _]) => i)),
|
||||||
|
weight: new Float64Array(column.map(([_, p]) => p)),
|
||||||
|
};
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
export function uniformDistribution(n: number): Distribution {
|
||||||
|
if (isNaN(n) || !isFinite(n) || n !== Math.floor(n) || n <= 0) {
|
||||||
|
throw new Error("expected positive integer, but got: " + n);
|
||||||
|
}
|
||||||
|
return new Float64Array(n).fill(1 / n);
|
||||||
|
}
|
||||||
|
|
||||||
|
export function sparseMarkovChainAction(
|
||||||
|
chain: SparseMarkovChain,
|
||||||
|
pi: Distribution
|
||||||
|
): Distribution {
|
||||||
|
const result = new Float64Array(pi.length);
|
||||||
|
chain.forEach(({neighbor, weight}, dst) => {
|
||||||
|
const inDegree = neighbor.length; // (also `weight.length`)
|
||||||
|
let probability = 0;
|
||||||
|
for (let i = 0; i < inDegree; i++) {
|
||||||
|
const src = neighbor[i];
|
||||||
|
probability += pi[src] * weight[i];
|
||||||
|
}
|
||||||
|
result[dst] = probability;
|
||||||
|
});
|
||||||
|
return result;
|
||||||
|
}
|
||||||
|
|
||||||
|
export function findStationaryDistribution(
|
||||||
|
chain: SparseMarkovChain,
|
||||||
|
options?: {|
|
||||||
|
+verbose?: boolean,
|
||||||
|
+convergenceThreshold?: number,
|
||||||
|
+maxIterations?: number,
|
||||||
|
|}
|
||||||
|
): Distribution {
|
||||||
|
const fullOptions = {
|
||||||
|
verbose: false,
|
||||||
|
convergenceThreshold: 1e-7,
|
||||||
|
maxIterations: 255,
|
||||||
|
...(options || {}),
|
||||||
|
};
|
||||||
|
let r0 = uniformDistribution(chain.length);
|
||||||
|
function computeDelta(pi0, pi1) {
|
||||||
|
let maxDelta = -Infinity;
|
||||||
|
// Here, we assume that `pi0.nodeOrder` and `pi1.nodeOrder` are the
|
||||||
|
// same (i.e., there has been no permutation).
|
||||||
|
pi0.forEach((x, i) => {
|
||||||
|
const delta = Math.abs(x - pi1[i]);
|
||||||
|
maxDelta = Math.max(delta, maxDelta);
|
||||||
|
});
|
||||||
|
return maxDelta;
|
||||||
|
}
|
||||||
|
let iteration = 0;
|
||||||
|
while (true) {
|
||||||
|
iteration++;
|
||||||
|
const r1 = sparseMarkovChainAction(chain, r0);
|
||||||
|
const delta = computeDelta(r0, r1);
|
||||||
|
r0 = r1;
|
||||||
|
if (fullOptions.verbose) {
|
||||||
|
console.log(`[${iteration}] delta = ${delta}`);
|
||||||
|
}
|
||||||
|
if (delta < fullOptions.convergenceThreshold) {
|
||||||
|
if (fullOptions.verbose) {
|
||||||
|
console.log(`[${iteration}] CONVERGED`);
|
||||||
|
}
|
||||||
|
return r0;
|
||||||
|
}
|
||||||
|
if (iteration >= fullOptions.maxIterations) {
|
||||||
|
if (fullOptions.verbose) {
|
||||||
|
console.log(`[${iteration}] FAILED to converge`);
|
||||||
|
}
|
||||||
|
return r0;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// ESLint knows that this next line is unreachable, but Flow doesn't. :-)
|
||||||
|
// eslint-disable-next-line no-unreachable
|
||||||
|
throw new Error("Unreachable.");
|
||||||
|
}
|
|
@ -0,0 +1,186 @@
|
||||||
|
// @flow
|
||||||
|
|
||||||
|
import type {Distribution, SparseMarkovChain} from "./markovChain";
|
||||||
|
import {
|
||||||
|
findStationaryDistribution,
|
||||||
|
sparseMarkovChainAction,
|
||||||
|
sparseMarkovChainFromTransitionMatrix,
|
||||||
|
uniformDistribution,
|
||||||
|
} from "./markovChain";
|
||||||
|
|
||||||
|
describe("core/attribution/markovChain", () => {
|
||||||
|
describe("sparseMarkovChainFromTransitionMatrix", () => {
|
||||||
|
it("works for a simple matrix", () => {
|
||||||
|
const matrix = [[1, 0, 0], [0.25, 0, 0.75], [0.25, 0.75, 0]];
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
|
||||||
|
const expected = [
|
||||||
|
{
|
||||||
|
neighbor: new Uint32Array([0, 1, 2]),
|
||||||
|
weight: new Float64Array([1, 0.25, 0.25]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
neighbor: new Uint32Array([2]),
|
||||||
|
weight: new Float64Array([0.75]),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
neighbor: new Uint32Array([1]),
|
||||||
|
weight: new Float64Array([0.75]),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
expect(chain).toEqual(expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("works for the 1-by-1 identity matrix", () => {
|
||||||
|
const matrix = [[1]];
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
|
||||||
|
const expected = [
|
||||||
|
{
|
||||||
|
neighbor: new Uint32Array([0]),
|
||||||
|
weight: new Float64Array([1]),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
expect(chain).toEqual(expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("works for the 0-by-0 identity matrix", () => {
|
||||||
|
const matrix = [];
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
|
||||||
|
const expected = [];
|
||||||
|
expect(chain).toEqual(expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("rejects a ragged matrix", () => {
|
||||||
|
const matrix = [[1], [0.5, 0.5]];
|
||||||
|
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||||
|
/length/
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("rejects a matrix with negative entries", () => {
|
||||||
|
const matrix = [[1, 0, 0], [-0.5, 0.75, 0.75], [0, 0, 1]];
|
||||||
|
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||||
|
/positive real.*-0.5/
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("rejects a matrix with NaN entries", () => {
|
||||||
|
const matrix = [[NaN]];
|
||||||
|
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||||
|
/positive real.*NaN/
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("rejects a matrix with infinite entries", () => {
|
||||||
|
const matrix = [[Infinity]];
|
||||||
|
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||||
|
/positive real.*Infinity/
|
||||||
|
);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("rejects a non-stochastic matrix", () => {
|
||||||
|
const matrix = [[1, 0], [0.125, 0.625]];
|
||||||
|
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||||
|
/sums to 0.75/
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("uniformDistribution", () => {
|
||||||
|
it("computes the uniform distribution with domain of size 1", () => {
|
||||||
|
const pi = uniformDistribution(1);
|
||||||
|
expect(pi).toEqual(new Float64Array([1]));
|
||||||
|
});
|
||||||
|
it("computes the uniform distribution with domain of size 4", () => {
|
||||||
|
const pi = uniformDistribution(4);
|
||||||
|
expect(pi).toEqual(new Float64Array([0.25, 0.25, 0.25, 0.25]));
|
||||||
|
});
|
||||||
|
[0, -1, Infinity, NaN, 3.5, '"beluga"', null, undefined].forEach((bad) => {
|
||||||
|
it(`fails when given domain ${String(bad)}`, () => {
|
||||||
|
expect(() => uniformDistribution((bad: any))).toThrow(
|
||||||
|
"positive integer"
|
||||||
|
);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
describe("sparseMarkovChainAction", () => {
|
||||||
|
it("acts properly on a nontrivial chain", () => {
|
||||||
|
// Note: this test case uses only real numbers that are exactly
|
||||||
|
// representable as floating point numbers.
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix([
|
||||||
|
[1, 0, 0],
|
||||||
|
[0.25, 0, 0.75],
|
||||||
|
[0.25, 0.75, 0],
|
||||||
|
]);
|
||||||
|
const pi0 = new Float64Array([0.125, 0.375, 0.625]);
|
||||||
|
const pi1 = sparseMarkovChainAction(chain, pi0);
|
||||||
|
// The expected value is given by `pi0 * A`, where `A` is the
|
||||||
|
// transition matrix. In Octave:
|
||||||
|
// >> A = [ 1 0 0; 0.25 0 0.75 ; 0.25 0.75 0 ];
|
||||||
|
// >> pi0 = [ 0.125 0.375 0.625 ];
|
||||||
|
// >> pi1 = pi0 * A;
|
||||||
|
// >> disp(pi1)
|
||||||
|
// 0.37500 0.46875 0.28125
|
||||||
|
const expected = new Float64Array([0.375, 0.46875, 0.28125]);
|
||||||
|
expect(pi1).toEqual(expected);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
function expectAllClose(
|
||||||
|
actual: Float64Array,
|
||||||
|
expected: Float64Array,
|
||||||
|
epsilon: number = 1e-6
|
||||||
|
): void {
|
||||||
|
expect(actual).toHaveLength(expected.length);
|
||||||
|
for (let i = 0; i < expected.length; i++) {
|
||||||
|
if (Math.abs(actual[i] - expected[i]) >= epsilon) {
|
||||||
|
expect(actual).toEqual(expected); // will fail
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
function expectStationary(chain: SparseMarkovChain, pi: Distribution): void {
|
||||||
|
expectAllClose(sparseMarkovChainAction(chain, pi), pi);
|
||||||
|
}
|
||||||
|
|
||||||
|
describe("findStationaryDistribution", () => {
|
||||||
|
it("finds an all-accumulating stationary distribution", () => {
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix([
|
||||||
|
[1, 0, 0],
|
||||||
|
[0.25, 0, 0.75],
|
||||||
|
[0.25, 0.75, 0],
|
||||||
|
]);
|
||||||
|
const pi = findStationaryDistribution(chain);
|
||||||
|
expectStationary(chain, pi);
|
||||||
|
const expected = new Float64Array([1, 0, 0]);
|
||||||
|
expectAllClose(pi, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("finds a non-degenerate stationary distribution", () => {
|
||||||
|
// Node 0 is the "center"; nodes 1 through 4 are "satellites". A
|
||||||
|
// satellite transitions to the center with probability 0.5, or to a
|
||||||
|
// cyclically adjacent satellite with probability 0.25 each. The
|
||||||
|
// center transitions to a uniformly random satellite.
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix([
|
||||||
|
[0, 0.25, 0.25, 0.25, 0.25],
|
||||||
|
[0.5, 0, 0.25, 0, 0.25],
|
||||||
|
[0.5, 0.25, 0, 0.25, 0],
|
||||||
|
[0.5, 0, 0.25, 0, 0.25],
|
||||||
|
[0.5, 0.25, 0, 0.25, 0],
|
||||||
|
]);
|
||||||
|
const pi = findStationaryDistribution(chain);
|
||||||
|
expectStationary(chain, pi);
|
||||||
|
const expected = new Float64Array([1 / 3, 1 / 6, 1 / 6, 1 / 6, 1 / 6]);
|
||||||
|
expectAllClose(pi, expected);
|
||||||
|
});
|
||||||
|
|
||||||
|
it("finds the stationary distribution of a periodic chain", () => {
|
||||||
|
const chain = sparseMarkovChainFromTransitionMatrix([[0, 1], [1, 0]]);
|
||||||
|
const pi = findStationaryDistribution(chain);
|
||||||
|
expectStationary(chain, pi);
|
||||||
|
const expected = new Float64Array([0.5, 0.5]);
|
||||||
|
expectAllClose(pi, expected);
|
||||||
|
});
|
||||||
|
});
|
||||||
|
});
|
Loading…
Reference in New Issue