2016-06-14 11:54:23 -05:00
|
|
|
package natpmp
|
|
|
|
|
|
|
|
import (
|
|
|
|
"fmt"
|
2020-02-21 15:48:53 +01:00
|
|
|
"log"
|
2016-06-14 11:54:23 -05:00
|
|
|
"net"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
const nAT_PMP_PORT = 5351
|
|
|
|
const nAT_TRIES = 9
|
|
|
|
const nAT_INITIAL_MS = 250
|
|
|
|
|
|
|
|
// A caller that implements the NAT-PMP RPC protocol.
|
|
|
|
type network struct {
|
|
|
|
gateway net.IP
|
|
|
|
}
|
|
|
|
|
2020-02-21 15:48:53 +01:00
|
|
|
func (n *network) call(msg []byte) (result []byte, err error) {
|
2016-06-14 11:54:23 -05:00
|
|
|
var server net.UDPAddr
|
|
|
|
server.IP = n.gateway
|
|
|
|
server.Port = nAT_PMP_PORT
|
|
|
|
conn, err := net.DialUDP("udp", nil, &server)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
defer conn.Close()
|
|
|
|
|
|
|
|
// 16 bytes is the maximum result size.
|
|
|
|
result = make([]byte, 16)
|
|
|
|
|
|
|
|
needNewDeadline := true
|
|
|
|
|
|
|
|
var tries uint
|
2020-02-21 15:48:53 +01:00
|
|
|
for tries = 0; tries < nAT_TRIES; {
|
2016-06-14 11:54:23 -05:00
|
|
|
if needNewDeadline {
|
2020-02-21 15:48:53 +01:00
|
|
|
err = conn.SetDeadline(time.Now().Add((nAT_INITIAL_MS << tries) * time.Millisecond))
|
2016-06-14 11:54:23 -05:00
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
needNewDeadline = false
|
|
|
|
}
|
|
|
|
_, err = conn.Write(msg)
|
|
|
|
if err != nil {
|
|
|
|
return
|
|
|
|
}
|
|
|
|
var bytesRead int
|
|
|
|
var remoteAddr *net.UDPAddr
|
|
|
|
bytesRead, remoteAddr, err = conn.ReadFromUDP(result)
|
|
|
|
if err != nil {
|
|
|
|
if err.(net.Error).Timeout() {
|
|
|
|
tries++
|
|
|
|
needNewDeadline = true
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
if !remoteAddr.IP.Equal(n.gateway) {
|
2020-02-21 15:48:53 +01:00
|
|
|
log.Printf("Ignoring packet because IPs differ:", remoteAddr, n.gateway)
|
2016-06-14 11:54:23 -05:00
|
|
|
// Ignore this packet.
|
|
|
|
// Continue without increasing retransmission timeout or deadline.
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
// Trim result to actual number of bytes received
|
|
|
|
if bytesRead < len(result) {
|
|
|
|
result = result[:bytesRead]
|
|
|
|
}
|
|
|
|
return
|
|
|
|
}
|
|
|
|
err = fmt.Errorf("Timed out trying to contact gateway")
|
|
|
|
return
|
|
|
|
}
|