Always return errors when handshakes or dialing fails
I intend to use xerrors-style error handling for special cases.
This commit is contained in:
parent
63d215568a
commit
338287486f
66
client.go
66
client.go
|
@ -37,6 +37,7 @@ import (
|
|||
"github.com/dustin/go-humanize"
|
||||
"github.com/google/btree"
|
||||
"golang.org/x/time/rate"
|
||||
"golang.org/x/xerrors"
|
||||
)
|
||||
|
||||
// Clients contain zero or more Torrents. A Client manages a blocklist, the
|
||||
|
@ -617,16 +618,13 @@ func (cl *Client) handshakesConnection(ctx context.Context, nc net.Conn, t *Torr
|
|||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ok, err = cl.initiateHandshakes(c, t)
|
||||
if !ok {
|
||||
c = nil
|
||||
}
|
||||
err = cl.initiateHandshakes(c, t)
|
||||
return
|
||||
}
|
||||
|
||||
// Returns nil connection and nil error if no connection could be established
|
||||
// for valid reasons.
|
||||
func (cl *Client) establishOutgoingConnEx(t *Torrent, addr IpPort, obfuscatedHeader bool) (c *connection, err error) {
|
||||
func (cl *Client) establishOutgoingConnEx(t *Torrent, addr IpPort, obfuscatedHeader bool) (*connection, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), func() time.Duration {
|
||||
cl.rLock()
|
||||
defer cl.rUnlock()
|
||||
|
@ -636,14 +634,16 @@ func (cl *Client) establishOutgoingConnEx(t *Torrent, addr IpPort, obfuscatedHea
|
|||
dr := cl.dialFirst(ctx, addr.String())
|
||||
nc := dr.Conn
|
||||
if nc == nil {
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if c == nil || err != nil {
|
||||
nc.Close()
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}()
|
||||
return cl.handshakesConnection(ctx, nc, t, obfuscatedHeader, addr, dr.Network)
|
||||
return nil, errors.New("dial failed")
|
||||
}
|
||||
c, err := cl.handshakesConnection(ctx, nc, t, obfuscatedHeader, addr, dr.Network)
|
||||
if err != nil {
|
||||
nc.Close()
|
||||
}
|
||||
return c, err
|
||||
}
|
||||
|
||||
// Returns nil connection and nil error if no connection could be established
|
||||
|
@ -652,13 +652,11 @@ func (cl *Client) establishOutgoingConn(t *Torrent, addr IpPort) (c *connection,
|
|||
torrent.Add("establish outgoing connection", 1)
|
||||
obfuscatedHeaderFirst := cl.config.HeaderObfuscationPolicy.Preferred
|
||||
c, err = cl.establishOutgoingConnEx(t, addr, obfuscatedHeaderFirst)
|
||||
if err != nil {
|
||||
//cl.logger.Printf("error establish connection to %s (obfuscatedHeader=%t): %v", addr, obfuscatedHeaderFirst, err)
|
||||
}
|
||||
if c != nil {
|
||||
if err == nil {
|
||||
torrent.Add("initiated conn with preferred header obfuscation", 1)
|
||||
return
|
||||
}
|
||||
//cl.logger.Printf("error establishing connection to %s (obfuscatedHeader=%t): %v", addr, obfuscatedHeaderFirst, err)
|
||||
if cl.config.HeaderObfuscationPolicy.RequirePreferred {
|
||||
// We should have just tried with the preferred header obfuscation. If it was required,
|
||||
// there's nothing else to try.
|
||||
|
@ -666,9 +664,10 @@ func (cl *Client) establishOutgoingConn(t *Torrent, addr IpPort) (c *connection,
|
|||
}
|
||||
// Try again with encryption if we didn't earlier, or without if we did.
|
||||
c, err = cl.establishOutgoingConnEx(t, addr, !obfuscatedHeaderFirst)
|
||||
if c != nil {
|
||||
if err == nil {
|
||||
torrent.Add("initiated conn with fallback header obfuscation", 1)
|
||||
}
|
||||
//cl.logger.Printf("error establishing fallback connection to %v: %v", addr, err)
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -688,9 +687,6 @@ func (cl *Client) outgoingConnection(t *Torrent, addr IpPort, ps peerSource) {
|
|||
}
|
||||
return
|
||||
}
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
c.Discovery = ps
|
||||
cl.runHandshookConn(c, t)
|
||||
|
@ -702,9 +698,10 @@ func (cl *Client) incomingPeerPort() int {
|
|||
return cl.LocalPort()
|
||||
}
|
||||
|
||||
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err error) {
|
||||
func (cl *Client) initiateHandshakes(c *connection, t *Torrent) error {
|
||||
if c.headerEncrypted {
|
||||
var rw io.ReadWriter
|
||||
var err error
|
||||
rw, c.cryptoMethod, err = mse.InitiateHandshake(
|
||||
struct {
|
||||
io.Reader
|
||||
|
@ -716,14 +713,17 @@ func (cl *Client) initiateHandshakes(c *connection, t *Torrent) (ok bool, err er
|
|||
)
|
||||
c.setRW(rw)
|
||||
if err != nil {
|
||||
return
|
||||
return xerrors.Errorf("header obfuscation handshake: %w", err)
|
||||
}
|
||||
}
|
||||
ih, ok, err := cl.connBTHandshake(c, &t.infoHash)
|
||||
if ih != t.infoHash {
|
||||
ok = false
|
||||
ih, err := cl.connBtHandshake(c, &t.infoHash)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("bittorrent protocol handshake: %w", err)
|
||||
}
|
||||
return
|
||||
if ih != t.infoHash {
|
||||
return errors.New("bittorrent protocol handshake: peer infohash didn't match")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Calls f with any secret keys.
|
||||
|
@ -775,12 +775,9 @@ func (cl *Client) receiveHandshakes(c *connection) (t *Torrent, err error) {
|
|||
err = errors.New("connection not have required header obfuscation")
|
||||
return
|
||||
}
|
||||
ih, ok, err := cl.connBTHandshake(c, nil)
|
||||
ih, err := cl.connBtHandshake(c, nil)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("error during bt handshake: %s", err)
|
||||
return
|
||||
}
|
||||
if !ok {
|
||||
err = xerrors.Errorf("during bt handshake: %w", err)
|
||||
return
|
||||
}
|
||||
cl.lock()
|
||||
|
@ -789,10 +786,9 @@ func (cl *Client) receiveHandshakes(c *connection) (t *Torrent, err error) {
|
|||
return
|
||||
}
|
||||
|
||||
// Returns !ok if handshake failed for valid reasons.
|
||||
func (cl *Client) connBTHandshake(c *connection, ih *metainfo.Hash) (ret metainfo.Hash, ok bool, err error) {
|
||||
res, ok, err := pp.Handshake(c.rw(), ih, cl.peerID, cl.extensionBytes)
|
||||
if err != nil || !ok {
|
||||
func (cl *Client) connBtHandshake(c *connection, ih *metainfo.Hash) (ret metainfo.Hash, err error) {
|
||||
res, err := pp.Handshake(c.rw(), ih, cl.peerID, cl.extensionBytes)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
ret = res.Hash
|
||||
|
|
1
go.mod
1
go.mod
|
@ -34,6 +34,7 @@ require (
|
|||
golang.org/x/net v0.0.0-20190628185345-da137c7871d7
|
||||
golang.org/x/sys v0.0.0-20190712062909-fae7ac547cb7 // indirect
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7
|
||||
)
|
||||
|
||||
go 1.13
|
||||
|
|
2
go.sum
2
go.sum
|
@ -189,3 +189,5 @@ golang.org/x/time v0.0.0-20181108054448-85acf8d2951c h1:fqgJT0MGcGpPgpWU7VRdRjuA
|
|||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4 h1:SvFZT6jyqRaOeXpc5h/JSfZenJ2O330aBsf7JfSUXmQ=
|
||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7 h1:9zdDQZ7Thm29KFXgAX/+yaf3eVbP7djjWp/dXAppNCc=
|
||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
|
|
|
@ -5,6 +5,8 @@ import (
|
|||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/xerrors"
|
||||
|
||||
"github.com/anacrolix/missinggo"
|
||||
"github.com/anacrolix/torrent/metainfo"
|
||||
)
|
||||
|
@ -75,7 +77,7 @@ type HandshakeResult struct {
|
|||
func Handshake(
|
||||
sock io.ReadWriter, ih *metainfo.Hash, peerID [20]byte, extensions PeerExtensionBits,
|
||||
) (
|
||||
res HandshakeResult, ok bool, err error,
|
||||
res HandshakeResult, err error,
|
||||
) {
|
||||
// Bytes to be sent to the peer. Should never block the sender.
|
||||
postCh := make(chan []byte, 4)
|
||||
|
@ -86,11 +88,8 @@ func Handshake(
|
|||
|
||||
defer func() {
|
||||
close(postCh) // Done writing.
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
panic(err)
|
||||
return
|
||||
}
|
||||
// Wait until writes complete before returning from handshake.
|
||||
err = <-writeDone
|
||||
|
@ -116,10 +115,11 @@ func Handshake(
|
|||
var b [68]byte
|
||||
_, err = io.ReadFull(sock, b[:68])
|
||||
if err != nil {
|
||||
err = nil
|
||||
err = xerrors.Errorf("while reading: %w", err)
|
||||
return
|
||||
}
|
||||
if string(b[:20]) != Protocol {
|
||||
err = xerrors.Errorf("unexpected protocol string")
|
||||
return
|
||||
}
|
||||
missinggo.CopyExact(&res.PeerExtensionBits, b[20:28])
|
||||
|
@ -135,6 +135,5 @@ func Handshake(
|
|||
post(peerID[:])
|
||||
}
|
||||
|
||||
ok = true
|
||||
return
|
||||
}
|
||||
|
|
|
@ -183,9 +183,8 @@ func TestTorrentMetainfoIncompleteMetadata(t *testing.T) {
|
|||
|
||||
var pex PeerExtensionBits
|
||||
pex.SetBit(pp.ExtensionBitExtended)
|
||||
hr, ok, err := pp.Handshake(nc, &ih, [20]byte{}, pex)
|
||||
hr, err := pp.Handshake(nc, &ih, [20]byte{}, pex)
|
||||
require.NoError(t, err)
|
||||
assert.True(t, ok)
|
||||
assert.True(t, hr.PeerExtensionBits.GetBit(pp.ExtensionBitExtended))
|
||||
assert.EqualValues(t, cl.PeerID(), hr.PeerID)
|
||||
assert.EqualValues(t, ih, hr.Hash)
|
||||
|
|
Loading…
Reference in New Issue