tracker: Add Announce.Context
Use it to rewrite a test that fails with recent go versions due to logging after test completion.
This commit is contained in:
parent
f8d827e7d6
commit
3e3f2e4128
@ -98,6 +98,9 @@ func announceHTTP(opt Announce, _url *url.URL) (ret AnnounceResponse, err error)
|
|||||||
req, err := http.NewRequest("GET", _url.String(), nil)
|
req, err := http.NewRequest("GET", _url.String(), nil)
|
||||||
req.Header.Set("User-Agent", opt.UserAgent)
|
req.Header.Set("User-Agent", opt.UserAgent)
|
||||||
req.Host = opt.HostHeader
|
req.Host = opt.HostHeader
|
||||||
|
if opt.Context != nil {
|
||||||
|
req = req.WithContext(opt.Context)
|
||||||
|
}
|
||||||
resp, err := (&http.Client{
|
resp, err := (&http.Client{
|
||||||
Timeout: time.Second * 15,
|
Timeout: time.Second * 15,
|
||||||
Transport: &http.Transport{
|
Transport: &http.Transport{
|
||||||
|
@ -1,6 +1,7 @@
|
|||||||
package tracker
|
package tracker
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
@ -61,6 +62,7 @@ type Announce struct {
|
|||||||
ClientIp4 krpc.NodeAddr
|
ClientIp4 krpc.NodeAddr
|
||||||
// If the port is zero, it's assumed to be the same as the Request.Port
|
// If the port is zero, it's assumed to be the same as the Request.Port
|
||||||
ClientIp6 krpc.NodeAddr
|
ClientIp6 krpc.NodeAddr
|
||||||
|
Context context.Context
|
||||||
}
|
}
|
||||||
|
|
||||||
// In an FP language with currying, what order what you put these params?
|
// In an FP language with currying, what order what you put these params?
|
||||||
|
@ -2,9 +2,9 @@ package tracker
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding"
|
"encoding"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/anacrolix/dht/krpc"
|
"github.com/anacrolix/dht/krpc"
|
||||||
"github.com/anacrolix/missinggo"
|
"github.com/anacrolix/missinggo"
|
||||||
"github.com/anacrolix/missinggo/pproffd"
|
"github.com/anacrolix/missinggo/pproffd"
|
||||||
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Action int32
|
type Action int32
|
||||||
@ -188,39 +189,55 @@ func write(w io.Writer, data interface{}) error {
|
|||||||
|
|
||||||
// args is the binary serializable request body. trailer is optional data
|
// args is the binary serializable request body. trailer is optional data
|
||||||
// following it, such as for BEP 41.
|
// following it, such as for BEP 41.
|
||||||
func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (responseBody *bytes.Buffer, err error) {
|
func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (*bytes.Buffer, error) {
|
||||||
tid := newTransactionId()
|
tid := newTransactionId()
|
||||||
err = c.write(&RequestHeader{
|
if err := errors.Wrap(
|
||||||
ConnectionId: c.connectionId,
|
c.write(
|
||||||
Action: action,
|
&RequestHeader{
|
||||||
TransactionId: tid,
|
ConnectionId: c.connectionId,
|
||||||
}, args, options)
|
Action: action,
|
||||||
if err != nil {
|
TransactionId: tid,
|
||||||
return
|
}, args, options),
|
||||||
|
"writing request",
|
||||||
|
); err != nil {
|
||||||
|
return nil, err
|
||||||
}
|
}
|
||||||
c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts)))
|
c.socket.SetReadDeadline(time.Now().Add(timeout(c.contiguousTimeouts)))
|
||||||
b := make([]byte, 0x800) // 2KiB
|
b := make([]byte, 0x800) // 2KiB
|
||||||
for {
|
for {
|
||||||
var n int
|
var (
|
||||||
n, err = c.socket.Read(b)
|
n int
|
||||||
if opE, ok := err.(*net.OpError); ok {
|
readErr error
|
||||||
if opE.Timeout() {
|
readDone = make(chan struct{})
|
||||||
c.contiguousTimeouts++
|
)
|
||||||
return
|
go func() {
|
||||||
}
|
defer close(readDone)
|
||||||
|
n, readErr = c.socket.Read(b)
|
||||||
|
}()
|
||||||
|
ctx := c.a.Context
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
}
|
}
|
||||||
if err != nil {
|
select {
|
||||||
return
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case <-readDone:
|
||||||
|
}
|
||||||
|
if opE, ok := readErr.(*net.OpError); ok && opE.Timeout() {
|
||||||
|
c.contiguousTimeouts++
|
||||||
|
}
|
||||||
|
if readErr != nil {
|
||||||
|
return nil, errors.Wrap(readErr, "reading from socket")
|
||||||
}
|
}
|
||||||
buf := bytes.NewBuffer(b[:n])
|
buf := bytes.NewBuffer(b[:n])
|
||||||
var h ResponseHeader
|
var h ResponseHeader
|
||||||
err = binary.Read(buf, binary.BigEndian, &h)
|
err := binary.Read(buf, binary.BigEndian, &h)
|
||||||
switch err {
|
switch err {
|
||||||
case io.ErrUnexpectedEOF:
|
default:
|
||||||
|
panic(err)
|
||||||
|
case io.ErrUnexpectedEOF, io.EOF:
|
||||||
continue
|
continue
|
||||||
case nil:
|
case nil:
|
||||||
default:
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
if h.TransactionId != tid {
|
if h.TransactionId != tid {
|
||||||
continue
|
continue
|
||||||
@ -229,8 +246,7 @@ func (c *udpAnnounce) request(action Action, args interface{}, options []byte) (
|
|||||||
if h.Action == ActionError {
|
if h.Action == ActionError {
|
||||||
err = errors.New(buf.String())
|
err = errors.New(buf.String())
|
||||||
}
|
}
|
||||||
responseBody = buf
|
return buf, err
|
||||||
return
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -2,6 +2,7 @@ package tracker
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"crypto/rand"
|
"crypto/rand"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
@ -160,8 +161,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
|
|||||||
rand.Read(req.PeerId[:])
|
rand.Read(req.PeerId[:])
|
||||||
rand.Read(req.InfoHash[:])
|
rand.Read(req.InfoHash[:])
|
||||||
wg := sync.WaitGroup{}
|
wg := sync.WaitGroup{}
|
||||||
success := make(chan bool)
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
fail := make(chan struct{})
|
|
||||||
for _, url := range trackers {
|
for _, url := range trackers {
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func(url string) {
|
go func(url string) {
|
||||||
@ -169,6 +169,7 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
|
|||||||
resp, err := Announce{
|
resp, err := Announce{
|
||||||
TrackerUrl: url,
|
TrackerUrl: url,
|
||||||
Request: req,
|
Request: req,
|
||||||
|
Context: ctx,
|
||||||
}.Do()
|
}.Do()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("error announcing to %s: %s", url, err)
|
t.Logf("error announcing to %s: %s", url, err)
|
||||||
@ -180,21 +181,10 @@ func TestAnnounceRandomInfoHashThirdParty(t *testing.T) {
|
|||||||
t.Fatal(resp)
|
t.Fatal(resp)
|
||||||
}
|
}
|
||||||
t.Logf("announced to %s", url)
|
t.Logf("announced to %s", url)
|
||||||
// TODO: Can probably get stuck here, but it's just a throwaway
|
cancel()
|
||||||
// test.
|
|
||||||
success <- true
|
|
||||||
}(url)
|
}(url)
|
||||||
}
|
}
|
||||||
go func() {
|
wg.Wait()
|
||||||
wg.Wait()
|
|
||||||
close(fail)
|
|
||||||
}()
|
|
||||||
select {
|
|
||||||
case <-fail:
|
|
||||||
// It doesn't matter if they all fail, the servers could just be down.
|
|
||||||
case <-success:
|
|
||||||
// Bail as quickly as we can. One success is enough.
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Check that URLPath option is done correctly.
|
// Check that URLPath option is done correctly.
|
||||||
|
Loading…
x
Reference in New Issue
Block a user