various identify fixes and nits (#922)

* various identify fixes and nits

Co-authored-by: Aarsh Shah <aarshkshah1992@gmail.com>
This commit is contained in:
Steven Allen 2020-05-14 04:54:10 -07:00 committed by GitHub
parent 973933ad7d
commit b42ba0faf3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 84 additions and 118 deletions

View File

@ -61,8 +61,6 @@ const transientTTL = 10 * time.Second
type addPeerHandlerReq struct {
rp peer.ID
localConnAddr ma.Multiaddr
remoteConnAddr ma.Multiaddr
resp chan *peerHandler
}
@ -194,9 +192,7 @@ func (ids *IDService) loop() {
}
if ids.Host.Network().Connectedness(rp) == network.Connected {
mes := &pb.Identify{}
ids.populateMessage(mes, rp, addReq.localConnAddr, addReq.remoteConnAddr)
ph = newPeerHandler(rp, ids, mes)
ph = newPeerHandler(rp, ids)
ph.start()
phs[rp] = ph
addReq.resp <- ph
@ -378,7 +374,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
defer func() {
helpers.FullClose(s)
if ph != nil {
ph.msgMu.RUnlock()
ph.snapshotMu.RUnlock()
}
}()
@ -386,8 +382,7 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
phCh := make(chan *peerHandler, 1)
select {
case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), c.LocalMultiaddr(),
c.RemoteMultiaddr(), phCh}:
case ids.addPeerHandlerCh <- addPeerHandlerReq{c.RemotePeer(), phCh}:
case <-ids.ctx.Done():
return
}
@ -398,9 +393,11 @@ func (ids *IDService) sendIdentifyResp(s network.Stream) {
return
}
ph.msgMu.RLock()
ph.snapshotMu.RLock()
mes := &pb.Identify{}
ids.populateMessage(mes, c, ph.snapshot)
w := ggio.NewDelimitedWriter(s)
w.WriteMsg(ph.idMsgSnapshot)
w.WriteMsg(mes)
log.Debugf("%s sent message to %s %s", ID, c.RemotePeer(), c.RemoteMultiaddr())
}
@ -422,13 +419,29 @@ func (ids *IDService) handleIdentifyResponse(s network.Stream) {
ids.consumeMessage(&mes, c)
}
func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, remoteAddr ma.Multiaddr) {
// set protocols this node is currently handling
protos := ids.Host.Mux().Protocols()
mes.Protocols = make([]string, len(protos))
for i, p := range protos {
mes.Protocols[i] = p
func (ids *IDService) getSnapshot() *identifySnapshot {
snapshot := new(identifySnapshot)
if cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore()); ok {
snapshot.record = cab.GetPeerRecord(ids.Host.ID())
if snapshot.record == nil {
log.Errorf("latest peer record does not exist. identify message incomplete!")
}
}
snapshot.addrs = ids.Host.Addrs()
snapshot.protocols = ids.Host.Mux().Protocols()
return snapshot
}
func (ids *IDService) populateMessage(
mes *pb.Identify,
conn network.Conn,
snapshot *identifySnapshot,
) {
remoteAddr := conn.RemoteMultiaddr()
localAddr := conn.LocalMultiaddr()
// set protocols this node is currently handling
mes.Protocols = snapshot.protocols
// observed address so other side is informed of their
// "public" address, at least in relation to us.
@ -436,33 +449,22 @@ func (ids *IDService) populateMessage(mes *pb.Identify, rp peer.ID, localAddr, r
// populate unsigned addresses.
// peers that do not yet support signed addresses will need this.
// set listen addrs, get our latest addrs from Host.
laddrs := ids.Host.Addrs()
// Note: LocalMultiaddr is sometimes 0.0.0.0
viaLoopback := manet.IsIPLoopback(localAddr) || manet.IsIPLoopback(remoteAddr)
mes.ListenAddrs = make([][]byte, 0, len(laddrs))
for _, addr := range laddrs {
mes.ListenAddrs = make([][]byte, 0, len(snapshot.addrs))
for _, addr := range snapshot.addrs {
if !viaLoopback && manet.IsIPLoopback(addr) {
continue
}
mes.ListenAddrs = append(mes.ListenAddrs, addr.Bytes())
}
// populate signed record.
cab, ok := peerstore.GetCertifiedAddrBook(ids.Host.Peerstore())
if ok {
rec := cab.GetPeerRecord(ids.Host.ID())
if rec == nil {
log.Errorf("latest peer record does not exist. identify message incomplete!")
} else {
recBytes, err := rec.Marshal()
recBytes, err := snapshot.record.Marshal()
if err != nil {
log.Errorf("error marshaling peer record: %v", err)
} else {
mes.SignedPeerRecord = recBytes
log.Debugf("%s sent peer record to %s", ids.Host.ID(), rp)
}
}
log.Debugf("%s sent peer record to %s", ids.Host.ID(), conn.RemotePeer())
}
// set our public key

View File

@ -490,8 +490,6 @@ func TestIdentifyDeltaOnProtocolChange(t *testing.T) {
lk.Unlock()
}
}
close(done)
}()
<-done

View File

@ -11,14 +11,21 @@ import (
"github.com/libp2p/go-libp2p-core/network"
"github.com/libp2p/go-libp2p-core/peer"
"github.com/libp2p/go-libp2p-core/protocol"
"github.com/libp2p/go-libp2p-core/record"
pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
ggio "github.com/gogo/protobuf/io"
ma "github.com/multiformats/go-multiaddr"
)
var errProtocolNotSupported = errors.New("protocol not supported")
var isTesting = false
type identifySnapshot struct {
protocols []string
addrs []ma.Multiaddr
record *record.Envelope
}
type peerHandler struct {
ids *IDService
@ -29,29 +36,24 @@ type peerHandler struct {
pid peer.ID
msgMu sync.RWMutex
idMsgSnapshot *pb.Identify
snapshotMu sync.RWMutex
snapshot *identifySnapshot
pushCh chan struct{}
deltaCh chan struct{}
evalTestCh chan func() // for testing
}
func newPeerHandler(pid peer.ID, ids *IDService, initState *pb.Identify) *peerHandler {
func newPeerHandler(pid peer.ID, ids *IDService) *peerHandler {
ph := &peerHandler{
ids: ids,
pid: pid,
idMsgSnapshot: initState,
snapshot: ids.getSnapshot(),
pushCh: make(chan struct{}, 1),
deltaCh: make(chan struct{}, 1),
}
if isTesting {
ph.evalTestCh = make(chan func())
}
return ph
}
@ -87,9 +89,6 @@ func (ph *peerHandler) loop() {
log.Warnw("failed to send Identify Delta", "peer", ph.pid, "error", err)
}
case fnc := <-ph.evalTestCh:
fnc()
case <-ph.ctx.Done():
return
}
@ -97,11 +96,6 @@ func (ph *peerHandler) loop() {
}
func (ph *peerHandler) sendDelta() error {
mes := ph.mkDelta()
if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) {
return nil
}
// send a push if the peer does not support the Delta protocol.
if !ph.peerSupportsProtos([]string{IDDelta}) {
log.Debugw("will send push as peer does not support delta", "peer", ph.pid)
@ -111,10 +105,11 @@ func (ph *peerHandler) sendDelta() error {
return nil
}
ph.msgMu.Lock()
// update our identify snapshot for this peer by applying the delta to it
ph.applyDelta(mes)
ph.msgMu.Unlock()
// extract a delta message, updating the last state.
mes := ph.nextDelta()
if mes == nil || (len(mes.AddedProtocols) == 0 && len(mes.RmProtocols) == 0) {
return nil
}
ds, err := ph.openStream([]string{IDDelta})
if err != nil {
@ -139,11 +134,13 @@ func (ph *peerHandler) sendPush() error {
conn := dp.Conn()
mes := &pb.Identify{}
ph.ids.populateMessage(mes, ph.pid, conn.LocalMultiaddr(), conn.RemoteMultiaddr())
ph.msgMu.Lock()
ph.idMsgSnapshot = mes
ph.msgMu.Unlock()
snapshot := ph.ids.getSnapshot()
ph.snapshotMu.Lock()
ph.snapshot = snapshot
ph.snapshotMu.Unlock()
ph.ids.populateMessage(mes, conn, snapshot)
if err := ph.sendMessage(dp, mes); err != nil {
return fmt.Errorf("failed to send push message: %w", err)
@ -151,21 +148,6 @@ func (ph *peerHandler) sendPush() error {
return nil
}
func (ph *peerHandler) applyDelta(mes *pb.Delta) {
for _, p1 := range mes.RmProtocols {
for j, p2 := range ph.idMsgSnapshot.Protocols {
if p2 == p1 {
ph.idMsgSnapshot.Protocols[j] = ph.idMsgSnapshot.Protocols[len(ph.idMsgSnapshot.Protocols)-1]
ph.idMsgSnapshot.Protocols = ph.idMsgSnapshot.Protocols[:len(ph.idMsgSnapshot.Protocols)-1]
}
}
}
for _, p := range mes.AddedProtocols {
ph.idMsgSnapshot.Protocols = append(ph.idMsgSnapshot.Protocols, p)
}
}
func (ph *peerHandler) openStream(protos []string) (network.Stream, error) {
// wait for the other peer to send us an Identify response on "all" connections we have with it
// so we can look at it's supported protocols and avoid a multistream-select roundtrip to negotiate the protocol
@ -217,10 +199,18 @@ func (ph *peerHandler) peerSupportsProtos(protos []string) bool {
return true
}
func (ph *peerHandler) mkDelta() *pb.Delta {
old := ph.idMsgSnapshot.GetProtocols()
func (ph *peerHandler) nextDelta() *pb.Delta {
curr := ph.ids.Host.Mux().Protocols()
// Extract the old protocol list and replace the old snapshot with an
// updated one.
ph.snapshotMu.Lock()
snapshot := *ph.snapshot
old := snapshot.protocols
snapshot.protocols = curr
ph.snapshot = &snapshot
ph.snapshotMu.Unlock()
oldProtos := make(map[string]struct{}, len(old))
currProtos := make(map[string]struct{}, len(curr))

View File

@ -9,72 +9,49 @@ import (
blhost "github.com/libp2p/go-libp2p-blankhost"
swarmt "github.com/libp2p/go-libp2p-swarm/testing"
pb "github.com/libp2p/go-libp2p/p2p/protocol/identify/pb"
"github.com/stretchr/testify/require"
)
func doeval(t *testing.T, ph *peerHandler, f func()) {
done := make(chan struct{}, 1)
ph.evalTestCh <- func() {
f()
done <- struct{}{}
}
<-done
}
func TestMakeApplyDelta(t *testing.T) {
isTesting = true
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
defer h1.Close()
ids1 := NewIDService(h1)
ph := newPeerHandler(h1.ID(), ids1, &pb.Identify{})
ph := newPeerHandler(h1.ID(), ids1)
ph.start()
defer ph.close()
m1 := ph.mkDelta()
m1 := ph.nextDelta()
require.NotNil(t, m1)
// all the Id protocols must have been added
require.NotEmpty(t, m1.AddedProtocols)
doeval(t, ph, func() {
ph.applyDelta(m1)
})
// We haven't changed anything since creating the peer handler
require.Empty(t, m1.AddedProtocols)
h1.SetStreamHandler("p1", func(network.Stream) {})
m2 := ph.mkDelta()
m2 := ph.nextDelta()
require.Len(t, m2.AddedProtocols, 1)
require.Contains(t, m2.AddedProtocols, "p1")
require.Empty(t, m2.RmProtocols)
doeval(t, ph, func() {
ph.applyDelta(m2)
})
h1.SetStreamHandler("p2", func(network.Stream) {})
h1.SetStreamHandler("p3", func(stream network.Stream) {})
m3 := ph.mkDelta()
m3 := ph.nextDelta()
require.Len(t, m3.AddedProtocols, 2)
require.Contains(t, m3.AddedProtocols, "p2")
require.Contains(t, m3.AddedProtocols, "p3")
require.Empty(t, m3.RmProtocols)
doeval(t, ph, func() {
ph.applyDelta(m3)
})
h1.RemoveStreamHandler("p3")
m4 := ph.mkDelta()
m4 := ph.nextDelta()
require.Empty(t, m4.AddedProtocols)
require.Len(t, m4.RmProtocols, 1)
require.Contains(t, m4.RmProtocols, "p3")
doeval(t, ph, func() {
ph.applyDelta(m4)
})
h1.RemoveStreamHandler("p2")
h1.RemoveStreamHandler("p1")
m5 := ph.mkDelta()
m5 := ph.nextDelta()
require.Empty(t, m5.AddedProtocols)
require.Len(t, m5.RmProtocols, 2)
require.Contains(t, m5.RmProtocols, "p2")
@ -82,14 +59,13 @@ func TestMakeApplyDelta(t *testing.T) {
}
func TestHandlerClose(t *testing.T) {
isTesting = true
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
h1 := blhost.NewBlankHost(swarmt.GenSwarm(t, ctx))
defer h1.Close()
ids1 := NewIDService(h1)
ph := newPeerHandler(h1.ID(), ids1, nil)
ph := newPeerHandler(h1.ID(), ids1)
ph.start()
require.NoError(t, ph.close())
@ -104,7 +80,7 @@ func TestPeerSupportsProto(t *testing.T) {
ids1 := NewIDService(h1)
rp := peer.ID("test")
ph := newPeerHandler(rp, ids1, nil)
ph := newPeerHandler(rp, ids1)
require.NoError(t, h1.Peerstore().AddProtocols(rp, "test"))
require.True(t, ph.peerSupportsProtos([]string{"test"}))
require.False(t, ph.peerSupportsProtos([]string{"random"}))