diff --git a/crc/services/workflow_service.py b/crc/services/workflow_service.py index a76553ce..ee37c688 100644 --- a/crc/services/workflow_service.py +++ b/crc/services/workflow_service.py @@ -22,7 +22,7 @@ from jinja2 import Template from crc import db, app from crc.api.common import ApiError from crc.models.api_models import Task, MultiInstanceType, WorkflowApi -from crc.models.file import LookupDataModel +from crc.models.file import LookupDataModel, FileModel from crc.models.study import StudyModel from crc.models.task_event import TaskEventModel from crc.models.user import UserModel, UserModelSchema @@ -808,3 +808,12 @@ class WorkflowService(object): def get_standalone_workflow_specs(): specs = db.session.query(WorkflowSpecModel).filter_by(standalone=True).all() return specs + + @staticmethod + def get_primary_workflow(workflow_spec_id): + # Returns the FileModel of the primary workflow for a workflow_spec + primary = None + file = db.session.query(FileModel).filter(FileModel.workflow_spec_id==workflow_spec_id, FileModel.primary==True).first() + if file: + primary = file + return primary diff --git a/tests/workflow/test_workflow_service.py b/tests/workflow/test_workflow_service.py index a4af4edc..cb4da9c2 100644 --- a/tests/workflow/test_workflow_service.py +++ b/tests/workflow/test_workflow_service.py @@ -10,6 +10,7 @@ from example_data import ExampleDataLoader from crc import db from crc.models.task_event import TaskEventModel from crc.models.api_models import Task +from crc.models.file import FileModel from crc.api.common import ApiError @@ -114,3 +115,12 @@ class TestWorkflowService(BaseTest): result2 = WorkflowService.get_dot_value(path, {"a.b.c":"garbage"}) self.assertEqual("garbage", result2) + + def test_get_primary_workflow(self): + + workflow = self.create_workflow('hello_world') + workflow_spec_id = workflow.workflow_spec.id + primary_workflow = WorkflowService.get_primary_workflow(workflow_spec_id) + self.assertIsInstance(primary_workflow, FileModel) + self.assertEqual(workflow_spec_id, primary_workflow.workflow_spec_id) + self.assertEqual('hello_world.bpmn', primary_workflow.name)