various identify fixes and nits (#922)
* various identify fixes and nits Co-authored-by: Aarsh Shah <aarshkshah1992@gmail.com>
This commit is contained in:
parent
973933ad7d
commit
b42ba0faf3
|
@ -61,8 +61,6 @@ const transientTTL = 10 * time.Second
|
|||
|
||||
type addPeerHandlerReq struct {
|
||||
rp peer.ID
|
||||
localConnAddr ma.Multiaddr
|
||||
remoteConnAddr ma.Multiaddr
|
||||
resp chan *peerHandler
|
||||
}
|
||||
|
||||
|
@ -194,9 +192,7 @@ func (ids *IDService) loop() {
|
|||
}
|
||||
|
||||
if ids.Host.Network().Connectedness(rp) == network.Connected {
|
||||
mes := &pb.Identify{}
|
||||
ids.populateMessage(mes, rp, addReq.localConnAddr, addReq.remoteConnAddr)
|
||||
ph = newPeerHandler(rp, ids, mes)
|
||||
ph = newPeerHandler(rp, ids)
|
||||
ph.start()
|
||||
phs[rp] = ph
|
||||
addReq.resp <- ph
|
||||
|
@ -378,7 +374,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
|
|||
defer func() {
|
||||
helpers.FullClose(s)
|
||||
if ph != nil {
|
||||
ph.msgMu.RUnlock()
|
||||
ph.snapshotMu.RUnlock()
|
||||
}
|
||||
}()
|
||||
|
||||
|
@ -386,8 +382,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
|
|||
|
||||
phCh := make(chan *peerHandler, 1)
|
||||
select {
|
||||
case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), c.LocalMultiaddr(),
|
||||
c.RemoteMultiaddr(), phCh}:
|
||||
case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), phCh}:
|
||||
case <-ids.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
@ -398,9 +393,11 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
|
|||
return
|
||||
}
|
||||
|
||||
ph.msgMu.RLock()
|
||||
ph.snapshotMu.RLock()
|
||||
mes := &pb.Identify{}
|
||||
ids.populateMessage(mes, c, ph.snapshot)
|
||||
w := ggio.NewDelimitedWriter(s)
|
||||
w.WriteMsg(ph.idMsgSnapshot)
|
||||
w.WriteMsg(mes)
|
||||
|
||||
log.Debugf("%s sent message to %s %s", ID, c.RemotePeer(), c.RemoteMultiaddr())
|
||||
}
|
||||
|
@ -422,13 +419,29 @@ func (ids *IDService) handleIdentifyResponse(s network.Stream) {
|
|||
ids.consumeMessage(&mes, c)
|
||||
}
|
||||
|
||||
func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, remoteAddr ma.Multiaddr) {
|
||||
// set protocols this node is currently handling
|
||||
protos := ids.Host.Mux().Protocols()
|
||||
mes.Protocols = make([]string, len(protos))
|
||||
for i, p := range protos {
|
||||
mes.Protocols[i] = p
|
||||
func (ids *IDService) getSnapshot() *identifySnapshot {
|
||||
snapshot := new(identifySnapshot)
|
||||
if cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()); ok {
|
||||
snapshot.record = cab.GetPeerRecord(ids.Host.ID())
|
||||
if snapshot.record == nil {
|
||||
log.Errorf("latest peer record does not exist. identify message incomplete!")
|
||||
}
|
||||
}
|
||||
snapshot.addrs = ids.Host.Addrs()
|
||||
snapshot.protocols = ids.Host.Mux().Protocols()
|
||||
return snapshot
|
||||
}
|
||||
|
||||
func (ids *IDService) populateMessage(
|
||||
mes *pb.Identify,
|
||||
conn network.Conn,
|
||||
snapshot *identifySnapshot,
|
||||
) {
|
||||
remoteAddr := conn.RemoteMultiaddr()
|
||||
localAddr := conn.LocalMultiaddr()
|
||||
|
||||
// set protocols this node is currently handling
|
||||
mes.Protocols = snapshot.protocols
|
||||
|
||||
// observed address so other side is informed of their
|
||||
// "public" address, at least in relation to us.
|
||||
|
@ -436,33 +449,22 @@ func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, r
|
|||
|
||||
// populate unsigned addresses.
|
||||
// peers that do not yet support signed addresses will need this.
|
||||
// set listen addrs, get our latest addrs from Host.
|
||||
laddrs := ids.Host.Addrs()
|
||||
// Note: LocalMultiaddr is sometimes 0.0.0.0
|
||||
viaLoopback := manet.IsIPLoopback(localAddr) || manet.IsIPLoopback(remoteAddr)
|
||||
mes.ListenAddrs = make([][]byte, 0, len(laddrs))
|
||||
for _, addr := range laddrs {
|
||||
mes.ListenAddrs = make([][]byte, 0, len(snapshot.addrs))
|
||||
for _, addr := range snapshot.addrs {
|
||||
if !viaLoopback && manet.IsIPLoopback(addr) {
|
||||
continue
|
||||
}
|
||||
mes.ListenAddrs = append(mes.ListenAddrs, addr.Bytes())
|
||||
}
|
||||
|
||||
// populate signed record.
|
||||
cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore())
|
||||
if ok {
|
||||
rec := cab.GetPeerRecord(ids.Host.ID())
|
||||
if rec == nil {
|
||||
log.Errorf("latest peer record does not exist. identify message incomplete!")
|
||||
} else {
|
||||
recBytes, err := rec.Marshal()
|
||||
recBytes, err := snapshot.record.Marshal()
|
||||
if err != nil {
|
||||
log.Errorf("error marshaling peer record: %v", err)
|
||||
} else {
|
||||
mes.SignedPeerRecord = recBytes
|
||||
log.Debugf("%s sent peer record to %s", ids.Host.ID(), rp)
|
||||
}
|
||||
}
|
||||
log.Debugf("%s sent peer record to %s", ids.Host.ID(), conn.RemotePeer())
|
||||
}
|
||||
|
||||
// set our public key
|
||||
|
|
|
@ -490,8 +490,6 @@ func TestIdentifyDeltaOnProtocolChange(t *testing.T) {
|
|||
lk.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
close(done)
|
||||
}()
|
||||
|
||||
<-done
|
||||
|
|
|
@ -11,14 +11,21 @@ import (
|
|||
"github.com/libp2p/go-libp2p-core/network"
|
||||
"github.com/libp2p/go-libp2p-core/peer"
|
||||
"github.com/libp2p/go-libp2p-core/protocol"
|
||||
"github.com/libp2p/go-libp2p-core/record"
|
||||
|
||||
pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
|
||||
|
||||
ggio "github.com/gogo/protobuf/io"
|
||||
ma "github.com/multiformats/go-multiaddr"
|
||||
)
|
||||
|
||||
var errProtocolNotSupported = errors.New("protocol not supported")
|
||||
var isTesting = false
|
||||
|
||||
type identifySnapshot struct {
|
||||
protocols []string
|
||||
addrs []ma.Multiaddr
|
||||
record *record.Envelope
|
||||
}
|
||||
|
||||
type peerHandler struct {
|
||||
ids *IDService
|
||||
|
@ -29,29 +36,24 @@ type peerHandler struct {
|
|||
|
||||
pid peer.ID
|
||||
|
||||
msgMu sync.RWMutex
|
||||
idMsgSnapshot *pb.Identify
|
||||
snapshotMu sync.RWMutex
|
||||
snapshot *identifySnapshot
|
||||
|
||||
pushCh chan struct{}
|
||||
deltaCh chan struct{}
|
||||
evalTestCh chan func() // for testing
|
||||
}
|
||||
|
||||
func newPeerHandler(pid peer.ID, ids *IDService, initState *pb.Identify) *peerHandler {
|
||||
func newPeerHandler(pid peer.ID, ids *IDService) *peerHandler {
|
||||
ph := &peerHandler{
|
||||
ids: ids,
|
||||
pid: pid,
|
||||
|
||||
idMsgSnapshot: initState,
|
||||
snapshot: ids.getSnapshot(),
|
||||
|
||||
pushCh: make(chan struct{}, 1),
|
||||
deltaCh: make(chan struct{}, 1),
|
||||
}
|
||||
|
||||
if isTesting {
|
||||
ph.evalTestCh = make(chan func())
|
||||
}
|
||||
|
||||
return ph
|
||||
}
|
||||
|
||||
|
@ -87,9 +89,6 @@ func (ph *peerHandler) loop() {
|
|||
log.Warnw("failed to send Identify Delta", "peer", ph.pid, "error", err)
|
||||
}
|
||||
|
||||
case fnc := <-ph.evalTestCh:
|
||||
fnc()
|
||||
|
||||
case <-ph.ctx.Done():
|
||||
return
|
||||
}
|
||||
|
@ -97,11 +96,6 @@ func (ph *peerHandler) loop() {
|
|||
}
|
||||
|
||||
func (ph *peerHandler) sendDelta() error {
|
||||
mes := ph.mkDelta()
|
||||
if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// send a push if the peer does not support the Delta protocol.
|
||||
if !ph.peerSupportsProtos([]string{IDDelta}) {
|
||||
log.Debugw("will send push as peer does not support delta", "peer", ph.pid)
|
||||
|
@ -111,10 +105,11 @@ func (ph *peerHandler) sendDelta() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
ph.msgMu.Lock()
|
||||
// update our identify snapshot for this peer by applying the delta to it
|
||||
ph.applyDelta(mes)
|
||||
ph.msgMu.Unlock()
|
||||
// extract a delta message, updating the last state.
|
||||
mes := ph.nextDelta()
|
||||
if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) {
|
||||
return nil
|
||||
}
|
||||
|
||||
ds, err := ph.openStream([]string{IDDelta})
|
||||
if err != nil {
|
||||
|
@ -139,11 +134,13 @@ func (ph *peerHandler) sendPush() error {
|
|||
|
||||
conn := dp.Conn()
|
||||
mes := &pb.Identify{}
|
||||
ph.ids.populateMessage(mes, ph.pid, conn.LocalMultiaddr(), conn.RemoteMultiaddr())
|
||||
|
||||
ph.msgMu.Lock()
|
||||
ph.idMsgSnapshot = mes
|
||||
ph.msgMu.Unlock()
|
||||
snapshot := ph.ids.getSnapshot()
|
||||
ph.snapshotMu.Lock()
|
||||
ph.snapshot = snapshot
|
||||
ph.snapshotMu.Unlock()
|
||||
|
||||
ph.ids.populateMessage(mes, conn, snapshot)
|
||||
|
||||
if err := ph.sendMessage(dp, mes); err != nil {
|
||||
return fmt.Errorf("failed to send push message: %w", err)
|
||||
|
@ -151,21 +148,6 @@ func (ph *peerHandler) sendPush() error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (ph *peerHandler) applyDelta(mes *pb.Delta) {
|
||||
for _, p1 := range mes.RmProtocols {
|
||||
for j, p2 := range ph.idMsgSnapshot.Protocols {
|
||||
if p2 == p1 {
|
||||
ph.idMsgSnapshot.Protocols[j] = ph.idMsgSnapshot.Protocols[len(ph.idMsgSnapshot.Protocols)-1]
|
||||
ph.idMsgSnapshot.Protocols = ph.idMsgSnapshot.Protocols[:len(ph.idMsgSnapshot.Protocols)-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for _, p := range mes.AddedProtocols {
|
||||
ph.idMsgSnapshot.Protocols = append(ph.idMsgSnapshot.Protocols, p)
|
||||
}
|
||||
}
|
||||
|
||||
func (ph *peerHandler) openStream(protos []string) (network.Stream, error) {
|
||||
// wait for the other peer to send us an Identify response on "all" connections we have with it
|
||||
// so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol
|
||||
|
@ -217,10 +199,18 @@ func (ph *peerHandler) peerSupportsProtos(protos []string) bool {
|
|||
return true
|
||||
}
|
||||
|
||||
func (ph *peerHandler) mkDelta() *pb.Delta {
|
||||
old := ph.idMsgSnapshot.GetProtocols()
|
||||
func (ph *peerHandler) nextDelta() *pb.Delta {
|
||||
curr := ph.ids.Host.Mux().Protocols()
|
||||
|
||||
// Extract the old protocol list and replace the old snapshot with an
|
||||
// updated one.
|
||||
ph.snapshotMu.Lock()
|
||||
snapshot := *ph.snapshot
|
||||
old := snapshot.protocols
|
||||
snapshot.protocols = curr
|
||||
ph.snapshot = &snapshot
|
||||
ph.snapshotMu.Unlock()
|
||||
|
||||
oldProtos := make(map[string]struct{}, len(old))
|
||||
currProtos := make(map[string]struct{}, len(curr))
|
||||
|
||||
|
|
|
@ -9,72 +9,49 @@ import (
|
|||
|
||||
blhost "github.com/libp2p/go-libp2p-blankhost"
|
||||
swarmt "github.com/libp2p/go-libp2p-swarm/testing"
|
||||
pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func doeval(t *testing.T, ph *peerHandler, f func()) {
|
||||
done := make(chan struct{}, 1)
|
||||
ph.evalTestCh <- func() {
|
||||
f()
|
||||
done <- struct{}{}
|
||||
}
|
||||
<-done
|
||||
}
|
||||
|
||||
func TestMakeApplyDelta(t *testing.T) {
|
||||
isTesting = true
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
|
||||
defer h1.Close()
|
||||
ids1 := NewIDService(h1)
|
||||
ph := newPeerHandler(h1.ID(), ids1, &pb.Identify{})
|
||||
ph := newPeerHandler(h1.ID(), ids1)
|
||||
ph.start()
|
||||
defer ph.close()
|
||||
|
||||
m1 := ph.mkDelta()
|
||||
m1 := ph.nextDelta()
|
||||
require.NotNil(t, m1)
|
||||
// all the Id protocols must have been added
|
||||
require.NotEmpty(t, m1.AddedProtocols)
|
||||
doeval(t, ph, func() {
|
||||
ph.applyDelta(m1)
|
||||
})
|
||||
// We haven't changed anything since creating the peer handler
|
||||
require.Empty(t, m1.AddedProtocols)
|
||||
|
||||
h1.SetStreamHandler("p1", func(network.Stream) {})
|
||||
m2 := ph.mkDelta()
|
||||
m2 := ph.nextDelta()
|
||||
require.Len(t, m2.AddedProtocols, 1)
|
||||
require.Contains(t, m2.AddedProtocols, "p1")
|
||||
require.Empty(t, m2.RmProtocols)
|
||||
doeval(t, ph, func() {
|
||||
ph.applyDelta(m2)
|
||||
})
|
||||
|
||||
h1.SetStreamHandler("p2", func(network.Stream) {})
|
||||
h1.SetStreamHandler("p3", func(stream network.Stream) {})
|
||||
m3 := ph.mkDelta()
|
||||
m3 := ph.nextDelta()
|
||||
require.Len(t, m3.AddedProtocols, 2)
|
||||
require.Contains(t, m3.AddedProtocols, "p2")
|
||||
require.Contains(t, m3.AddedProtocols, "p3")
|
||||
require.Empty(t, m3.RmProtocols)
|
||||
doeval(t, ph, func() {
|
||||
ph.applyDelta(m3)
|
||||
})
|
||||
|
||||
h1.RemoveStreamHandler("p3")
|
||||
m4 := ph.mkDelta()
|
||||
m4 := ph.nextDelta()
|
||||
require.Empty(t, m4.AddedProtocols)
|
||||
require.Len(t, m4.RmProtocols, 1)
|
||||
require.Contains(t, m4.RmProtocols, "p3")
|
||||
doeval(t, ph, func() {
|
||||
ph.applyDelta(m4)
|
||||
})
|
||||
|
||||
h1.RemoveStreamHandler("p2")
|
||||
h1.RemoveStreamHandler("p1")
|
||||
m5 := ph.mkDelta()
|
||||
m5 := ph.nextDelta()
|
||||
require.Empty(t, m5.AddedProtocols)
|
||||
require.Len(t, m5.RmProtocols, 2)
|
||||
require.Contains(t, m5.RmProtocols, "p2")
|
||||
|
@ -82,14 +59,13 @@ func TestMakeApplyDelta(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestHandlerClose(t *testing.T) {
|
||||
isTesting = true
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
|
||||
defer h1.Close()
|
||||
ids1 := NewIDService(h1)
|
||||
ph := newPeerHandler(h1.ID(), ids1, nil)
|
||||
ph := newPeerHandler(h1.ID(), ids1)
|
||||
ph.start()
|
||||
|
||||
require.NoError(t, ph.close())
|
||||
|
@ -104,7 +80,7 @@ func TestPeerSupportsProto(t *testing.T) {
|
|||
ids1 := NewIDService(h1)
|
||||
|
||||
rp := peer.ID("test")
|
||||
ph := newPeerHandler(rp, ids1, nil)
|
||||
ph := newPeerHandler(rp, ids1)
|
||||
require.NoError(t, h1.Peerstore().AddProtocols(rp, "test"))
|
||||
require.True(t, ph.peerSupportsProtos([]string{"test"}))
|
||||
require.False(t, ph.peerSupportsProtos([]string{"random"}))
|
||||
|
|
Loading…
Reference in New Issue