various identify fixes and nits (#922)

* various identify fixes and nits

Co-authored-by: Aarsh Shah <aarshkshah1992@gmail.com>
This commit is contained in:
Steven Allen 2020-05-14 04:54:10 -07:00 committed by GitHub
parent 973933ad7d
commit b42ba0faf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 118 deletions

View File

@ -61,8 +61,6 @@ const transientTTL = 10 * time.Second
type addPeerHandlerReq struct { type addPeerHandlerReq struct {
rp peer.ID rp peer.ID
localConnAddr ma.Multiaddr
remoteConnAddr ma.Multiaddr
resp chan *peerHandler resp chan *peerHandler
} }
@ -194,9 +192,7 @@ func (ids *IDService) loop() {
} }
if ids.Host.Network().Connectedness(rp) == network.Connected { if ids.Host.Network().Connectedness(rp) == network.Connected {
mes := &pb.Identify{} ph = newPeerHandler(rp, ids)
ids.populateMessage(mes, rp, addReq.localConnAddr, addReq.remoteConnAddr)
ph = newPeerHandler(rp, ids, mes)
ph.start() ph.start()
phs[rp] = ph phs[rp] = ph
addReq.resp <- ph addReq.resp <- ph
@ -378,7 +374,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
defer func() { defer func() {
helpers.FullClose(s) helpers.FullClose(s)
if ph != nil { if ph != nil {
ph.msgMu.RUnlock() ph.snapshotMu.RUnlock()
} }
}() }()
@ -386,8 +382,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
phCh := make(chan *peerHandler, 1) phCh := make(chan *peerHandler, 1)
select { select {
case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), c.LocalMultiaddr(), case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), phCh}:
c.RemoteMultiaddr(), phCh}:
case <-ids.ctx.Done(): case <-ids.ctx.Done():
return return
} }
@ -398,9 +393,11 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
return return
} }
ph.msgMu.RLock() ph.snapshotMu.RLock()
mes := &pb.Identify{}
ids.populateMessage(mes, c, ph.snapshot)
w := ggio.NewDelimitedWriter(s) w := ggio.NewDelimitedWriter(s)
w.WriteMsg(ph.idMsgSnapshot) w.WriteMsg(mes)
log.Debugf("%s sent message to %s %s", ID, c.RemotePeer(), c.RemoteMultiaddr()) log.Debugf("%s sent message to %s %s", ID, c.RemotePeer(), c.RemoteMultiaddr())
} }
@ -422,13 +419,29 @@ func (ids *IDService) handleIdentifyResponse(s network.Stream) {
ids.consumeMessage(&mes, c) ids.consumeMessage(&mes, c)
} }
func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, remoteAddr ma.Multiaddr) { func (ids *IDService) getSnapshot() *identifySnapshot {
// set protocols this node is currently handling snapshot := new(identifySnapshot)
protos := ids.Host.Mux().Protocols() if cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()); ok {
mes.Protocols = make([]string, len(protos)) snapshot.record = cab.GetPeerRecord(ids.Host.ID())
for i, p := range protos { if snapshot.record == nil {
mes.Protocols[i] = p log.Errorf("latest peer record does not exist. identify message incomplete!")
} }
}
snapshot.addrs = ids.Host.Addrs()
snapshot.protocols = ids.Host.Mux().Protocols()
return snapshot
}
func (ids *IDService) populateMessage(
mes *pb.Identify,
conn network.Conn,
snapshot *identifySnapshot,
) {
remoteAddr := conn.RemoteMultiaddr()
localAddr := conn.LocalMultiaddr()
// set protocols this node is currently handling
mes.Protocols = snapshot.protocols
// observed address so other side is informed of their // observed address so other side is informed of their
// "public" address, at least in relation to us. // "public" address, at least in relation to us.
@ -436,33 +449,22 @@ func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, r
// populate unsigned addresses. // populate unsigned addresses.
// peers that do not yet support signed addresses will need this. // peers that do not yet support signed addresses will need this.
// set listen addrs, get our latest addrs from Host.
laddrs := ids.Host.Addrs()
// Note: LocalMultiaddr is sometimes 0.0.0.0 // Note: LocalMultiaddr is sometimes 0.0.0.0
viaLoopback := manet.IsIPLoopback(localAddr) || manet.IsIPLoopback(remoteAddr) viaLoopback := manet.IsIPLoopback(localAddr) || manet.IsIPLoopback(remoteAddr)
mes.ListenAddrs = make([][]byte, 0, len(laddrs)) mes.ListenAddrs = make([][]byte, 0, len(snapshot.addrs))
for _, addr := range laddrs { for _, addr := range snapshot.addrs {
if !viaLoopback && manet.IsIPLoopback(addr) { if !viaLoopback && manet.IsIPLoopback(addr) {
continue continue
} }
mes.ListenAddrs = append(mes.ListenAddrs, addr.Bytes()) mes.ListenAddrs = append(mes.ListenAddrs, addr.Bytes())
} }
// populate signed record. recBytes, err := snapshot.record.Marshal()
cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore())
if ok {
rec := cab.GetPeerRecord(ids.Host.ID())
if rec == nil {
log.Errorf("latest peer record does not exist. identify message incomplete!")
} else {
recBytes, err := rec.Marshal()
if err != nil { if err != nil {
log.Errorf("error marshaling peer record: %v", err) log.Errorf("error marshaling peer record: %v", err)
} else { } else {
mes.SignedPeerRecord = recBytes mes.SignedPeerRecord = recBytes
log.Debugf("%s sent peer record to %s", ids.Host.ID(), rp) log.Debugf("%s sent peer record to %s", ids.Host.ID(), conn.RemotePeer())
}
}
} }
// set our public key // set our public key

