343 lines
10 KiB
Go
343 lines
10 KiB
Go
package autonatv2
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/libp2p/go-libp2p/core/host"
|
|
"github.com/libp2p/go-libp2p/core/network"
|
|
"github.com/libp2p/go-libp2p/core/peer"
|
|
"github.com/libp2p/go-libp2p/p2p/protocol/autonatv2/pb"
|
|
"github.com/libp2p/go-msgio/pbio"
|
|
ma "github.com/multiformats/go-multiaddr"
|
|
"golang.org/x/exp/rand"
|
|
)
|
|
|
|
//go:generate protoc --go_out=. --go_opt=Mpb/autonatv2.proto=./pb pb/autonatv2.proto
|
|
|
|
// client implements the client for making dial requests for AutoNAT v2. It verifies successful
|
|
// dials and provides an option to send data for dial requests.
|
|
type client struct {
|
|
host host.Host
|
|
dialData []byte
|
|
normalizeMultiaddr func(ma.Multiaddr) ma.Multiaddr
|
|
|
|
mu sync.Mutex
|
|
// dialBackQueues maps nonce to the channel for providing the local multiaddr of the connection
|
|
// the nonce was received on
|
|
dialBackQueues map[uint64]chan ma.Multiaddr
|
|
}
|
|
|
|
type normalizeMultiaddrer interface {
|
|
NormalizeMultiaddr(ma.Multiaddr) ma.Multiaddr
|
|
}
|
|
|
|
func newClient(h host.Host) *client {
|
|
normalizeMultiaddr := func(a ma.Multiaddr) ma.Multiaddr { return a }
|
|
if hn, ok := h.(normalizeMultiaddrer); ok {
|
|
normalizeMultiaddr = hn.NormalizeMultiaddr
|
|
}
|
|
return &client{
|
|
host: h,
|
|
dialData: make([]byte, 4000),
|
|
normalizeMultiaddr: normalizeMultiaddr,
|
|
dialBackQueues: make(map[uint64]chan ma.Multiaddr),
|
|
}
|
|
}
|
|
|
|
func (ac *client) Start() {
|
|
ac.host.SetStreamHandler(DialBackProtocol, ac.handleDialBack)
|
|
}
|
|
|
|
func (ac *client) Close() {
|
|
ac.host.RemoveStreamHandler(DialBackProtocol)
|
|
}
|
|
|
|
// GetReachability verifies address reachability with a AutoNAT v2 server p.
|
|
func (ac *client) GetReachability(ctx context.Context, p peer.ID, reqs []Request) (Result, error) {
|
|
ctx, cancel := context.WithTimeout(ctx, streamTimeout)
|
|
defer cancel()
|
|
|
|
s, err := ac.host.NewStream(ctx, p, DialProtocol)
|
|
if err != nil {
|
|
return Result{}, fmt.Errorf("open %s stream failed: %w", DialProtocol, err)
|
|
}
|
|
|
|
if err := s.Scope().SetService(ServiceName); err != nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("attach stream %s to service %s failed: %w", DialProtocol, ServiceName, err)
|
|
}
|
|
|
|
if err := s.Scope().ReserveMemory(maxMsgSize, network.ReservationPriorityAlways); err != nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("failed to reserve memory for stream %s: %w", DialProtocol, err)
|
|
}
|
|
defer s.Scope().ReleaseMemory(maxMsgSize)
|
|
|
|
s.SetDeadline(time.Now().Add(streamTimeout))
|
|
defer s.Close()
|
|
|
|
nonce := rand.Uint64()
|
|
ch := make(chan ma.Multiaddr, 1)
|
|
ac.mu.Lock()
|
|
ac.dialBackQueues[nonce] = ch
|
|
ac.mu.Unlock()
|
|
defer func() {
|
|
ac.mu.Lock()
|
|
delete(ac.dialBackQueues, nonce)
|
|
ac.mu.Unlock()
|
|
}()
|
|
|
|
msg := newDialRequest(reqs, nonce)
|
|
w := pbio.NewDelimitedWriter(s)
|
|
if err := w.WriteMsg(&msg); err != nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("dial request write failed: %w", err)
|
|
}
|
|
|
|
r := pbio.NewDelimitedReader(s, maxMsgSize)
|
|
if err := r.ReadMsg(&msg); err != nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("dial msg read failed: %w", err)
|
|
}
|
|
|
|
switch {
|
|
case msg.GetDialResponse() != nil:
|
|
break
|
|
// provide dial data if appropriate
|
|
case msg.GetDialDataRequest() != nil:
|
|
if err := ac.validateDialDataRequest(reqs, &msg); err != nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("invalid dial data request: %w", err)
|
|
}
|
|
// dial data request is valid and we want to send data
|
|
if err := sendDialData(ac.dialData, int(msg.GetDialDataRequest().GetNumBytes()), w, &msg); err != nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("dial data send failed: %w", err)
|
|
}
|
|
if err := r.ReadMsg(&msg); err != nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("dial response read failed: %w", err)
|
|
}
|
|
if msg.GetDialResponse() == nil {
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("invalid response type: %T", msg.Msg)
|
|
}
|
|
default:
|
|
s.Reset()
|
|
return Result{}, fmt.Errorf("invalid msg type: %T", msg.Msg)
|
|
}
|
|
|
|
resp := msg.GetDialResponse()
|
|
if resp.GetStatus() != pb.DialResponse_OK {
|
|
// E_DIAL_REFUSED has implication for deciding future address verificiation priorities
|
|
// wrap a distinct error for convenient errors.Is usage
|
|
if resp.GetStatus() == pb.DialResponse_E_DIAL_REFUSED {
|
|
return Result{}, fmt.Errorf("dial request failed: %w", ErrDialRefused)
|
|
}
|
|
return Result{}, fmt.Errorf("dial request failed: response status %d %s", resp.GetStatus(),
|
|
pb.DialResponse_ResponseStatus_name[int32(resp.GetStatus())])
|
|
}
|
|
if resp.GetDialStatus() == pb.DialStatus_UNUSED {
|
|
return Result{}, fmt.Errorf("invalid response: invalid dial status UNUSED")
|
|
}
|
|
if int(resp.AddrIdx) >= len(reqs) {
|
|
return Result{}, fmt.Errorf("invalid response: addr index out of range: %d [0-%d)", resp.AddrIdx, len(reqs))
|
|
}
|
|
|
|
// wait for nonce from the server
|
|
var dialBackAddr ma.Multiaddr
|
|
if resp.GetDialStatus() == pb.DialStatus_OK {
|
|
timer := time.NewTimer(dialBackStreamTimeout)
|
|
select {
|
|
case at := <-ch:
|
|
dialBackAddr = at
|
|
case <-ctx.Done():
|
|
case <-timer.C:
|
|
}
|
|
timer.Stop()
|
|
}
|
|
return ac.newResult(resp, reqs, dialBackAddr)
|
|
}
|
|
|
|
func (ac *client) validateDialDataRequest(reqs []Request, msg *pb.Message) error {
|
|
idx := int(msg.GetDialDataRequest().AddrIdx)
|
|
if idx >= len(reqs) { // invalid address index
|
|
return fmt.Errorf("addr index out of range: %d [0-%d)", idx, len(reqs))
|
|
}
|
|
if msg.GetDialDataRequest().NumBytes > maxHandshakeSizeBytes { // data request is too high
|
|
return fmt.Errorf("requested data too high: %d", msg.GetDialDataRequest().NumBytes)
|
|
}
|
|
if !reqs[idx].SendDialData { // low priority addr
|
|
return fmt.Errorf("low priority addr: %s index %d", reqs[idx].Addr, idx)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (ac *client) newResult(resp *pb.DialResponse, reqs []Request, dialBackAddr ma.Multiaddr) (Result, error) {
|
|
idx := int(resp.AddrIdx)
|
|
addr := reqs[idx].Addr
|
|
|
|
var rch network.Reachability
|
|
switch resp.DialStatus {
|
|
case pb.DialStatus_OK:
|
|
if !ac.areAddrsConsistent(dialBackAddr, addr) {
|
|
// the server is misinforming us about the address it successfully dialed
|
|
// either we received no dialback or the address on the dialback is inconsistent with
|
|
// what the server is telling us
|
|
return Result{}, fmt.Errorf("invalid response: dialBackAddr: %s, respAddr: %s", dialBackAddr, addr)
|
|
}
|
|
rch = network.ReachabilityPublic
|
|
case pb.DialStatus_E_DIAL_ERROR:
|
|
rch = network.ReachabilityPrivate
|
|
case pb.DialStatus_E_DIAL_BACK_ERROR:
|
|
if ac.areAddrsConsistent(dialBackAddr, addr) {
|
|
// We received the dial back but the server claims the dial back errored.
|
|
// As long as we received the correct nonce in dial back it is safe to assume
|
|
// that we are public.
|
|
rch = network.ReachabilityPublic
|
|
} else {
|
|
rch = network.ReachabilityUnknown
|
|
}
|
|
default:
|
|
// Unexpected response code. Discard the response and fail.
|
|
log.Warnf("invalid status code received in response for addr %s: %d", addr, resp.DialStatus)
|
|
return Result{}, fmt.Errorf("invalid response: invalid status code for addr %s: %d", addr, resp.DialStatus)
|
|
}
|
|
|
|
return Result{
|
|
Addr: addr,
|
|
Reachability: rch,
|
|
Status: resp.DialStatus,
|
|
}, nil
|
|
}
|
|
|
|
func sendDialData(dialData []byte, numBytes int, w pbio.Writer, msg *pb.Message) (err error) {
|
|
ddResp := &pb.DialDataResponse{Data: dialData}
|
|
*msg = pb.Message{
|
|
Msg: &pb.Message_DialDataResponse{
|
|
DialDataResponse: ddResp,
|
|
},
|
|
}
|
|
for remain := numBytes; remain > 0; {
|
|
if remain < len(ddResp.Data) {
|
|
ddResp.Data = ddResp.Data[:remain]
|
|
}
|
|
if err := w.WriteMsg(msg); err != nil {
|
|
return fmt.Errorf("write failed: %w", err)
|
|
}
|
|
remain -= len(dialData)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func newDialRequest(reqs []Request, nonce uint64) pb.Message {
|
|
addrbs := make([][]byte, len(reqs))
|
|
for i, r := range reqs {
|
|
addrbs[i] = r.Addr.Bytes()
|
|
}
|
|
return pb.Message{
|
|
Msg: &pb.Message_DialRequest{
|
|
DialRequest: &pb.DialRequest{
|
|
Addrs: addrbs,
|
|
Nonce: nonce,
|
|
},
|
|
},
|
|
}
|
|
}
|
|
|
|
// handleDialBack receives the nonce on the dial-back stream
|
|
func (ac *client) handleDialBack(s network.Stream) {
|
|
if err := s.Scope().SetService(ServiceName); err != nil {
|
|
log.Debugf("failed to attach stream to service %s: %w", ServiceName, err)
|
|
s.Reset()
|
|
return
|
|
}
|
|
|
|
if err := s.Scope().ReserveMemory(dialBackMaxMsgSize, network.ReservationPriorityAlways); err != nil {
|
|
log.Debugf("failed to reserve memory for stream %s: %w", DialBackProtocol, err)
|
|
s.Reset()
|
|
return
|
|
}
|
|
defer s.Scope().ReleaseMemory(dialBackMaxMsgSize)
|
|
|
|
s.SetDeadline(time.Now().Add(dialBackStreamTimeout))
|
|
defer s.Close()
|
|
|
|
r := pbio.NewDelimitedReader(s, dialBackMaxMsgSize)
|
|
var msg pb.DialBack
|
|
if err := r.ReadMsg(&msg); err != nil {
|
|
log.Debugf("failed to read dialback msg from %s: %s", s.Conn().RemotePeer(), err)
|
|
s.Reset()
|
|
return
|
|
}
|
|
nonce := msg.GetNonce()
|
|
|
|
ac.mu.Lock()
|
|
ch := ac.dialBackQueues[nonce]
|
|
ac.mu.Unlock()
|
|
if ch == nil {
|
|
log.Debugf("dialback received with invalid nonce: localAdds: %s peer: %s nonce: %d", s.Conn().LocalMultiaddr(), s.Conn().RemotePeer(), nonce)
|
|
s.Reset()
|
|
return
|
|
}
|
|
select {
|
|
case ch <- s.Conn().LocalMultiaddr():
|
|
default:
|
|
log.Debugf("multiple dialbacks received: localAddr: %s peer: %s", s.Conn().LocalMultiaddr(), s.Conn().RemotePeer())
|
|
s.Reset()
|
|
return
|
|
}
|
|
w := pbio.NewDelimitedWriter(s)
|
|
res := pb.DialBackResponse{}
|
|
if err := w.WriteMsg(&res); err != nil {
|
|
log.Debugf("failed to write dialback response: %s", err)
|
|
s.Reset()
|
|
}
|
|
}
|
|
|
|
func (ac *client) areAddrsConsistent(connLocalAddr, dialedAddr ma.Multiaddr) bool {
|
|
if connLocalAddr == nil || dialedAddr == nil {
|
|
return false
|
|
}
|
|
connLocalAddr = ac.normalizeMultiaddr(connLocalAddr)
|
|
dialedAddr = ac.normalizeMultiaddr(dialedAddr)
|
|
|
|
localProtos := connLocalAddr.Protocols()
|
|
externalProtos := dialedAddr.Protocols()
|
|
if len(localProtos) != len(externalProtos) {
|
|
return false
|
|
}
|
|
for i := 0; i < len(localProtos); i++ {
|
|
if i == 0 {
|
|
switch externalProtos[i].Code {
|
|
case ma.P_DNS, ma.P_DNSADDR:
|
|
if localProtos[i].Code == ma.P_IP4 || localProtos[i].Code == ma.P_IP6 {
|
|
continue
|
|
}
|
|
return false
|
|
case ma.P_DNS4:
|
|
if localProtos[i].Code == ma.P_IP4 {
|
|
continue
|
|
}
|
|
return false
|
|
case ma.P_DNS6:
|
|
if localProtos[i].Code == ma.P_IP6 {
|
|
continue
|
|
}
|
|
return false
|
|
}
|
|
if localProtos[i].Code != externalProtos[i].Code {
|
|
return false
|
|
}
|
|
} else {
|
|
if localProtos[i].Code != externalProtos[i].Code {
|
|
return false
|
|
}
|
|
}
|
|
}
|
|
return true
|
|
}
|