diff --git a/crc/api/workflow.py b/crc/api/workflow.py index c8fd95ab..b2e0bd4a 100644 --- a/crc/api/workflow.py +++ b/crc/api/workflow.py @@ -1,7 +1,7 @@ import uuid -from crc.api.common import ApiError, ApiErrorSchema from crc import session +from crc.api.common import ApiError, ApiErrorSchema from crc.models.workflow import WorkflowModel, WorkflowModelSchema, WorkflowSpecModelSchema, WorkflowSpecModel, \ Task, TaskSchema from crc.workflow_processor import WorkflowProcessor @@ -20,26 +20,20 @@ def add_workflow_specification(body): def update_workflow_specification(spec_id, body): - if spec_id is None: error = ApiError('unknown_study', 'Please provide a valid Workflow Specification ID.') return ApiErrorSchema.dump(error), 404 - db_spec = session.query(WorkflowSpecModel).filter_by(id=spec_id).first() - """:type: WorkflowSpecModel""" + spec: WorkflowSpecModel = session.query(WorkflowSpecModel).filter_by(id=spec_id).first() - if db_spec is None: + if spec is None: error = ApiError('unknown_study', 'The Workflow Specification "' + spec_id + '" is not recognized.') return ApiErrorSchema.dump(error), 404 - new_spec = WorkflowSpecModelSchema().load(body, session=session) - """:type: WorkflowSpecModel""" - - db_spec.id = new_spec.id - db_spec.display_name = new_spec.display_name - db_spec.description = new_spec.description + spec = WorkflowSpecModelSchema().load(body, session=session) + session.add(spec) session.commit() - return WorkflowSpecModelSchema().dump(db_spec) + return WorkflowSpecModelSchema().dump(spec) def get_workflow(workflow_id): diff --git a/tests/test_api.py b/tests/test_api.py index 2acd5a70..77390ec2 100644 --- a/tests/test_api.py +++ b/tests/test_api.py @@ -103,7 +103,7 @@ class TestStudy(BaseTest, unittest.TestCase): def test_modify_workflow_specification(self): self.load_example_data() old_id = 'random_fact' - spec = session.query(WorkflowSpecModel).filter_by(id=old_id).first() + spec: WorkflowSpecModel = session.query(WorkflowSpecModel).filter_by(id=old_id).first() """:type: WorkflowSpecModel""" spec.id = 'odd_datum' @@ -114,6 +114,10 @@ class TestStudy(BaseTest, unittest.TestCase): self.assert_success(rv) db_spec = session.query(WorkflowSpecModel).filter_by(id=spec.id).first() self.assertEqual(spec.display_name, db_spec.display_name) + + num_old_after = session.query(WorkflowSpecModel).filter_by(id=old_id).count() + self.assertEqual(num_old_after, 0) + num_after = session.query(WorkflowSpecModel).count() self.assertEqual(num_after, num_before)