status-go/vendor/github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/client.go

343 lines
10 KiB
Go

package autonatv2
import (
"context"
"fmt"
"sync"
"time"
"github.com/libp2p/go-libp2p/core/host"
"github.com/libp2p/go-libp2p/core/network"
"github.com/libp2p/go-libp2p/core/peer"
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
"github.com/libp2p/go-msgio/pbio"
ma "github.com/multiformats/go-multiaddr"
"golang.org/x/exp/rand"
)
//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto
// client implements the client for making dial requests for AutoNAT v2. It verifies successful
// dials and provides an option to send data for dial requests.
type client struct {
host host.Host
dialData []byte
normalizeMultiaddr func(ma.Multiaddr) ma.Multiaddr
mu sync.Mutex
// dialBackQueues maps nonce to the channel for providing the local multiaddr of the connection
// the nonce was received on
dialBackQueues map[uint64]chan ma.Multiaddr
}
type normalizeMultiaddrer interface {
NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr
}
func newClient(h host.Host) *client {
normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a }
if hn, ok := h.(normalizeMultiaddrer); ok {
normalizeMultiaddr = hn.NormalizeMultiaddr
}
return &client{
host: h,
dialData: make([]byte, 4000),
normalizeMultiaddr: normalizeMultiaddr,
dialBackQueues: make(map[uint64]chan ma.Multiaddr),
}
}
func (ac *client) Start() {
ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack)
}
func (ac *client) Close() {
ac.host.RemoveStreamHandler(DialBackProtocol)
}
// GetReachability verifies address reachability with a AutoNAT v2 server p.
func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) {
ctx, cancel := context.WithTimeout(ctx, streamTimeout)
defer cancel()
s, err := ac.host.NewStream(ctx, p, DialProtocol)
if err != nil {
return Result{}, fmt.Errorf("open %s stream failed: %w", DialProtocol, err)
}
if err := s.Scope().SetService(ServiceName); err != nil {
s.Reset()
return Result{}, fmt.Errorf("attach stream %s to service %s failed: %w", DialProtocol, ServiceName, err)
}
if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
s.Reset()
return Result{}, fmt.Errorf("failed to reserve memory for stream %s: %w", DialProtocol, err)
}
defer s.Scope().ReleaseMemory(maxMsgSize)
s.SetDeadline(time.Now().Add(streamTimeout))
defer s.Close()
nonce := rand.Uint64()
ch := make(chan ma.Multiaddr, 1)
ac.mu.Lock()
ac.dialBackQueues[nonce] = ch
ac.mu.Unlock()
defer func() {
ac.mu.Lock()
delete(ac.dialBackQueues, nonce)
ac.mu.Unlock()
}()
msg := newDialRequest(reqs, nonce)
w := pbio.NewDelimitedWriter(s)
if err := w.WriteMsg(&msg); err != nil {
s.Reset()
return Result{}, fmt.Errorf("dial request write failed: %w", err)
}
r := pbio.NewDelimitedReader(s, maxMsgSize)
if err := r.ReadMsg(&msg); err != nil {
s.Reset()
return Result{}, fmt.Errorf("dial msg read failed: %w", err)
}
switch {
case msg.GetDialResponse() != nil:
break
// provide dial data if appropriate
case msg.GetDialDataRequest() != nil:
if err := ac.validateDialDataRequest(reqs, &msg); err != nil {
s.Reset()
return Result{}, fmt.Errorf("invalid dial data request: %w", err)
}
// dial data request is valid and we want to send data
if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil {
s.Reset()
return Result{}, fmt.Errorf("dial data send failed: %w", err)
}
if err := r.ReadMsg(&msg); err != nil {
s.Reset()
return Result{}, fmt.Errorf("dial response read failed: %w", err)
}
if msg.GetDialResponse() == nil {
s.Reset()
return Result{}, fmt.Errorf("invalid response type: %T", msg.Msg)
}
default:
s.Reset()
return Result{}, fmt.Errorf("invalid msg type: %T", msg.Msg)
}
resp := msg.GetDialResponse()
if resp.GetStatus() != pb.DialResponse_OK {
// E_DIAL_REFUSED has implication for deciding future address verificiation priorities
// wrap a distinct error for convenient errors.Is usage
if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED {
return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused)
}
return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(),
pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())])
}
if resp.GetDialStatus() == pb.DialStatus_UNUSED {
return Result{}, fmt.Errorf("invalid response: invalid dial status UNUSED")
}
if int(resp.AddrIdx) >= len(reqs) {
return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs))
}
// wait for nonce from the server
var dialBackAddr ma.Multiaddr
if resp.GetDialStatus() == pb.DialStatus_OK {
timer := time.NewTimer(dialBackStreamTimeout)
select {
case at := <-ch:
dialBackAddr = at
case <-ctx.Done():
case <-timer.C:
}
timer.Stop()
}
return ac.newResult(resp, reqs, dialBackAddr)
}
func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error {
idx := int(msg.GetDialDataRequest().AddrIdx)
if idx >= len(reqs) { // invalid address index
return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs))
}
if msg.GetDialDataRequest().NumBytes > maxHandshakeSizeBytes { // data request is too high
return fmt.Errorf("requested data too high: %d", msg.GetDialDataRequest().NumBytes)
}
if !reqs[idx].SendDialData { // low priority addr
return fmt.Errorf("low priority addr: %s index %d", reqs[idx].Addr, idx)
}
return nil
}
func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) {
idx := int(resp.AddrIdx)
addr := reqs[idx].Addr
var rch network.Reachability
switch resp.DialStatus {
case pb.DialStatus_OK:
if !ac.areAddrsConsistent(dialBackAddr, addr) {
// the server is misinforming us about the address it successfully dialed
// either we received no dialback or the address on the dialback is inconsistent with
// what the server is telling us
return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
}
rch = network.ReachabilityPublic
case pb.DialStatus_E_DIAL_ERROR:
rch = network.ReachabilityPrivate
case pb.DialStatus_E_DIAL_BACK_ERROR:
if ac.areAddrsConsistent(dialBackAddr, addr) {
// We received the dial back but the server claims the dial back errored.
// As long as we received the correct nonce in dial back it is safe to assume
// that we are public.
rch = network.ReachabilityPublic
} else {
rch = network.ReachabilityUnknown
}
default:
// Unexpected response code. Discard the response and fail.
log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus)
return Result{}, fmt.Errorf("invalid response: invalid status code for addr %s: %d", addr, resp.DialStatus)
}
return Result{
Addr: addr,
Reachability: rch,
Status: resp.DialStatus,
}, nil
}
func sendDialData(dialData []byte, numBytes int, w pbio.Writer, msg *pb.Message) (err error) {
ddResp := &pb.DialDataResponse{Data: dialData}
*msg = pb.Message{
Msg: &pb.Message_DialDataResponse{
DialDataResponse: ddResp,
},
}
for remain := numBytes; remain > 0; {
if remain < len(ddResp.Data) {
ddResp.Data = ddResp.Data[:remain]
}
if err := w.WriteMsg(msg); err != nil {
return fmt.Errorf("write failed: %w", err)
}
remain -= len(dialData)
}
return nil
}
func newDialRequest(reqs []Request, nonce uint64) pb.Message {
addrbs := make([][]byte, len(reqs))
for i, r := range reqs {
addrbs[i] = r.Addr.Bytes()
}
return pb.Message{
Msg: &pb.Message_DialRequest{
DialRequest: &pb.DialRequest{
Addrs: addrbs,
Nonce: nonce,
},
},
}
}
// handleDialBack receives the nonce on the dial-back stream
func (ac *client) handleDialBack(s network.Stream) {
if err := s.Scope().SetService(ServiceName); err != nil {
log.Debugf("failed to attach stream to service %s: %w", ServiceName, err)
s.Reset()
return
}
if err := s.Scope().ReserveMemory(dialBackMaxMsgSize, network.ReservationPriorityAlways); err != nil {
log.Debugf("failed to reserve memory for stream %s: %w", DialBackProtocol, err)
s.Reset()
return
}
defer s.Scope().ReleaseMemory(dialBackMaxMsgSize)
s.SetDeadline(time.Now().Add(dialBackStreamTimeout))
defer s.Close()
r := pbio.NewDelimitedReader(s, dialBackMaxMsgSize)
var msg pb.DialBack
if err := r.ReadMsg(&msg); err != nil {
log.Debugf("failed to read dialback msg from %s: %s", s.Conn().RemotePeer(), err)
s.Reset()
return
}
nonce := msg.GetNonce()
ac.mu.Lock()
ch := ac.dialBackQueues[nonce]
ac.mu.Unlock()
if ch == nil {
log.Debugf("dialback received with invalid nonce: localAdds: %s peer: %s nonce: %d", s.Conn().LocalMultiaddr(), s.Conn().RemotePeer(), nonce)
s.Reset()
return
}
select {
case ch <- s.Conn().LocalMultiaddr():
default:
log.Debugf("multiple dialbacks received: localAddr: %s peer: %s", s.Conn().LocalMultiaddr(), s.Conn().RemotePeer())
s.Reset()
return
}
w := pbio.NewDelimitedWriter(s)
res := pb.DialBackResponse{}
if err := w.WriteMsg(&res); err != nil {
log.Debugf("failed to write dialback response: %s", err)
s.Reset()
}
}
func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool {
if connLocalAddr == nil || dialedAddr == nil {
return false
}
connLocalAddr = ac.normalizeMultiaddr(connLocalAddr)
dialedAddr = ac.normalizeMultiaddr(dialedAddr)
localProtos := connLocalAddr.Protocols()
externalProtos := dialedAddr.Protocols()
if len(localProtos) != len(externalProtos) {
return false
}
for i := 0; i < len(localProtos); i++ {
if i == 0 {
switch externalProtos[i].Code {
case ma.P_DNS, ma.P_DNSADDR:
if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 {
continue
}
return false
case ma.P_DNS4:
if localProtos[i].Code == ma.P_IP4 {
continue
}
return false
case ma.P_DNS6:
if localProtos[i].Code == ma.P_IP6 {
continue
}
return false
}
if localProtos[i].Code != externalProtos[i].Code {
return false
}
} else {
if localProtos[i].Code != externalProtos[i].Code {
return false
}
}
}
return true
}