From 3e3f2e41284117822b66842e523cae39a559b49b Mon Sep 17 00:00:00 2001 From: Matt Joiner Date: Wed, 28 Nov 2018 12:02:12 +1100 Subject: [PATCH] tracker: Add Announce.Context Use it to rewrite a test that fails with recent go versions due to logging after test completion. --- tracker/http.go | 3 +++ tracker/tracker.go | 2 ++ tracker/udp.go | 64 ++++++++++++++++++++++++++++----------------- tracker/udp_test.go | 20 ++++---------- 4 files changed, 50 insertions(+), 39 deletions(-) diff --git a/tracker/http.go b/tracker/http.go index 69764040..87927e77 100644 --- a/tracker/http.go +++ b/tracker/http.go @@ -98,6 +98,9 @@ func announceHTTP(opt Announce, _url *url.URL) (ret AnnounceResponse, err error) req, err := http.NewRequest("GET", _url.String(), nil) req.Header.Set("User-Agent", opt.UserAgent) req.Host = opt.HostHeader + if opt.Context != nil { + req = req.WithContext(opt.Context) + } resp, err := (&http.Client{ Timeout: time.Second * 15, Transport: &http.Transport{ diff --git a/tracker/tracker.go b/tracker/tracker.go index a56a99b7..e7260503 100644 --- a/tracker/tracker.go +++ b/tracker/tracker.go @@ -1,6 +1,7 @@ package tracker import ( + "context" "errors" "net/http" "net/url" @@ -61,6 +62,7 @@ type Announce struct { ClientIp4 krpc.NodeAddr // If the port is zero, it's assumed to be the same as the Request.Port ClientIp6 krpc.NodeAddr + Context context.Context } // In an FP language with currying, what order what you put these params? diff --git a/tracker/udp.go b/tracker/udp.go index 656cc7df..c694285d 100644 --- a/tracker/udp.go +++ b/tracker/udp.go @@ -2,9 +2,9 @@ package tracker import ( "bytes" + "context" "encoding" "encoding/binary" - "errors" "fmt" "io" "math/rand" @@ -15,6 +15,7 @@ import ( "github.com/anacrolix/dht/krpc" "github.com/anacrolix/missinggo" "github.com/anacrolix/missinggo/pproffd" + "github.com/pkg/errors" ) type Action int32 @@ -188,39 +189,55 @@ func write(w io.Writer, data interface{}) error { // args is the binary serializable request body. trailer is optional data // following it, such as for BEP 41. -func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (responseBody *bytes.Buffer, err error) { +func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (*bytes.Buffer, error) { tid := newTransactionId() - err = c.write(&RequestHeader{ - ConnectionId: c.connectionId, - Action: action, - TransactionId: tid, - }, args, options) - if err != nil { - return + if err := errors.Wrap( + c.write( + &RequestHeader{ + ConnectionId: c.connectionId, + Action: action, + TransactionId: tid, + }, args, options), + "writing request", + ); err != nil { + return nil, err } c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts))) b := make([]byte, 0x800) // 2KiB for { - var n int - n, err = c.socket.Read(b) - if opE, ok := err.(*net.OpError); ok { - if opE.Timeout() { - c.contiguousTimeouts++ - return - } + var ( + n int + readErr error + readDone = make(chan struct{}) + ) + go func() { + defer close(readDone) + n, readErr = c.socket.Read(b) + }() + ctx := c.a.Context + if ctx == nil { + ctx = context.Background() } - if err != nil { - return + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-readDone: + } + if opE, ok := readErr.(*net.OpError); ok && opE.Timeout() { + c.contiguousTimeouts++ + } + if readErr != nil { + return nil, errors.Wrap(readErr, "reading from socket") } buf := bytes.NewBuffer(b[:n]) var h ResponseHeader - err = binary.Read(buf, binary.BigEndian, &h) + err := binary.Read(buf, binary.BigEndian, &h) switch err { - case io.ErrUnexpectedEOF: + default: + panic(err) + case io.ErrUnexpectedEOF, io.EOF: continue case nil: - default: - return } if h.TransactionId != tid { continue @@ -229,8 +246,7 @@ func (c *udpAnnounce) request(action Action, args interface{}, options []byte) ( if h.Action == ActionError { err = errors.New(buf.String()) } - responseBody = buf - return + return buf, err } } diff --git a/tracker/udp_test.go b/tracker/udp_test.go index a6d4cb15..ea500f18 100644 --- a/tracker/udp_test.go +++ b/tracker/udp_test.go @@ -2,6 +2,7 @@ package tracker import ( "bytes" + "context" "crypto/rand" "encoding/binary" "fmt" @@ -160,8 +161,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) { rand.Read(req.PeerId[:]) rand.Read(req.InfoHash[:]) wg := sync.WaitGroup{} - success := make(chan bool) - fail := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) for _, url := range trackers { wg.Add(1) go func(url string) { @@ -169,6 +169,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) { resp, err := Announce{ TrackerUrl: url, Request: req, + Context: ctx, }.Do() if err != nil { t.Logf("error announcing to %s: %s", url, err) @@ -180,21 +181,10 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) { t.Fatal(resp) } t.Logf("announced to %s", url) - // TODO: Can probably get stuck here, but it's just a throwaway - // test. - success <- true + cancel() }(url) } - go func() { - wg.Wait() - close(fail) - }() - select { - case <-fail: - // It doesn't matter if they all fail, the servers could just be down. - case <-success: - // Bail as quickly as we can. One success is enough. - } + wg.Wait() } // Check that URLPath option is done correctly.