diff --git a/da/network/dispersal/dispersal.proto b/da/network/dispersal/dispersal.proto index 99fe9eb..5b82eaa 100644 --- a/da/network/dispersal/dispersal.proto +++ b/da/network/dispersal/dispersal.proto @@ -82,5 +82,6 @@ message DispersalMessage { DispersalRes dispersal_res = 2; SampleReq sample_req = 3; SampleRes sample_res = 4; + SessionReq session_req = 5; } } diff --git a/da/network/dispersal/dispersal_pb2.py b/da/network/dispersal/dispersal_pb2.py index 5d7ae35..9984ee2 100644 --- a/da/network/dispersal/dispersal_pb2.py +++ b/da/network/dispersal/dispersal_pb2.py @@ -24,31 +24,37 @@ _sym_db = _symbol_database.Default() -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ispersal.proto\x12\x12nomos.da.dispersal\"%\n\x04\x42lob\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\xb3\x01\n\x0c\x44ispersalErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x43\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32\x31.nomos.da.dispersal.DispersalErr.DispersalErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"4\n\x10\x44ispersalErrType\x12\x0e\n\nCHUNK_SIZE\x10\x00\x12\x10\n\x0cVERIFICATION\x10\x01\"6\n\x0c\x44ispersalReq\x12&\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x18.nomos.da.dispersal.Blob\"b\n\x0c\x44ispersalRes\x12\x11\n\x07\x62lob_id\x18\x01 \x01(\x0cH\x00\x12/\n\x03\x65rr\x18\x02 \x01(\x0b\x32 .nomos.da.dispersal.DispersalErrH\x00\x42\x0e\n\x0cmessage_type\"\x94\x01\n\tSampleErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12=\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32+.nomos.da.dispersal.SampleErr.SampleErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"\x1e\n\rSampleErrType\x12\r\n\tNOT_FOUND\x10\x00\"\x1c\n\tSampleReq\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\"s\n\tSampleRes\x12(\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x18.nomos.da.dispersal.BlobH\x00\x12,\n\x03\x65rr\x18\x02 \x01(\x0b\x32\x1d.nomos.da.dispersal.SampleErrH\x00\x42\x0e\n\x0cmessage_type\"\x82\x02\n\x10\x44ispersalMessage\x12\x39\n\rdispersal_req\x18\x01 \x01(\x0b\x32 .nomos.da.dispersal.DispersalReqH\x00\x12\x39\n\rdispersal_res\x18\x02 \x01(\x0b\x32 .nomos.da.dispersal.DispersalResH\x00\x12\x33\n\nsample_req\x18\x03 \x01(\x0b\x32\x1d.nomos.da.dispersal.SampleReqH\x00\x12\x33\n\nsample_res\x18\x04 \x01(\x0b\x32\x1d.nomos.da.dispersal.SampleResH\x00\x42\x0e\n\x0cmessage_typeb\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0f\x64ispersal.proto\x12\x15nomos.da.dispersal.v1\"%\n\x04\x42lob\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x0c\n\x04\x64\x61ta\x18\x02 \x01(\x0c\"\xb6\x01\n\x0c\x44ispersalErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12\x46\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32\x34.nomos.da.dispersal.v1.DispersalErr.DispersalErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"4\n\x10\x44ispersalErrType\x12\x0e\n\nCHUNK_SIZE\x10\x00\x12\x10\n\x0cVERIFICATION\x10\x01\"9\n\x0c\x44ispersalReq\x12)\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x1b.nomos.da.dispersal.v1.Blob\"e\n\x0c\x44ispersalRes\x12\x11\n\x07\x62lob_id\x18\x01 \x01(\x0cH\x00\x12\x32\n\x03\x65rr\x18\x02 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalErrH\x00\x42\x0e\n\x0cmessage_type\"\x97\x01\n\tSampleErr\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\x12@\n\x08\x65rr_type\x18\x02 \x01(\x0e\x32..nomos.da.dispersal.v1.SampleErr.SampleErrType\x12\x17\n\x0f\x65rr_description\x18\x03 \x01(\t\"\x1e\n\rSampleErrType\x12\r\n\tNOT_FOUND\x10\x00\"\x1c\n\tSampleReq\x12\x0f\n\x07\x62lob_id\x18\x01 \x01(\x0c\"y\n\tSampleRes\x12+\n\x04\x62lob\x18\x01 \x01(\x0b\x32\x1b.nomos.da.dispersal.v1.BlobH\x00\x12/\n\x03\x65rr\x18\x02 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleErrH\x00\x42\x0e\n\x0cmessage_type\"\x98\x01\n\x08\x43loseMsg\x12;\n\x06reason\x18\x01 \x01(\x0e\x32+.nomos.da.dispersal.v1.CloseMsg.CloseReason\"O\n\x0b\x43loseReason\x12\x15\n\x11GRACEFUL_SHUTDOWN\x10\x00\x12\x11\n\rSUBNET_CHANGE\x10\x01\x12\x16\n\x12SUBNET_SAMPLE_FAIL\x10\x02\"R\n\nSessionReq\x12\x34\n\tclose_msg\x18\x01 \x01(\x0b\x32\x1f.nomos.da.dispersal.v1.CloseMsgH\x00\x42\x0e\n\x0cmessage_type\"\xc8\x02\n\x10\x44ispersalMessage\x12<\n\rdispersal_req\x18\x01 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalReqH\x00\x12<\n\rdispersal_res\x18\x02 \x01(\x0b\x32#.nomos.da.dispersal.v1.DispersalResH\x00\x12\x36\n\nsample_req\x18\x03 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleReqH\x00\x12\x36\n\nsample_res\x18\x04 \x01(\x0b\x32 .nomos.da.dispersal.v1.SampleResH\x00\x12\x38\n\x0bsession_req\x18\x05 \x01(\x0b\x32!.nomos.da.dispersal.v1.SessionReqH\x00\x42\x0e\n\x0cmessage_typeb\x06proto3') _globals = globals() _builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals) _builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'dispersal_pb2', _globals) if not _descriptor._USE_C_DESCRIPTORS: DESCRIPTOR._loaded_options = None - _globals['_BLOB']._serialized_start=39 - _globals['_BLOB']._serialized_end=76 - _globals['_DISPERSALERR']._serialized_start=79 - _globals['_DISPERSALERR']._serialized_end=258 - _globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_start=206 - _globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_end=258 - _globals['_DISPERSALREQ']._serialized_start=260 - _globals['_DISPERSALREQ']._serialized_end=314 - _globals['_DISPERSALRES']._serialized_start=316 - _globals['_DISPERSALRES']._serialized_end=414 - _globals['_SAMPLEERR']._serialized_start=417 - _globals['_SAMPLEERR']._serialized_end=565 - _globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_start=535 - _globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_end=565 - _globals['_SAMPLEREQ']._serialized_start=567 - _globals['_SAMPLEREQ']._serialized_end=595 - _globals['_SAMPLERES']._serialized_start=597 - _globals['_SAMPLERES']._serialized_end=712 - _globals['_DISPERSALMESSAGE']._serialized_start=715 - _globals['_DISPERSALMESSAGE']._serialized_end=973 + _globals['_BLOB']._serialized_start=42 + _globals['_BLOB']._serialized_end=79 + _globals['_DISPERSALERR']._serialized_start=82 + _globals['_DISPERSALERR']._serialized_end=264 + _globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_start=212 + _globals['_DISPERSALERR_DISPERSALERRTYPE']._serialized_end=264 + _globals['_DISPERSALREQ']._serialized_start=266 + _globals['_DISPERSALREQ']._serialized_end=323 + _globals['_DISPERSALRES']._serialized_start=325 + _globals['_DISPERSALRES']._serialized_end=426 + _globals['_SAMPLEERR']._serialized_start=429 + _globals['_SAMPLEERR']._serialized_end=580 + _globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_start=550 + _globals['_SAMPLEERR_SAMPLEERRTYPE']._serialized_end=580 + _globals['_SAMPLEREQ']._serialized_start=582 + _globals['_SAMPLEREQ']._serialized_end=610 + _globals['_SAMPLERES']._serialized_start=612 + _globals['_SAMPLERES']._serialized_end=733 + _globals['_CLOSEMSG']._serialized_start=736 + _globals['_CLOSEMSG']._serialized_end=888 + _globals['_CLOSEMSG_CLOSEREASON']._serialized_start=809 + _globals['_CLOSEMSG_CLOSEREASON']._serialized_end=888 + _globals['_SESSIONREQ']._serialized_start=890 + _globals['_SESSIONREQ']._serialized_end=972 + _globals['_DISPERSALMESSAGE']._serialized_start=975 + _globals['_DISPERSALMESSAGE']._serialized_end=1303 # @@protoc_insertion_point(module_scope) diff --git a/da/network/dispersal/mock_system.py b/da/network/dispersal/mock_system.py index 6d7ebd2..b1821c4 100644 --- a/da/network/dispersal/mock_system.py +++ b/da/network/dispersal/mock_system.py @@ -15,7 +15,7 @@ class MockTransport: async def read_and_process(self): try: while True: - message = await proto.parse_from_reader(self.reader) + message = await proto.unpack_from_reader(self.reader) await self.handler(self.conn_id, self.writer, message) except Exception as e: print(f"MockTransport: An error occurred: {e}") @@ -102,7 +102,7 @@ class MockExecutor: async def run(self): await asyncio.gather(*(self._connect() for _ in range(self.col_num))) - await self._execute() + await self.executor() class MockSystem: diff --git a/da/network/dispersal/proto.py b/da/network/dispersal/proto.py index 022c833..707a885 100644 --- a/da/network/dispersal/proto.py +++ b/da/network/dispersal/proto.py @@ -3,19 +3,24 @@ from itertools import count MAX_MSG_LEN_BYTES = 2 -async def parse_from_reader(reader): - length_prefix = await reader.readexactly(MAX_MSG_LEN_BYTES) - data_length = int.from_bytes(length_prefix, byteorder='big') - data = await reader.readexactly(data_length) - return unpack_message(data) - def pack_message(message): # SerializeToString method returns an instance of bytes. data = message.SerializeToString() length_prefix = len(data).to_bytes(MAX_MSG_LEN_BYTES, byteorder='big') return length_prefix + data -def unpack_message(data): +async def unpack_from_reader(reader): + length_prefix = await reader.readexactly(MAX_MSG_LEN_BYTES) + data_length = int.from_bytes(length_prefix, byteorder='big') + data = await reader.readexactly(data_length) + return parse(data) + +def unpack_from_bytes(data): + length_prefix = data[:MAX_MSG_LEN_BYTES] + data_length = int.from_bytes(length_prefix, byteorder='big') + return parse(data[MAX_MSG_LEN_BYTES:MAX_MSG_LEN_BYTES + data_length]) + +def parse(data): message = dispersal_pb2.DispersalMessage() message.ParseFromString(data) return message @@ -86,13 +91,16 @@ def new_session_req_close_msg(reason): close_msg = new_close_msg(reason) session_req = dispersal_pb2.SessionReq(close_msg=close_msg) dispersal_message = dispersal_pb2.DispersalMessage(session_req=session_req) - return pack_message(dispersal_message) + return dispersal_message def new_session_req_graceful_shutdown_msg(): - new_session_req_close_msg(dispersal_pb2.CloseMsg.GRACEFUL_SHUTDOWN) + message = new_session_req_close_msg(dispersal_pb2.CloseMsg.GRACEFUL_SHUTDOWN) + return pack_message(message) def new_session_req_subnet_change_msg(): - new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_CHANGE) + message = new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_CHANGE) + return pack_message(message) def new_session_req_subnet_sample_fail_msg(): - new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_SAMPLE_FAIL) + message = new_session_req_close_msg(dispersal_pb2.CloseMsg.SUBNET_SAMPLE_FAIL) + return pack_message(message) diff --git a/da/network/dispersal/test_proto_helpers.py b/da/network/dispersal/test_proto_helpers.py new file mode 100644 index 0000000..cde27ca --- /dev/null +++ b/da/network/dispersal/test_proto_helpers.py @@ -0,0 +1,75 @@ +import dispersal_pb2 +import proto +from unittest import TestCase + +class TestMessageSerialization(TestCase): + + def test_dispersal_req_msg(self): + blob_id = b"dummy_blob_id" + data = b"dummy_data" + packed_message = proto.new_dispersal_req_msg(blob_id, data) + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('dispersal_req')) + self.assertEqual(message.dispersal_req.blob.blob_id, blob_id) + self.assertEqual(message.dispersal_req.blob.data, data) + + def test_dispersal_res_success_msg(self): + blob_id = b"dummy_blob_id" + packed_message = proto.new_dispersal_res_success_msg(blob_id) + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('dispersal_res')) + self.assertEqual(message.dispersal_res.blob_id, blob_id) + + def test_dispersal_res_chunk_size_error_msg(self): + blob_id = b"dummy_blob_id" + description = "Chunk size error" + packed_message = proto.new_dispersal_res_chunk_size_error_msg(blob_id, description) + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('dispersal_res')) + self.assertEqual(message.dispersal_res.err.blob_id, blob_id) + self.assertEqual(message.dispersal_res.err.err_type, dispersal_pb2.DispersalErr.CHUNK_SIZE) + self.assertEqual(message.dispersal_res.err.err_description, description) + + def test_dispersal_res_verification_error_msg(self): + blob_id = b"dummy_blob_id" + description = "Verification error" + packed_message = proto.new_dispersal_res_verification_error_msg(blob_id, description) + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('dispersal_res')) + self.assertEqual(message.dispersal_res.err.blob_id, blob_id) + self.assertEqual(message.dispersal_res.err.err_type, dispersal_pb2.DispersalErr.VERIFICATION) + self.assertEqual(message.dispersal_res.err.err_description, description) + + def test_sample_req_msg(self): + blob_id = b"dummy_blob_id" + packed_message = proto.new_sample_req_msg(blob_id) + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('sample_req')) + self.assertEqual(message.sample_req.blob_id, blob_id) + + def test_sample_res_success_msg(self): + blob_id = b"dummy_blob_id" + data = b"dummy_data" + packed_message = proto.new_sample_res_success_msg(blob_id, data) + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('sample_res')) + self.assertEqual(message.sample_res.blob.blob_id, blob_id) + self.assertEqual(message.sample_res.blob.data, data) + + def test_sample_res_not_found_error_msg(self): + blob_id = b"dummy_blob_id" + description = "Blob not found" + packed_message = proto.new_sample_res_not_found_error_msg(blob_id, description) + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('sample_res')) + self.assertEqual(message.sample_res.err.blob_id, blob_id) + self.assertEqual(message.sample_res.err.err_type, dispersal_pb2.SampleErr.NOT_FOUND) + self.assertEqual(message.sample_res.err.err_description, description) + + def test_session_req_close_msg(self): + reason = dispersal_pb2.CloseMsg.GRACEFUL_SHUTDOWN + packed_message = proto.new_session_req_graceful_shutdown_msg() + message = proto.unpack_from_bytes(packed_message) + self.assertTrue(message.HasField('session_req')) + self.assertTrue(message.session_req.HasField('close_msg')) + self.assertEqual(message.session_req.close_msg.reason, reason)