Allow converting transition matrix to sparse chain (#272)

Summary:
This function is mostly useful for easily describing Markov chains in
test cases.

Test Plan:
Unit tests added. Run `yarn test`.

wchargin-branch: sparseMarkovChainFromTransitionMatrix
This commit is contained in:
William Chargin 2018-05-11 21:22:10 -07:00 committed by GitHub
parent 3bd449d1c3
commit e5472752ac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 118 additions and 0 deletions

View File

@ -26,3 +26,42 @@ export type SparseMarkovChain = $ReadOnlyArray<{|
+neighbor: Uint32Array, +neighbor: Uint32Array,
+weight: Float64Array, +weight: Float64Array,
|}>; |}>;
export function sparseMarkovChainFromTransitionMatrix(
matrix: $ReadOnlyArray<$ReadOnlyArray<number>>
): SparseMarkovChain {
const n = matrix.length;
matrix.forEach((row, i) => {
if (row.length !== n) {
throw new Error(
`expected rows to have length ${n}, but row ${i} has ${row.length}`
);
}
});
matrix.forEach((row, i) => {
row.forEach((value, j) => {
if (isNaN(value) || !isFinite(value) || value < 0) {
throw new Error(
`expected positive real entries, but [${i}][${j}] is ${value}`
);
}
});
});
matrix.forEach((row, i) => {
const rowsum = row.reduce((a, b) => a + b, 0);
if (Math.abs(rowsum - 1) > 1e-6) {
throw new Error(
`expected rows to sum to 1, but row ${i} sums to ${rowsum}`
);
}
});
return matrix.map((_, j) => {
const column = matrix
.map((row, i) => [i, row[j]])
.filter(([_, p]) => p > 0);
return {
neighbor: new Uint32Array(column.map(([i, _]) => i)),
weight: new Float64Array(column.map(([_, p]) => p)),
};
});
}

View File

@ -0,0 +1,79 @@
// @flow
import {sparseMarkovChainFromTransitionMatrix} from "./markovChain";
describe("sparseMarkovChainFromTransitionMatrix", () => {
it("works for a simple matrix", () => {
const matrix = [[1, 0, 0], [0.25, 0, 0.75], [0.25, 0.75, 0]];
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
const expected = [
{
neighbor: new Uint32Array([0, 1, 2]),
weight: new Float64Array([1, 0.25, 0.25]),
},
{
neighbor: new Uint32Array([2]),
weight: new Float64Array([0.75]),
},
{
neighbor: new Uint32Array([1]),
weight: new Float64Array([0.75]),
},
];
expect(chain).toEqual(expected);
});
it("works for the 1-by-1 identity matrix", () => {
const matrix = [[1]];
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
const expected = [
{
neighbor: new Uint32Array([0]),
weight: new Float64Array([1]),
},
];
expect(chain).toEqual(expected);
});
it("works for the 0-by-0 identity matrix", () => {
const matrix = [];
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
const expected = [];
expect(chain).toEqual(expected);
});
it("rejects a ragged matrix", () => {
const matrix = [[1], [0.5, 0.5]];
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
/length/
);
});
it("rejects a matrix with negative entries", () => {
const matrix = [[1, 0, 0], [-0.5, 0.75, 0.75], [0, 0, 1]];
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
/positive real.*-0.5/
);
});
it("rejects a matrix with NaN entries", () => {
const matrix = [[NaN]];
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
/positive real.*NaN/
);
});
it("rejects a matrix with infinite entries", () => {
const matrix = [[Infinity]];
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
/positive real.*Infinity/
);
});
it("rejects a non-stochastic matrix", () => {
const matrix = [[1, 0], [0.125, 0.625]];
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
/sums to 0.75/
);
});
});