count the number of streams on a connection for the stats

This commit is contained in:
Marten Seemann 2021-12-10 15:28:54 +04:00
parent 212b671494
commit f417a8d5ce
5 changed files with 57 additions and 13 deletions

View File

@ -212,7 +212,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
)
// create the Stat object, initializing with the underlying connection Stat if available
var stat network.Stat
var stat network.ConnStats
if cs, ok := tc.(network.ConnStat); ok {
stat = cs.Stat()
}

View File

@ -39,7 +39,7 @@ type Conn struct {
m map[*Stream]struct{}
}
stat network.Stat
stat network.ConnStats
}
func (c *Conn) ID() string {
@ -90,6 +90,7 @@ func (c *Conn) doClose() {
func (c *Conn) removeStream(s *Stream) {
c.streams.Lock()
c.stat.NumStreams--
delete(c.streams.m, s)
c.streams.Unlock()
}
@ -171,7 +172,9 @@ func (c *Conn) RemotePublicKey() ic.PubKey {
}
// Stat returns metadata pertaining to this connection
func (c *Conn) Stat() network.Stat {
func (c *Conn) Stat() network.ConnStats {
c.streams.Lock()
defer c.streams.Unlock()
return c.stat
}
@ -201,16 +204,16 @@ func (c *Conn) addStream(ts mux.MuxedStream, dir network.Direction) (*Stream, er
}
// Wrap and register the stream.
stat := network.Stat{
Direction: dir,
Opened: time.Now(),
}
s := &Stream{
stream: ts,
conn: c,
stat: stat,
id: atomic.AddUint64(&c.swarm.nextStreamID, 1),
stat: network.Stats{
Direction: dir,
Opened: time.Now(),
},
id: atomic.AddUint64(&c.swarm.nextStreamID, 1),
}
c.stat.NumStreams++
c.streams.m[s] = struct{}{}
// Released once the stream disconnect notifications have finished

View File

@ -130,9 +130,9 @@ func TestNetworkOpenStream(t *testing.T) {
t.Fatal(err)
}
numStreams := 0
var numStreams int
for _, conn := range nets[0].ConnsToPeer(nets[1].LocalPeer()) {
numStreams += len(conn.GetStreams())
numStreams += conn.Stat().NumStreams
}
if numStreams != 1 {

View File

@ -28,7 +28,7 @@ type Stream struct {
protocol atomic.Value
stat network.Stat
stat network.Stats
}
func (s *Stream) ID() string {
@ -151,6 +151,6 @@ func (s *Stream) SetWriteDeadline(t time.Time) error {
}
// Stat returns metadata information for this stream.
func (s *Stream) Stat() network.Stat {
func (s *Stream) Stat() network.Stats {
return s.stat
}

View File

@ -424,3 +424,44 @@ func TestPreventDialListenAddr(t *testing.T) {
t.Fatal("expected dial to fail: %w", err)
}
}
func TestStreamCount(t *testing.T) {
s1 := GenSwarm(t)
s2 := GenSwarm(t)
connectSwarms(t, context.Background(), []*swarm.Swarm{s2, s1})
countStreams := func() (n int) {
var num int
for _, c := range s1.ConnsToPeer(s2.LocalPeer()) {
n += c.Stat().NumStreams
num += len(c.GetStreams())
}
require.Equal(t, n, num, "inconsistent stream count")
return
}
streams := make(chan network.Stream, 20)
streamAccepted := make(chan struct{}, 1)
s1.SetStreamHandler(func(str network.Stream) {
streams <- str
streamAccepted <- struct{}{}
})
for i := 0; i < 10; i++ {
str, err := s2.NewStream(context.Background(), s1.LocalPeer())
require.NoError(t, err)
str.Write([]byte("foobar"))
<-streamAccepted
}
require.Eventually(t, func() bool { return len(streams) == 10 }, 5*time.Second, 10*time.Millisecond)
require.Equal(t, countStreams(), 10)
(<-streams).Reset()
(<-streams).Close()
require.Equal(t, countStreams(), 8)
str, err := s1.NewStream(context.Background(), s2.LocalPeer())
require.NoError(t, err)
require.Equal(t, countStreams(), 9)
str.Close()
require.Equal(t, countStreams(), 8)
}