Allow converting transition matrix to sparse chain (#272)
Summary: This function is mostly useful for easily describing Markov chains in test cases. Test Plan: Unit tests added. Run `yarn test`. wchargin-branch: sparseMarkovChainFromTransitionMatrix
This commit is contained in:
parent
3bd449d1c3
commit
e5472752ac
|
@ -26,3 +26,42 @@ export type SparseMarkovChain = $ReadOnlyArray<{|
|
|||
+neighbor: Uint32Array,
|
||||
+weight: Float64Array,
|
||||
|}>;
|
||||
|
||||
export function sparseMarkovChainFromTransitionMatrix(
|
||||
matrix: $ReadOnlyArray<$ReadOnlyArray<number>>
|
||||
): SparseMarkovChain {
|
||||
const n = matrix.length;
|
||||
matrix.forEach((row, i) => {
|
||||
if (row.length !== n) {
|
||||
throw new Error(
|
||||
`expected rows to have length ${n}, but row ${i} has ${row.length}`
|
||||
);
|
||||
}
|
||||
});
|
||||
matrix.forEach((row, i) => {
|
||||
row.forEach((value, j) => {
|
||||
if (isNaN(value) || !isFinite(value) || value < 0) {
|
||||
throw new Error(
|
||||
`expected positive real entries, but [${i}][${j}] is ${value}`
|
||||
);
|
||||
}
|
||||
});
|
||||
});
|
||||
matrix.forEach((row, i) => {
|
||||
const rowsum = row.reduce((a, b) => a + b, 0);
|
||||
if (Math.abs(rowsum - 1) > 1e-6) {
|
||||
throw new Error(
|
||||
`expected rows to sum to 1, but row ${i} sums to ${rowsum}`
|
||||
);
|
||||
}
|
||||
});
|
||||
return matrix.map((_, j) => {
|
||||
const column = matrix
|
||||
.map((row, i) => [i, row[j]])
|
||||
.filter(([_, p]) => p > 0);
|
||||
return {
|
||||
neighbor: new Uint32Array(column.map(([i, _]) => i)),
|
||||
weight: new Float64Array(column.map(([_, p]) => p)),
|
||||
};
|
||||
});
|
||||
}
|
||||
|
|
|
@ -0,0 +1,79 @@
|
|||
// @flow
|
||||
|
||||
import {sparseMarkovChainFromTransitionMatrix} from "./markovChain";
|
||||
|
||||
describe("sparseMarkovChainFromTransitionMatrix", () => {
|
||||
it("works for a simple matrix", () => {
|
||||
const matrix = [[1, 0, 0], [0.25, 0, 0.75], [0.25, 0.75, 0]];
|
||||
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
|
||||
const expected = [
|
||||
{
|
||||
neighbor: new Uint32Array([0, 1, 2]),
|
||||
weight: new Float64Array([1, 0.25, 0.25]),
|
||||
},
|
||||
{
|
||||
neighbor: new Uint32Array([2]),
|
||||
weight: new Float64Array([0.75]),
|
||||
},
|
||||
{
|
||||
neighbor: new Uint32Array([1]),
|
||||
weight: new Float64Array([0.75]),
|
||||
},
|
||||
];
|
||||
expect(chain).toEqual(expected);
|
||||
});
|
||||
|
||||
it("works for the 1-by-1 identity matrix", () => {
|
||||
const matrix = [[1]];
|
||||
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
|
||||
const expected = [
|
||||
{
|
||||
neighbor: new Uint32Array([0]),
|
||||
weight: new Float64Array([1]),
|
||||
},
|
||||
];
|
||||
expect(chain).toEqual(expected);
|
||||
});
|
||||
|
||||
it("works for the 0-by-0 identity matrix", () => {
|
||||
const matrix = [];
|
||||
const chain = sparseMarkovChainFromTransitionMatrix(matrix);
|
||||
const expected = [];
|
||||
expect(chain).toEqual(expected);
|
||||
});
|
||||
|
||||
it("rejects a ragged matrix", () => {
|
||||
const matrix = [[1], [0.5, 0.5]];
|
||||
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||
/length/
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects a matrix with negative entries", () => {
|
||||
const matrix = [[1, 0, 0], [-0.5, 0.75, 0.75], [0, 0, 1]];
|
||||
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||
/positive real.*-0.5/
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects a matrix with NaN entries", () => {
|
||||
const matrix = [[NaN]];
|
||||
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||
/positive real.*NaN/
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects a matrix with infinite entries", () => {
|
||||
const matrix = [[Infinity]];
|
||||
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||
/positive real.*Infinity/
|
||||
);
|
||||
});
|
||||
|
||||
it("rejects a non-stochastic matrix", () => {
|
||||
const matrix = [[1, 0], [0.125, 0.625]];
|
||||
expect(() => sparseMarkovChainFromTransitionMatrix(matrix)).toThrow(
|
||||
/sums to 0.75/
|
||||
);
|
||||
});
|
||||
});
|
Loading…
Reference in New Issue