2
0
mirror of synced 2025-02-24 06:38:14 +00:00

tracker: Add Announce.Context

Use it to rewrite a test that fails with recent go versions due to logging after test completion.
This commit is contained in:
Matt Joiner 2018-11-28 12:02:12 +11:00
parent f8d827e7d6
commit 3e3f2e4128
4 changed files with 50 additions and 39 deletions

View File

@ -98,6 +98,9 @@ func announceHTTP(opt Announce, _url *url.URL) (ret AnnounceResponse, err error)
req, err := http.NewRequest("GET", _url.String(), nil) req, err := http.NewRequest("GET", _url.String(), nil)
req.Header.Set("User-Agent", opt.UserAgent) req.Header.Set("User-Agent", opt.UserAgent)
req.Host = opt.HostHeader req.Host = opt.HostHeader
if opt.Context != nil {
req = req.WithContext(opt.Context)
}
resp, err := (&http.Client{ resp, err := (&http.Client{
Timeout: time.Second * 15, Timeout: time.Second * 15,
Transport: &http.Transport{ Transport: &http.Transport{

View File

@ -1,6 +1,7 @@
package tracker package tracker
import ( import (
"context"
"errors" "errors"
"net/http" "net/http"
"net/url" "net/url"
@ -61,6 +62,7 @@ type Announce struct {
ClientIp4 krpc.NodeAddr ClientIp4 krpc.NodeAddr
// If the port is zero, it's assumed to be the same as the Request.Port // If the port is zero, it's assumed to be the same as the Request.Port
ClientIp6 krpc.NodeAddr ClientIp6 krpc.NodeAddr
Context context.Context
} }
// In an FP language with currying, what order what you put these params? // In an FP language with currying, what order what you put these params?

View File

@ -2,9 +2,9 @@ package tracker
import ( import (
"bytes" "bytes"
"context"
"encoding" "encoding"
"encoding/binary" "encoding/binary"
"errors"
"fmt" "fmt"
"io" "io"
"math/rand" "math/rand"
@ -15,6 +15,7 @@ import (
"github.com/anacrolix/dht/krpc" "github.com/anacrolix/dht/krpc"
"github.com/anacrolix/missinggo" "github.com/anacrolix/missinggo"
"github.com/anacrolix/missinggo/pproffd" "github.com/anacrolix/missinggo/pproffd"
"github.com/pkg/errors"
) )
type Action int32 type Action int32
@ -188,39 +189,55 @@ func write(w io.Writer, data interface{}) error {
// args is the binary serializable request body. trailer is optional data // args is the binary serializable request body. trailer is optional data
// following it, such as for BEP 41. // following it, such as for BEP 41.
func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (responseBody *bytes.Buffer, err error) { func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (*bytes.Buffer, error) {
tid := newTransactionId() tid := newTransactionId()
err = c.write(&RequestHeader{ if err := errors.Wrap(
ConnectionId: c.connectionId, c.write(
Action: action, &RequestHeader{
TransactionId: tid, ConnectionId: c.connectionId,
}, args, options) Action: action,
if err != nil { TransactionId: tid,
return }, args, options),
"writing request",
); err != nil {
return nil, err
} }
c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts))) c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts)))
b := make([]byte, 0x800) // 2KiB b := make([]byte, 0x800) // 2KiB
for { for {
var n int var (
n, err = c.socket.Read(b) n int
if opE, ok := err.(*net.OpError); ok { readErr error
if opE.Timeout() { readDone = make(chan struct{})
c.contiguousTimeouts++ )
return go func() {
} defer close(readDone)
n, readErr = c.socket.Read(b)
}()
ctx := c.a.Context
if ctx == nil {
ctx = context.Background()
} }
if err != nil { select {
return case <-ctx.Done():
return nil, ctx.Err()
case <-readDone:
}
if opE, ok := readErr.(*net.OpError); ok && opE.Timeout() {
c.contiguousTimeouts++
}
if readErr != nil {
return nil, errors.Wrap(readErr, "reading from socket")
} }
buf := bytes.NewBuffer(b[:n]) buf := bytes.NewBuffer(b[:n])
var h ResponseHeader var h ResponseHeader
err = binary.Read(buf, binary.BigEndian, &h) err := binary.Read(buf, binary.BigEndian, &h)
switch err { switch err {
case io.ErrUnexpectedEOF: default:
panic(err)
case io.ErrUnexpectedEOF, io.EOF:
continue continue
case nil: case nil:
default:
return
} }
if h.TransactionId != tid { if h.TransactionId != tid {
continue continue
@ -229,8 +246,7 @@ func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (
if h.Action == ActionError { if h.Action == ActionError {
err = errors.New(buf.String()) err = errors.New(buf.String())
} }
responseBody = buf return buf, err
return
} }
} }

View File

@ -2,6 +2,7 @@ package tracker
import ( import (
"bytes" "bytes"
"context"
"crypto/rand" "crypto/rand"
"encoding/binary" "encoding/binary"
"fmt" "fmt"
@ -160,8 +161,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
rand.Read(req.PeerId[:]) rand.Read(req.PeerId[:])
rand.Read(req.InfoHash[:]) rand.Read(req.InfoHash[:])
wg := sync.WaitGroup{} wg := sync.WaitGroup{}
success := make(chan bool) ctx, cancel := context.WithCancel(context.Background())
fail := make(chan struct{})
for _, url := range trackers { for _, url := range trackers {
wg.Add(1) wg.Add(1)
go func(url string) { go func(url string) {
@ -169,6 +169,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
resp, err := Announce{ resp, err := Announce{
TrackerUrl: url, TrackerUrl: url,
Request: req, Request: req,
Context: ctx,
}.Do() }.Do()
if err != nil { if err != nil {
t.Logf("error announcing to %s: %s", url, err) t.Logf("error announcing to %s: %s", url, err)
@ -180,21 +181,10 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
t.Fatal(resp) t.Fatal(resp)
} }
t.Logf("announced to %s", url) t.Logf("announced to %s", url)
// TODO: Can probably get stuck here, but it's just a throwaway cancel()
// test.
success <- true
}(url) }(url)
} }
go func() { wg.Wait()
wg.Wait()
close(fail)
}()
select {
case <-fail:
// It doesn't matter if they all fail, the servers could just be down.
case <-success:
// Bail as quickly as we can. One success is enough.
}
} }
// Check that URLPath option is done correctly. // Check that URLPath option is done correctly.