Use the new firewall callback support in go-libutp
This commit is contained in:
parent
6dd3b9c12c
commit
2940f27f9f
31
client.go
31
client.go
@ -220,7 +220,7 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
cl.conns, err = listenAll(cl.enabledPeerNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL)
|
cl.conns, err = listenAll(cl.enabledPeerNetworks(), cl.config.ListenHost, cl.config.ListenPort, cl.config.ProxyURL, cl.firewallCallback)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@ -249,6 +249,18 @@ func NewClient(cfg *ClientConfig) (cl *Client, err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cl *Client) firewallCallback(net.Addr) bool {
|
||||||
|
cl.rLock()
|
||||||
|
block := !cl.wantConns()
|
||||||
|
cl.rUnlock()
|
||||||
|
if block {
|
||||||
|
torrent.Add("connections firewalled", 1)
|
||||||
|
} else {
|
||||||
|
torrent.Add("connections not firewalled", 1)
|
||||||
|
}
|
||||||
|
return block
|
||||||
|
}
|
||||||
|
|
||||||
func (cl *Client) enabledPeerNetworks() (ns []string) {
|
func (cl *Client) enabledPeerNetworks() (ns []string) {
|
||||||
for _, n := range allPeerNetworks {
|
for _, n := range allPeerNetworks {
|
||||||
if peerNetworkEnabled(n, cl.config) {
|
if peerNetworkEnabled(n, cl.config) {
|
||||||
@ -340,16 +352,23 @@ func (cl *Client) ipIsBlocked(ip net.IP) bool {
|
|||||||
return blocked
|
return blocked
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (cl *Client) wantConns() bool {
|
||||||
|
for _, t := range cl.torrents {
|
||||||
|
if t.wantConns() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
func (cl *Client) waitAccept() {
|
func (cl *Client) waitAccept() {
|
||||||
for {
|
for {
|
||||||
for _, t := range cl.torrents {
|
|
||||||
if t.wantConns() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if cl.closed.IsSet() {
|
if cl.closed.IsSet() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if cl.wantConns() {
|
||||||
|
return
|
||||||
|
}
|
||||||
cl.event.Wait()
|
cl.event.Wait()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1012,7 +1012,7 @@ func TestMultipleTorrentsWithEncryption(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestClientAddressInUse(t *testing.T) {
|
func TestClientAddressInUse(t *testing.T) {
|
||||||
s, _ := NewUtpSocket("udp", ":50007")
|
s, _ := NewUtpSocket("udp", ":50007", nil)
|
||||||
if s != nil {
|
if s != nil {
|
||||||
defer s.Close()
|
defer s.Close()
|
||||||
}
|
}
|
||||||
|
@ -23,7 +23,7 @@ func testListenerNetwork(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func listenUtpListener(net, addr string) (l net.Listener, err error) {
|
func listenUtpListener(net, addr string) (l net.Listener, err error) {
|
||||||
l, err = NewUtpSocket(net, addr)
|
l, err = NewUtpSocket(net, addr, nil)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
20
socket.go
20
socket.go
@ -31,11 +31,11 @@ func getProxyDialer(proxyURL string) (proxy.Dialer, error) {
|
|||||||
return proxy.FromURL(fixedURL, proxy.Direct)
|
return proxy.FromURL(fixedURL, proxy.Direct)
|
||||||
}
|
}
|
||||||
|
|
||||||
func listen(network, addr, proxyURL string) (socket, error) {
|
func listen(network, addr, proxyURL string, f firewallCallback) (socket, error) {
|
||||||
if isTcpNetwork(network) {
|
if isTcpNetwork(network) {
|
||||||
return listenTcp(network, addr, proxyURL)
|
return listenTcp(network, addr, proxyURL)
|
||||||
} else if isUtpNetwork(network) {
|
} else if isUtpNetwork(network) {
|
||||||
return listenUtp(network, addr, proxyURL)
|
return listenUtp(network, addr, proxyURL, f)
|
||||||
} else {
|
} else {
|
||||||
panic(fmt.Sprintf("unknown network %q", network))
|
panic(fmt.Sprintf("unknown network %q", network))
|
||||||
}
|
}
|
||||||
@ -97,7 +97,7 @@ func setPort(addr string, port int) string {
|
|||||||
return net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
|
return net.JoinHostPort(host, strconv.FormatInt(int64(port), 10))
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenAll(networks []string, getHost func(string) string, port int, proxyURL string) ([]socket, error) {
|
func listenAll(networks []string, getHost func(string) string, port int, proxyURL string, f firewallCallback) ([]socket, error) {
|
||||||
if len(networks) == 0 {
|
if len(networks) == 0 {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
@ -106,7 +106,7 @@ func listenAll(networks []string, getHost func(string) string, port int, proxyUR
|
|||||||
nahs = append(nahs, networkAndHost{n, getHost(n)})
|
nahs = append(nahs, networkAndHost{n, getHost(n)})
|
||||||
}
|
}
|
||||||
for {
|
for {
|
||||||
ss, retry, err := listenAllRetry(nahs, port, proxyURL)
|
ss, retry, err := listenAllRetry(nahs, port, proxyURL, f)
|
||||||
if !retry {
|
if !retry {
|
||||||
return ss, err
|
return ss, err
|
||||||
}
|
}
|
||||||
@ -118,10 +118,10 @@ type networkAndHost struct {
|
|||||||
Host string
|
Host string
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []socket, retry bool, err error) {
|
func listenAllRetry(nahs []networkAndHost, port int, proxyURL string, f firewallCallback) (ss []socket, retry bool, err error) {
|
||||||
ss = make([]socket, 1, len(nahs))
|
ss = make([]socket, 1, len(nahs))
|
||||||
portStr := strconv.FormatInt(int64(port), 10)
|
portStr := strconv.FormatInt(int64(port), 10)
|
||||||
ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL)
|
ss[0], err = listen(nahs[0].Network, net.JoinHostPort(nahs[0].Host, portStr), proxyURL, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, false, fmt.Errorf("first listen: %s", err)
|
return nil, false, fmt.Errorf("first listen: %s", err)
|
||||||
}
|
}
|
||||||
@ -135,7 +135,7 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []sock
|
|||||||
}()
|
}()
|
||||||
portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
|
portStr = strconv.FormatInt(int64(missinggo.AddrPort(ss[0].Addr())), 10)
|
||||||
for _, nah := range nahs[1:] {
|
for _, nah := range nahs[1:] {
|
||||||
s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL)
|
s, err := listen(nah.Network, net.JoinHostPort(nah.Host, portStr), proxyURL, f)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return ss,
|
return ss,
|
||||||
missinggo.IsAddrInUse(err) && port == 0,
|
missinggo.IsAddrInUse(err) && port == 0,
|
||||||
@ -146,8 +146,10 @@ func listenAllRetry(nahs []networkAndHost, port int, proxyURL string) (ss []sock
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func listenUtp(network, addr, proxyURL string) (s socket, err error) {
|
type firewallCallback func(net.Addr) bool
|
||||||
us, err := NewUtpSocket(network, addr)
|
|
||||||
|
func listenUtp(network, addr, proxyURL string, fc firewallCallback) (s socket, err error) {
|
||||||
|
us, err := NewUtpSocket(network, addr, fc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
@ -6,7 +6,7 @@ import (
|
|||||||
"github.com/anacrolix/utp"
|
"github.com/anacrolix/utp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUtpSocket(network, addr string) (utpSocket, error) {
|
func NewUtpSocket(network, addr string, _ firewallCallback) (utpSocket, error) {
|
||||||
s, err := utp.NewSocket(network, addr)
|
s, err := utp.NewSocket(network, addr)
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -6,11 +6,16 @@ import (
|
|||||||
"github.com/anacrolix/go-libutp"
|
"github.com/anacrolix/go-libutp"
|
||||||
)
|
)
|
||||||
|
|
||||||
func NewUtpSocket(network, addr string) (utpSocket, error) {
|
func NewUtpSocket(network, addr string, fc firewallCallback) (utpSocket, error) {
|
||||||
s, err := utp.NewSocket(network, addr)
|
s, err := utp.NewSocket(network, addr)
|
||||||
if s == nil {
|
if s == nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
} else {
|
}
|
||||||
|
if err != nil {
|
||||||
return s, err
|
return s, err
|
||||||
}
|
}
|
||||||
|
if fc != nil {
|
||||||
|
s.SetFirewallCallback(utp.FirewallCallback(fc))
|
||||||
|
}
|
||||||
|
return s, err
|
||||||
}
|
}
|
||||||
|
@ -7,7 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
func TestNewUtpSocketErrorNilInterface(t *testing.T) {
|
func TestNewUtpSocketErrorNilInterface(t *testing.T) {
|
||||||
s, err := NewUtpSocket("fix", "your:language")
|
s, err := NewUtpSocket("fix", "your:language", nil)
|
||||||
assert.Error(t, err)
|
assert.Error(t, err)
|
||||||
if s != nil {
|
if s != nil {
|
||||||
t.Fatalf("expected nil, got %#v", s)
|
t.Fatalf("expected nil, got %#v", s)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user