diff --git a/gossipsub.go b/gossipsub.go index d125c1c..2e90b61 100644 --- a/gossipsub.go +++ b/gossipsub.go @@ -939,10 +939,14 @@ func fragmentRPC(rpc *RPC, limit int) ([]*RPC, error) { // outRPC returns the current RPC message if it will fit sizeToAdd more bytes // otherwise, it will create a new RPC message and add it to the list. - // if withCtl is true, the new RPC message will have a non-nil empty Control message. + // if withCtl is true, the returned message will have a non-nil empty Control message. outRPC := func(sizeToAdd int, withCtl bool) *RPC { current := rpcs[len(rpcs)-1] - if current.Size()+sizeToAdd < limit { + // check if we can fit the new data, plus an extra byte for the protobuf field tag + if current.Size()+sizeToAdd+1 < limit { + if withCtl && current.Control == nil { + current.Control = &pb.ControlMessage{} + } return current } var ctl *pb.ControlMessage @@ -982,9 +986,6 @@ func fragmentRPC(rpc *RPC, limit int) ([]*RPC, error) { } // we need to split up the control messages into multiple RPCs - // add a blank rpc message to the end of the list, then use outRPC to get or create - // RPC messages to fit each control message - rpcs = append(rpcs, &RPC{RPC: pb.RPC{Control: &pb.ControlMessage{}}, from: rpc.from}) for _, graft := range ctl.Graft { out := outRPC(graft.Size(), true) out.Control.Graft = append(out.Control.Graft, graft) @@ -993,17 +994,56 @@ func fragmentRPC(rpc *RPC, limit int) ([]*RPC, error) { out := outRPC(prune.Size(), true) out.Control.Prune = append(out.Control.Prune, prune) } + + // An individual IWANT or IHAVE message could be larger than the limit if we have + // a lot of message IDs. fragmentMessageIds will split them into buckets that + // fit within the limit, with some overhead for the control messages themselves for _, iwant := range ctl.Iwant { - out := outRPC(iwant.Size(), true) - out.Control.Iwant = append(out.Control.Iwant, iwant) + const protobufOverhead = 6 + idBuckets := fragmentMessageIds(iwant.MessageIDs, limit-protobufOverhead) + for _, ids := range idBuckets { + iwant := &pb.ControlIWant{MessageIDs: ids} + out := outRPC(iwant.Size(), true) + out.Control.Iwant = append(out.Control.Iwant, iwant) + } } for _, ihave := range ctl.Ihave { - out := outRPC(ihave.Size(), true) - out.Control.Ihave = append(out.Control.Ihave, ihave) + const protobufOverhead = 6 + idBuckets := fragmentMessageIds(ihave.MessageIDs, limit-protobufOverhead) + for _, ids := range idBuckets { + ihave := &pb.ControlIHave{MessageIDs: ids} + out := outRPC(ihave.Size(), true) + out.Control.Ihave = append(out.Control.Ihave, ihave) + } } return rpcs, nil } +func fragmentMessageIds(msgIds []string, limit int) [][]string { + // account for two bytes of protobuf overhead per array element + const protobufOverhead = 2 + + out := [][]string{{}} + var currentBucket int + var bucketLen int + for i := 0; i < len(msgIds); i++ { + size := len(msgIds[i]) + protobufOverhead + if size > limit { + // pathological case where a single message ID exceeds the limit. + log.Warnf("message ID length %d exceeds limit %d, removing from outgoing gossip", size, limit) + continue + } + bucketLen += size + if bucketLen > limit { + out = append(out, []string{}) + currentBucket++ + bucketLen = size + } + out[currentBucket] = append(out[currentBucket], msgIds[i]) + } + return out +} + func (gs *GossipSubRouter) heartbeatTimer() { time.Sleep(GossipSubHeartbeatInitialDelay) select { diff --git a/gossipsub_test.go b/gossipsub_test.go index ac9c80a..946ace0 100644 --- a/gossipsub_test.go +++ b/gossipsub_test.go @@ -1853,3 +1853,169 @@ func (iwe *iwantEverything) handleStream(s network.Stream) { iwe.lk.Unlock() } } + +func TestFragmentRPCFunction(t *testing.T) { + p := peer.ID("some-peer") + topic := "test" + rpc := &RPC{from: p} + limit := 1024 + + mkMsg := func(size int) *pb.Message { + msg := &pb.Message{} + msg.Data = make([]byte, size-4) // subtract the protobuf overhead, so msg.Size() returns requested size + rand.Read(msg.Data) + return msg + } + + ensureBelowLimit := func(rpcs []*RPC) { + for _, r := range rpcs { + if r.Size() > limit { + t.Fatalf("expected fragmented RPC to be below %d bytes, was %d", limit, r.Size()) + } + } + } + + // it should not fragment if everything fits in one RPC + rpc.Publish = []*pb.Message{} + rpc.Publish = []*pb.Message{mkMsg(10), mkMsg(10)} + results, err := fragmentRPC(rpc, limit) + if err != nil { + t.Fatal(err) + } + if len(results) != 1 { + t.Fatalf("expected single RPC if input is < limit, got %d", len(results)) + } + + // if there's a message larger than the limit, we should fail + rpc.Publish = []*pb.Message{mkMsg(10), mkMsg(limit * 2)} + results, err = fragmentRPC(rpc, limit) + if err == nil { + t.Fatalf("expected an error if a message exceeds limit, got %d RPCs instead", len(results)) + } + + // if the individual messages are below the limit, but the RPC as a whole is larger, we should fragment + nMessages := 100 + msgSize := 200 + truth := true + rpc.Subscriptions = []*pb.RPC_SubOpts{ + { + Subscribe: &truth, + Topicid: &topic, + }, + } + rpc.Publish = make([]*pb.Message, nMessages) + for i := 0; i < nMessages; i++ { + rpc.Publish[i] = mkMsg(msgSize) + } + results, err = fragmentRPC(rpc, limit) + if err != nil { + t.Fatal(err) + } + ensureBelowLimit(results) + msgsPerRPC := limit / msgSize + expectedRPCs := nMessages / msgsPerRPC + if len(results) != expectedRPCs { + t.Fatalf("expected %d RPC messages in output, got %d", expectedRPCs, len(results)) + } + var nMessagesFragmented int + var nSubscriptions int + for _, r := range results { + nMessagesFragmented += len(r.Publish) + nSubscriptions += len(r.Subscriptions) + } + if nMessagesFragmented != nMessages { + t.Fatalf("expected fragemented RPCs to contain same number of messages as input, got %d / %d", nMessagesFragmented, nMessages) + } + if nSubscriptions != 1 { + t.Fatal("expected subscription to be present in one of the fragmented messages, but not found") + } + + // if we're fragmenting, and the input RPC has control messages, + // the control messages should be in a separate RPC at the end + // reuse RPC from prev test, but add a control message + rpc.Control = &pb.ControlMessage{ + Graft: []*pb.ControlGraft{{TopicID: &topic}}, + Prune: []*pb.ControlPrune{{TopicID: &topic}}, + Ihave: []*pb.ControlIHave{{MessageIDs: []string{"foo"}}}, + Iwant: []*pb.ControlIWant{{MessageIDs: []string{"bar"}}}, + } + results, err = fragmentRPC(rpc, limit) + if err != nil { + t.Fatal(err) + } + ensureBelowLimit(results) + // we expect one more RPC than last time, with the final one containing the control messages + expectedCtrl := 1 + expectedRPCs = (nMessages / msgsPerRPC) + expectedCtrl + if len(results) != expectedRPCs { + t.Fatalf("expected %d RPC messages in output, got %d", expectedRPCs, len(results)) + } + ctl := results[len(results)-1].Control + if ctl == nil { + t.Fatal("expected final fragmented RPC to contain control messages, but .Control was nil") + } + // since it was not altered, the original control message should be identical to the output control message + originalBytes, err := rpc.Control.Marshal() + if err != nil { + t.Fatal(err) + } + receivedBytes, err := ctl.Marshal() + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(originalBytes, receivedBytes) { + t.Fatal("expected control message to be unaltered if it fits within one RPC message") + } + + // if the control message is too large to fit into a single RPC, it should be split into multiple RPCs + nTopics := 5 // pretend we're subscribed to multiple topics and sending IHAVE / IWANTs for each + messageIdSize := 32 + msgsPerTopic := 100 // enough that a single IHAVE or IWANT will exceed the limit + rpc.Control.Ihave = make([]*pb.ControlIHave, nTopics) + rpc.Control.Iwant = make([]*pb.ControlIWant, nTopics) + for i := 0; i < nTopics; i++ { + messageIds := make([]string, msgsPerTopic) + for m := 0; m < msgsPerTopic; m++ { + mid := make([]byte, messageIdSize) + rand.Read(mid) + messageIds[m] = string(mid) + } + rpc.Control.Ihave[i] = &pb.ControlIHave{MessageIDs: messageIds} + rpc.Control.Iwant[i] = &pb.ControlIWant{MessageIDs: messageIds} + } + results, err = fragmentRPC(rpc, limit) + if err != nil { + t.Fatal(err) + } + ensureBelowLimit(results) + minExpectedCtl := rpc.Control.Size() / limit + minExpectedRPCs := (nMessages / msgsPerRPC) + minExpectedCtl + if len(results) < minExpectedRPCs { + t.Fatalf("expected at least %d total RPCs (at least %d with control messages), got %d total", expectedRPCs, expectedCtrl, len(results)) + } + + // Test the pathological case where a single gossip message ID exceeds the limit. + // It should not be present in the fragmented messages, but smaller IDs should be + rpc.Reset() + giantIdBytes := make([]byte, limit*2) + rand.Read(giantIdBytes) + rpc.Control = &pb.ControlMessage{ + Iwant: []*pb.ControlIWant{ + {MessageIDs: []string{"hello", string(giantIdBytes)}}, + }, + } + results, err = fragmentRPC(rpc, limit) + if err != nil { + t.Fatal(err) + } + if len(results) != 1 { + t.Fatalf("expected 1 RPC, got %d", len(results)) + } + if len(results[0].Control.Iwant) != 1 { + t.Fatalf("expected 1 IWANT, got %d", len(results[0].Control.Iwant)) + } + if results[0].Control.Iwant[0].MessageIDs[0] != "hello" { + t.Fatalf("expected small message ID to be included unaltered, got %s instead", + results[0].Control.Iwant[0].MessageIDs[0]) + } +}