View File

@ -490,8 +490,6 @@ func TestIdentifyDeltaOnProtocolChange(t *testing.T) {
lk.Unlock() lk.Unlock()
} }
} }
close(done)
}() }()
<-done <-done

View File

@ -11,14 +11,21 @@ import (
"github.com/libp2p/go-libp2p-core/network" "github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer" "github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol" "github.com/libp2p/go-libp2p-core/protocol"
"github.com/libp2p/go-libp2p-core/record"
pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb" pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
ggio "github.com/gogo/protobuf/io" ggio "github.com/gogo/protobuf/io"
ma "github.com/multiformats/go-multiaddr"
) )
var errProtocolNotSupported = errors.New("protocol not supported") var errProtocolNotSupported = errors.New("protocol not supported")
var isTesting = false
type identifySnapshot struct {
protocols []string
addrs []ma.Multiaddr
record *record.Envelope
}
type peerHandler struct { type peerHandler struct {
ids *IDService ids *IDService
@ -29,29 +36,24 @@ type peerHandler struct {
pid peer.ID pid peer.ID
msgMu sync.RWMutex snapshotMu sync.RWMutex
idMsgSnapshot *pb.Identify snapshot *identifySnapshot
pushCh chan struct{} pushCh chan struct{}
deltaCh chan struct{} deltaCh chan struct{}
evalTestCh chan func() // for testing
} }
func newPeerHandler(pid peer.ID, ids *IDService, initState *pb.Identify) *peerHandler { func newPeerHandler(pid peer.ID, ids *IDService) *peerHandler {
ph := &peerHandler{ ph := &peerHandler{
ids: ids, ids: ids,
pid: pid, pid: pid,
idMsgSnapshot: initState, snapshot: ids.getSnapshot(),
pushCh: make(chan struct{}, 1), pushCh: make(chan struct{}, 1),
deltaCh: make(chan struct{}, 1), deltaCh: make(chan struct{}, 1),
} }
if isTesting {
ph.evalTestCh = make(chan func())
}
return ph return ph
} }
@ -87,9 +89,6 @@ func (ph *peerHandler) loop() {
log.Warnw("failed to send Identify Delta", "peer", ph.pid, "error", err) log.Warnw("failed to send Identify Delta", "peer", ph.pid, "error", err)
} }
case fnc := <-ph.evalTestCh:
fnc()
case <-ph.ctx.Done(): case <-ph.ctx.Done():
return return
} }
@ -97,11 +96,6 @@ func (ph *peerHandler) loop() {
} }
func (ph *peerHandler) sendDelta() error { func (ph *peerHandler) sendDelta() error {
mes := ph.mkDelta()
if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) {
return nil
}
// send a push if the peer does not support the Delta protocol. // send a push if the peer does not support the Delta protocol.
if !ph.peerSupportsProtos([]string{IDDelta}) { if !ph.peerSupportsProtos([]string{IDDelta}) {
log.Debugw("will send push as peer does not support delta", "peer", ph.pid) log.Debugw("will send push as peer does not support delta", "peer", ph.pid)
@ -111,10 +105,11 @@ func (ph *peerHandler) sendDelta() error {
return nil return nil
} }
ph.msgMu.Lock() // extract a delta message, updating the last state.
// update our identify snapshot for this peer by applying the delta to it mes := ph.nextDelta()
ph.applyDelta(mes) if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) {
ph.msgMu.Unlock() return nil
}
ds, err := ph.openStream([]string{IDDelta}) ds, err := ph.openStream([]string{IDDelta})
if err != nil { if err != nil {
@ -139,11 +134,13 @@ func (ph *peerHandler) sendPush() error {
conn := dp.Conn() conn := dp.Conn()
mes := &pb.Identify{} mes := &pb.Identify{}
ph.ids.populateMessage(mes, ph.pid, conn.LocalMultiaddr(), conn.RemoteMultiaddr())
ph.msgMu.Lock() snapshot := ph.ids.getSnapshot()
ph.idMsgSnapshot = mes ph.snapshotMu.Lock()
ph.msgMu.Unlock() ph.snapshot = snapshot
ph.snapshotMu.Unlock()
ph.ids.populateMessage(mes, conn, snapshot)
if err := ph.sendMessage(dp, mes); err != nil { if err := ph.sendMessage(dp, mes); err != nil {
return fmt.Errorf("failed to send push message: %w", err) return fmt.Errorf("failed to send push message: %w", err)
@ -151,21 +148,6 @@ func (ph *peerHandler) sendPush() error {
return nil return nil
} }
func (ph *peerHandler) applyDelta(mes *pb.Delta) {
for _, p1 := range mes.RmProtocols {
for j, p2 := range ph.idMsgSnapshot.Protocols {
if p2 == p1 {
ph.idMsgSnapshot.Protocols[j] = ph.idMsgSnapshot.Protocols[len(ph.idMsgSnapshot.Protocols)-1]
ph.idMsgSnapshot.Protocols = ph.idMsgSnapshot.Protocols[:len(ph.idMsgSnapshot.Protocols)-1]
}
}
}
for _, p := range mes.AddedProtocols {
ph.idMsgSnapshot.Protocols = append(ph.idMsgSnapshot.Protocols, p)
}
}
func (ph *peerHandler) openStream(protos []string) (network.Stream, error) { func (ph *peerHandler) openStream(protos []string) (network.Stream, error) {
// wait for the other peer to send us an Identify response on "all" connections we have with it // wait for the other peer to send us an Identify response on "all" connections we have with it
// so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol // so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol
@ -217,10 +199,18 @@ func (ph *peerHandler) peerSupportsProtos(protos []string) bool {
return true return true
} }
func (ph *peerHandler) mkDelta() *pb.Delta { func (ph *peerHandler) nextDelta() *pb.Delta {
old := ph.idMsgSnapshot.GetProtocols()
curr := ph.ids.Host.Mux().Protocols() curr := ph.ids.Host.Mux().Protocols()
// Extract the old protocol list and replace the old snapshot with an
// updated one.
ph.snapshotMu.Lock()
snapshot := *ph.snapshot
old := snapshot.protocols
snapshot.protocols = curr
ph.snapshot = &snapshot
ph.snapshotMu.Unlock()
oldProtos := make(map[string]struct{}, len(old)) oldProtos := make(map[string]struct{}, len(old))
currProtos := make(map[string]struct{}, len(curr)) currProtos := make(map[string]struct{}, len(curr))

View File

@ -9,72 +9,49 @@ import (
blhost "github.com/libp2p/go-libp2p-blankhost" blhost "github.com/libp2p/go-libp2p-blankhost"
swarmt "github.com/libp2p/go-libp2p-swarm/testing" swarmt "github.com/libp2p/go-libp2p-swarm/testing"
pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
) )
func doeval(t *testing.T, ph *peerHandler, f func()) {
done := make(chan struct{}, 1)
ph.evalTestCh <- func() {
f()
done <- struct{}{}
}
<-done
}
func TestMakeApplyDelta(t *testing.T) { func TestMakeApplyDelta(t *testing.T) {
isTesting = true
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
defer h1.Close() defer h1.Close()
ids1 := NewIDService(h1) ids1 := NewIDService(h1)
ph := newPeerHandler(h1.ID(), ids1, &pb.Identify{}) ph := newPeerHandler(h1.ID(), ids1)
ph.start() ph.start()
defer ph.close() defer ph.close()
m1 := ph.mkDelta() m1 := ph.nextDelta()
require.NotNil(t, m1) require.NotNil(t, m1)
// all the Id protocols must have been added // We haven't changed anything since creating the peer handler
require.NotEmpty(t, m1.AddedProtocols) require.Empty(t, m1.AddedProtocols)
doeval(t, ph, func() {
ph.applyDelta(m1)
})
h1.SetStreamHandler("p1", func(network.Stream) {}) h1.SetStreamHandler("p1", func(network.Stream) {})
m2 := ph.mkDelta() m2 := ph.nextDelta()
require.Len(t, m2.AddedProtocols, 1) require.Len(t, m2.AddedProtocols, 1)
require.Contains(t, m2.AddedProtocols, "p1") require.Contains(t, m2.AddedProtocols, "p1")
require.Empty(t, m2.RmProtocols) require.Empty(t, m2.RmProtocols)
doeval(t, ph, func() {
ph.applyDelta(m2)
})
h1.SetStreamHandler("p2", func(network.Stream) {}) h1.SetStreamHandler("p2", func(network.Stream) {})
h1.SetStreamHandler("p3", func(stream network.Stream) {}) h1.SetStreamHandler("p3", func(stream network.Stream) {})
m3 := ph.mkDelta() m3 := ph.nextDelta()
require.Len(t, m3.AddedProtocols, 2) require.Len(t, m3.AddedProtocols, 2)
require.Contains(t, m3.AddedProtocols, "p2") require.Contains(t, m3.AddedProtocols, "p2")
require.Contains(t, m3.AddedProtocols, "p3") require.Contains(t, m3.AddedProtocols, "p3")
require.Empty(t, m3.RmProtocols) require.Empty(t, m3.RmProtocols)
doeval(t, ph, func() {
ph.applyDelta(m3)
})
h1.RemoveStreamHandler("p3") h1.RemoveStreamHandler("p3")
m4 := ph.mkDelta() m4 := ph.nextDelta()
require.Empty(t, m4.AddedProtocols) require.Empty(t, m4.AddedProtocols)
require.Len(t, m4.RmProtocols, 1) require.Len(t, m4.RmProtocols, 1)
require.Contains(t, m4.RmProtocols, "p3") require.Contains(t, m4.RmProtocols, "p3")
doeval(t, ph, func() {
ph.applyDelta(m4)
})
h1.RemoveStreamHandler("p2") h1.RemoveStreamHandler("p2")
h1.RemoveStreamHandler("p1") h1.RemoveStreamHandler("p1")
m5 := ph.mkDelta() m5 := ph.nextDelta()
require.Empty(t, m5.AddedProtocols) require.Empty(t, m5.AddedProtocols)
require.Len(t, m5.RmProtocols, 2) require.Len(t, m5.RmProtocols, 2)
require.Contains(t, m5.RmProtocols, "p2") require.Contains(t, m5.RmProtocols, "p2")
@ -82,14 +59,13 @@ func TestMakeApplyDelta(t *testing.T) {
} }
func TestHandlerClose(t *testing.T) { func TestHandlerClose(t *testing.T) {
isTesting = true
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
defer cancel() defer cancel()
h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx)) h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
defer h1.Close() defer h1.Close()
ids1 := NewIDService(h1) ids1 := NewIDService(h1)
ph := newPeerHandler(h1.ID(), ids1, nil) ph := newPeerHandler(h1.ID(), ids1)
ph.start() ph.start()
require.NoError(t, ph.close()) require.NoError(t, ph.close())
@ -104,7 +80,7 @@ func TestPeerSupportsProto(t *testing.T) {
ids1 := NewIDService(h1) ids1 := NewIDService(h1)
rp := peer.ID("test") rp := peer.ID("test")
ph := newPeerHandler(rp, ids1, nil) ph := newPeerHandler(rp, ids1)
require.NoError(t, h1.Peerstore().AddProtocols(rp, "test")) require.NoError(t, h1.Peerstore().AddProtocols(rp, "test"))
require.True(t, ph.peerSupportsProtos([]string{"test"})) require.True(t, ph.peerSupportsProtos([]string{"test"}))
require.False(t, ph.peerSupportsProtos([]string{"random"})) require.False(t, ph.peerSupportsProtos([]string{"random"}))