diff --git a/src/cli/common.js b/src/cli/common.js index 9a6aa9b..0306c28 100644 --- a/src/cli/common.js +++ b/src/cli/common.js @@ -4,6 +4,8 @@ import os from "os"; import path from "path"; import deepFreeze from "deep-freeze"; +import fs from "fs-extra"; +import {type Weights, fromJSON as weightsFromJSON} from "../analysis/weights"; import * as NullUtil from "../util/null"; @@ -27,3 +29,16 @@ export function githubToken(): string | null { export function discourseKey(): string | null { return NullUtil.orElse(process.env.SOURCECRED_DISCOURSE_KEY, null); } + +export async function loadWeights(path: string): Promise { + if (!(await fs.exists(path))) { + throw new Error("Could not find the weights file"); + } + const raw = await fs.readFile(path, "utf-8"); + const weightsJSON = JSON.parse(raw); + try { + return weightsFromJSON(weightsJSON); + } catch (e) { + throw new Error(`provided weights file is invalid:\n${e}`); + } +} diff --git a/src/cli/common.test.js b/src/cli/common.test.js index bd4c335..1ab8c6c 100644 --- a/src/cli/common.test.js +++ b/src/cli/common.test.js @@ -1,6 +1,10 @@ // @flow import path from "path"; +import tmp from "tmp"; +import fs from "fs-extra"; +import {defaultWeights, toJSON as weightsToJSON} from "../analysis/weights"; +import {NodeAddress} from "../core/graph"; import { defaultPlugins, @@ -8,6 +12,7 @@ import { sourcecredDirectory, githubToken, discourseKey, + loadWeights, } from "./common"; describe("cli/common", () => { @@ -66,4 +71,36 @@ describe("cli/common", () => { expect(discourseKey()).toBe(null); }); }); + + describe("loadWeights", () => { + function tmpWithContents(contents: mixed) { + const name = tmp.tmpNameSync(); + fs.writeFileSync(name, JSON.stringify(contents)); + return name; + } + it("works in a simple success case", async () => { + const weights = defaultWeights(); + // Make a modification, just to be sure we aren't always loading the + // default weights. + weights.nodeManualWeights.set(NodeAddress.empty, 3); + const weightsJSON = weightsToJSON(weights); + const file = tmpWithContents(weightsJSON); + const weights_ = await loadWeights(file); + expect(weights).toEqual(weights_); + }); + it("rejects if the file is not a valid weights file", () => { + const file = tmpWithContents(1234); + expect.assertions(1); + return loadWeights(file).catch((e) => + expect(e.message).toMatch("provided weights file is invalid:") + ); + }); + it("rejects if the file does not exist", () => { + const file = tmp.tmpNameSync(); + expect.assertions(1); + return loadWeights(file).catch((e) => + expect(e.message).toMatch("Could not find the weights file") + ); + }); + }); }); diff --git a/src/cli/load.js b/src/cli/load.js index 83455ea..9ebb0cc 100644 --- a/src/cli/load.js +++ b/src/cli/load.js @@ -5,10 +5,9 @@ import dedent from "../util/dedent"; import {LoggingTaskReporter} from "../util/taskReporter"; import type {Command} from "./command"; import * as Common from "./common"; -import {defaultWeights, fromJSON as weightsFromJSON} from "../analysis/weights"; +import {defaultWeights} from "../analysis/weights"; import {load} from "../api/load"; import {specToProject} from "../plugins/github/specToProject"; -import fs from "fs-extra"; import {partialParams} from "../analysis/timeline/params"; import {DEFAULT_PLUGINS} from "./defaultPlugins"; @@ -93,7 +92,7 @@ const loadCommand: Command = async (args, std) => { let weights = defaultWeights(); if (weightsPath) { - weights = await loadWeightOverrides(weightsPath); + weights = await Common.loadWeights(weightsPath); } const githubToken = Common.githubToken(); @@ -124,20 +123,6 @@ const loadCommand: Command = async (args, std) => { return 0; }; -const loadWeightOverrides = async (path: string) => { - if (!(await fs.exists(path))) { - throw new Error("Could not find the weights file"); - } - - const raw = await fs.readFile(path, "utf-8"); - const weightsJSON = JSON.parse(raw); - try { - return weightsFromJSON(weightsJSON); - } catch (e) { - throw new Error(`provided weights file is invalid:\n${e}`); - } -}; - export const help: Command = async (args, std) => { if (args.length === 0) { usage(std.out);