count the number of streams on a connection for the stats
This commit is contained in:
parent
212b671494
commit
f417a8d5ce
|
@ -212,7 +212,7 @@ func (s *Swarm) addConn(tc transport.CapableConn, dir network.Direction) (*Conn,
|
|||
)
|
||||
|
||||
// create the Stat object, initializing with the underlying connection Stat if available
|
||||
var stat network.Stat
|
||||
var stat network.ConnStats
|
||||
if cs, ok := tc.(network.ConnStat); ok {
|
||||
stat = cs.Stat()
|
||||
}
|
||||
|
|
|
@ -39,7 +39,7 @@ type Conn struct {
|
|||
m map[*Stream]struct{}
|
||||
}
|
||||
|
||||
stat network.Stat
|
||||
stat network.ConnStats
|
||||
}
|
||||
|
||||
func (c *Conn) ID() string {
|
||||
|
@ -90,6 +90,7 @@ func (c *Conn) doClose() {
|
|||
|
||||
func (c *Conn) removeStream(s *Stream) {
|
||||
c.streams.Lock()
|
||||
c.stat.NumStreams--
|
||||
delete(c.streams.m, s)
|
||||
c.streams.Unlock()
|
||||
}
|
||||
|
@ -171,7 +172,9 @@ func (c *Conn) RemotePublicKey() ic.PubKey {
|
|||
}
|
||||
|
||||
// Stat returns metadata pertaining to this connection
|
||||
func (c *Conn) Stat() network.Stat {
|
||||
func (c *Conn) Stat() network.ConnStats {
|
||||
c.streams.Lock()
|
||||
defer c.streams.Unlock()
|
||||
return c.stat
|
||||
}
|
||||
|
||||
|
@ -201,16 +204,16 @@ func (c *Conn) addStream(ts mux.MuxedStream, dir network.Direction) (*Stream, er
|
|||
}
|
||||
|
||||
// Wrap and register the stream.
|
||||
stat := network.Stat{
|
||||
Direction: dir,
|
||||
Opened: time.Now(),
|
||||
}
|
||||
s := &Stream{
|
||||
stream: ts,
|
||||
conn: c,
|
||||
stat: stat,
|
||||
id: atomic.AddUint64(&c.swarm.nextStreamID, 1),
|
||||
stat: network.Stats{
|
||||
Direction: dir,
|
||||
Opened: time.Now(),
|
||||
},
|
||||
id: atomic.AddUint64(&c.swarm.nextStreamID, 1),
|
||||
}
|
||||
c.stat.NumStreams++
|
||||
c.streams.m[s] = struct{}{}
|
||||
|
||||
// Released once the stream disconnect notifications have finished
|
||||
|
|
|
@ -130,9 +130,9 @@ func TestNetworkOpenStream(t *testing.T) {
|
|||
t.Fatal(err)
|
||||
}
|
||||
|
||||
numStreams := 0
|
||||
var numStreams int
|
||||
for _, conn := range nets[0].ConnsToPeer(nets[1].LocalPeer()) {
|
||||
numStreams += len(conn.GetStreams())
|
||||
numStreams += conn.Stat().NumStreams
|
||||
}
|
||||
|
||||
if numStreams != 1 {
|
||||
|
|
|
@ -28,7 +28,7 @@ type Stream struct {
|
|||
|
||||
protocol atomic.Value
|
||||
|
||||
stat network.Stat
|
||||
stat network.Stats
|
||||
}
|
||||
|
||||
func (s *Stream) ID() string {
|
||||
|
@ -151,6 +151,6 @@ func (s *Stream) SetWriteDeadline(t time.Time) error {
|
|||
}
|
||||
|
||||
// Stat returns metadata information for this stream.
|
||||
func (s *Stream) Stat() network.Stat {
|
||||
func (s *Stream) Stat() network.Stats {
|
||||
return s.stat
|
||||
}
|
||||
|
|
|
@ -424,3 +424,44 @@ func TestPreventDialListenAddr(t *testing.T) {
|
|||
t.Fatal("expected dial to fail: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStreamCount(t *testing.T) {
|
||||
s1 := GenSwarm(t)
|
||||
s2 := GenSwarm(t)
|
||||
connectSwarms(t, context.Background(), []*swarm.Swarm{s2, s1})
|
||||
|
||||
countStreams := func() (n int) {
|
||||
var num int
|
||||
for _, c := range s1.ConnsToPeer(s2.LocalPeer()) {
|
||||
n += c.Stat().NumStreams
|
||||
num += len(c.GetStreams())
|
||||
}
|
||||
require.Equal(t, n, num, "inconsistent stream count")
|
||||
return
|
||||
}
|
||||
|
||||
streams := make(chan network.Stream, 20)
|
||||
streamAccepted := make(chan struct{}, 1)
|
||||
s1.SetStreamHandler(func(str network.Stream) {
|
||||
streams <- str
|
||||
streamAccepted <- struct{}{}
|
||||
})
|
||||
|
||||
for i := 0; i < 10; i++ {
|
||||
str, err := s2.NewStream(context.Background(), s1.LocalPeer())
|
||||
require.NoError(t, err)
|
||||
str.Write([]byte("foobar"))
|
||||
<-streamAccepted
|
||||
}
|
||||
require.Eventually(t, func() bool { return len(streams) == 10 }, 5*time.Second, 10*time.Millisecond)
|
||||
require.Equal(t, countStreams(), 10)
|
||||
(<-streams).Reset()
|
||||
(<-streams).Close()
|
||||
require.Equal(t, countStreams(), 8)
|
||||
|
||||
str, err := s1.NewStream(context.Background(), s2.LocalPeer())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, countStreams(), 9)
|
||||
str.Close()
|
||||
require.Equal(t, countStreams(), 8)
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue