diff --git a/agent/grpc-external/services/peerstream/stream_resources.go b/agent/grpc-external/services/peerstream/stream_resources.go index c67f7da041..5c69d08a72 100644 --- a/agent/grpc-external/services/peerstream/stream_resources.go +++ b/agent/grpc-external/services/peerstream/stream_resources.go @@ -5,6 +5,7 @@ import ( "fmt" "io" "strings" + "sync" "github.com/golang/protobuf/jsonpb" "github.com/golang/protobuf/proto" @@ -204,6 +205,25 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { ) subCh := mgr.subscribe(streamReq.Stream.Context(), streamReq.LocalID, streamReq.PeerName, streamReq.Partition) + // We need a mutex to protect against simultaneous sends to the client. + var sendMutex sync.Mutex + + // streamSend is a helper function that sends msg over the stream + // respecting the send mutex. It also logs the send and calls status.TrackSendError + // on error. + streamSend := func(msg *pbpeerstream.ReplicationMessage) error { + logTraceSend(logger, msg) + + sendMutex.Lock() + err := streamReq.Stream.Send(msg) + sendMutex.Unlock() + + if err != nil { + status.TrackSendError(err.Error()) + } + return err + } + // Subscribe to all relevant resource types. for _, resourceURL := range []string{ pbpeerstream.TypeURLExportedService, @@ -213,16 +233,12 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { ResourceURL: resourceURL, PeerID: streamReq.RemoteID, }) - logTraceSend(logger, sub) - - if err := streamReq.Stream.Send(sub); err != nil { + if err := streamSend(sub); err != nil { if err == io.EOF { logger.Info("stream ended by peer") - status.TrackReceiveError(err.Error()) return nil } // TODO(peering) Test error handling in calls to Send/Recv - status.TrackSendError(err.Error()) return fmt.Errorf("failed to send subscription for %q to stream: %w", resourceURL, err) } } @@ -261,10 +277,7 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { Terminated: &pbpeerstream.ReplicationMessage_Terminated{}, }, } - logTraceSend(logger, term) - - if err := streamReq.Stream.Send(term); err != nil { - status.TrackSendError(err.Error()) + if err := streamSend(term); err != nil { return fmt.Errorf("failed to send to stream: %v", err) } @@ -401,9 +414,7 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { status.TrackReceiveSuccess() } - logTraceSend(logger, reply) - if err := streamReq.Stream.Send(reply); err != nil { - status.TrackSendError(err.Error()) + if err := streamSend(reply); err != nil { return fmt.Errorf("failed to send to stream: %v", err) } @@ -451,10 +462,7 @@ func (s *Server) HandleStream(streamReq HandleStreamRequest) error { } replResp := makeReplicationResponse(resp) - - logTraceSend(logger, replResp) - if err := streamReq.Stream.Send(replResp); err != nil { - status.TrackSendError(err.Error()) + if err := streamSend(replResp); err != nil { return fmt.Errorf("failed to push data for %q: %w", update.CorrelationID, err) } }