Local Pairing: connection string version 2 with sharing multiple server addresses (#3909)

* feat: network functions for local pairing (#3898)
This commit is contained in:
Igor Sirotin 2023-08-22 19:18:14 +03:00 committed by GitHub
parent e922fc40d5
commit 5a8f1feea9
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
16 changed files with 680 additions and 102 deletions

View File

@ -1 +1 @@
0.163.14
0.164.0

View File

@ -21,8 +21,8 @@ func makeRandomSerialNumber() (*big.Int, error) {
return rand.Int(rand.Reader, serialNumberLimit)
}
func GenerateX509Cert(sn *big.Int, from, to time.Time, hostname string) *x509.Certificate {
c := &x509.Certificate{
func GenerateX509Cert(sn *big.Int, from, to time.Time, IPAddresses []net.IP, DNSNames []string) *x509.Certificate {
return &x509.Certificate{
SerialNumber: sn,
Subject: pkix.Name{Organization: []string{"Self-signed cert"}},
NotBefore: from,
@ -31,16 +31,9 @@ func GenerateX509Cert(sn *big.Int, from, to time.Time, hostname string) *x509.Ce
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
BasicConstraintsValid: true,
IsCA: true,
IPAddresses: IPAddresses,
DNSNames: DNSNames,
}
ip := net.ParseIP(hostname)
if ip != nil {
c.IPAddresses = []net.IP{ip}
} else {
c.DNSNames = []string{hostname}
}
return c
}
func GenerateX509PEMs(cert *x509.Certificate, key *ecdsa.PrivateKey) (certPem, keyPem []byte, err error) {
@ -59,7 +52,7 @@ func GenerateX509PEMs(cert *x509.Certificate, key *ecdsa.PrivateKey) (certPem, k
return
}
func GenerateTLSCert(notBefore, notAfter time.Time, hostname string) (*tls.Certificate, []byte, error) {
func GenerateTLSCert(notBefore, notAfter time.Time, IPAddresses []net.IP, DNSNames []string) (*tls.Certificate, []byte, error) {
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, nil, err
@ -70,7 +63,7 @@ func GenerateTLSCert(notBefore, notAfter time.Time, hostname string) (*tls.Certi
return nil, nil, err
}
cert := GenerateX509Cert(sn, notBefore, notAfter, hostname)
cert := GenerateX509Cert(sn, notBefore, notAfter, IPAddresses, DNSNames)
certPem, keyPem, err := GenerateX509PEMs(cert, priv)
if err != nil {
return nil, nil, err
@ -88,7 +81,7 @@ func generateMediaTLSCert() error {
notBefore := time.Now()
notAfter := notBefore.Add(365 * 24 * time.Hour)
finalCert, certPem, err := GenerateTLSCert(notBefore, notAfter, Localhost)
finalCert, certPem, err := GenerateTLSCert(notBefore, notAfter, []net.IP{}, []string{Localhost})
if err != nil {
return err
}

View File

@ -1,6 +1,7 @@
package server
import (
"net"
"testing"
"time"
@ -42,12 +43,12 @@ func (s *CertsSuite) TestGenerateX509Cert() {
notBefore := time.Now()
notAfter := notBefore.Add(time.Hour)
c1 := GenerateX509Cert(s.SN, notBefore, notAfter, Localhost)
c1 := GenerateX509Cert(s.SN, notBefore, notAfter, []net.IP{}, []string{Localhost})
s.Require().Exactly([]string{Localhost}, c1.DNSNames)
s.Require().Nil(c1.IPAddresses)
s.Require().Empty(c1.IPAddresses)
c2 := GenerateX509Cert(s.SN, notBefore, notAfter, DefaultIP.String())
c2 := GenerateX509Cert(s.SN, notBefore, notAfter, []net.IP{LocalHostIP}, []string{})
s.Require().Len(c2.IPAddresses, 1)
s.Require().Equal(DefaultIP.String(), c2.IPAddresses[0].String())
s.Require().Nil(c2.DNSNames)
s.Require().Equal(LocalHostIP.String(), c2.IPAddresses[0].String())
s.Require().Empty(c2.DNSNames)
}

View File

@ -2,10 +2,14 @@ package server
import (
"net"
"go.uber.org/zap"
"github.com/status-im/status-go/logutils"
)
var (
DefaultIP = net.IP{127, 0, 0, 1}
LocalHostIP = net.IP{127, 0, 0, 1}
Localhost = "Localhost"
)
@ -20,3 +24,147 @@ func GetOutboundIP() (net.IP, error) {
return localAddr.IP, nil
}
// addrToIPNet casts addr to IPNet.
// Returns nil if addr is not of IPNet type.
func addrToIPNet(addr net.Addr) *net.IPNet {
switch v := addr.(type) {
case *net.IPNet:
return v
default:
return nil
}
}
// filterAddressesForPairingServer filters private unicast addresses.
// ips is a 2-dimensional array, where each sub-array is a list of IP
// addresses for a single network interface.
func filterAddressesForPairingServer(ips [][]net.IP) []net.IP {
var result []net.IP
for _, niIps := range ips {
var ipv4, ipv6 []net.IP
for _, ip := range niIps {
// Only take private global unicast addrs
if !ip.IsGlobalUnicast() || !ip.IsPrivate() {
continue
}
if v := ip.To4(); v != nil {
ipv4 = append(ipv4, ip)
} else {
ipv6 = append(ipv6, ip)
}
}
// Prefer IPv4 over IPv6 for shorter connection string
if len(ipv4) == 0 {
result = append(result, ipv6...)
} else {
result = append(result, ipv4...)
}
}
return result
}
// getLocalAddresses returns an array of all addresses
// of all available network interfaces.
func getLocalAddresses() ([][]net.IP, error) {
nis, err := net.Interfaces()
if err != nil {
return nil, err
}
var ips [][]net.IP
for _, ni := range nis {
var niIps []net.IP
addrs, err := ni.Addrs()
if err != nil {
logutils.ZapLogger().Warn("failed to get addresses of network interface",
zap.String("networkInterface", ni.Name),
zap.Error(err))
continue
}
for _, addr := range addrs {
var ip net.IP
if ipNet := addrToIPNet(addr); ipNet == nil {
continue
} else {
ip = ipNet.IP
}
niIps = append(niIps, ip)
}
if len(niIps) > 0 {
ips = append(ips, niIps)
}
}
return ips, nil
}
// GetLocalAddressesForPairingServer is a high-level func
// that returns a list of addresses to be used by local pairing server.
func GetLocalAddressesForPairingServer() ([]net.IP, error) {
ips, err := getLocalAddresses()
if err != nil {
return nil, err
}
return filterAddressesForPairingServer(ips), nil
}
// findReachableAddresses returns a filtered remoteIps array,
// in which each IP matches one or more of given localNets.
func findReachableAddresses(remoteIPs []net.IP, localNets []net.IPNet) []net.IP {
var result []net.IP
for _, localNet := range localNets {
for _, remoteIP := range remoteIPs {
if localNet.Contains(remoteIP) {
result = append(result, remoteIP)
}
}
}
return result
}
// getAllAvailableNetworks collects all networks
// from available network interfaces.
func getAllAvailableNetworks() ([]net.IPNet, error) {
var localNets []net.IPNet
nis, err := net.Interfaces()
if err != nil {
return nil, err
}
for _, ni := range nis {
addrs, err := ni.Addrs()
if err != nil {
logutils.ZapLogger().Warn("failed to get addresses of network interface",
zap.String("networkInterface", ni.Name),
zap.Error(err))
continue
}
for _, localAddr := range addrs {
localNets = append(localNets, *addrToIPNet(localAddr))
}
}
return localNets, nil
}
// FindReachableAddressesForPairingClient is a high-level func
// that returns a reachable server's address to be used by local pairing client.
func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) {
nets, err := getAllAvailableNetworks()
if err != nil {
return nil, err
}
return findReachableAddresses(serverIps, nets), nil
}

271
server/ips_test.go Normal file
View File

@ -0,0 +1,271 @@
package server
import (
"fmt"
"net"
"testing"
"github.com/stretchr/testify/suite"
"go.uber.org/zap"
"github.com/status-im/status-go/server/servertest"
)
func TestIPsTestingSuite(t *testing.T) {
suite.Run(t, new(IPsTestingSuite))
}
type IPsTestingSuite struct {
suite.Suite
servertest.TestLoggerComponents
}
func (s *IPsTestingSuite) SetupSuite() {
s.SetupLoggerComponents()
}
func (s *IPsTestingSuite) TestConnectionParams_GetLocalAddressesForPairingServer() {
allIps := [][]net.IP{
{
net.IPv4(127, 0, 0, 1),
net.IPv6loopback,
},
{
net.IPv4(192, 168, 1, 42),
net.IP{0xfc, 0x80, 0, 0, 0, 0, 0, 0, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
},
{
net.IPv4(11, 12, 13, 14),
},
{
net.IP{0xfc, 0x80, 0, 0, 0, 0, 0, 0, 0xff, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
},
{
net.IP{0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0xff, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08},
},
}
// First NI is a loop-back
ni0 := allIps[0]
s.Require().NotNil(ni0[0].To4())
s.Require().True(ni0[0].IsLoopback())
s.Require().Len(ni0[1], net.IPv6len)
s.Require().True(ni0[1].IsLoopback())
// Second NI's both IP addresses fits the needs. IPv6 should be filtered out.
ni1 := allIps[1]
s.Require().NotNil(ni1[0].To4())
s.Require().True(ni1[0].IsGlobalUnicast())
s.Require().True(ni1[0].IsPrivate())
s.Require().Len(ni1[1], net.IPv6len)
s.Require().True(ni1[1].IsGlobalUnicast())
s.Require().True(ni1[1].IsPrivate())
// Next NI should be filtered out as non-private
ni2 := allIps[2]
s.Require().NotNil(ni2[0].To4())
s.Require().False(ni2[0].IsPrivate())
// Next NI fits the needs, should be taken,
// as no preferred IPv4 is available on this NI.
ni3 := allIps[3]
s.Require().Len(ni3[0], net.IPv6len)
s.Require().True(ni3[0].IsGlobalUnicast())
s.Require().True(ni3[0].IsPrivate())
// Last NI has a link-local unicast address,
// which should be filtered out as non-private.
ni4 := allIps[4]
s.Require().Len(ni4[0], net.IPv6len)
s.Require().True(ni4[0].IsLinkLocalUnicast())
s.Require().False(ni4[0].IsGlobalUnicast())
s.Require().False(ni4[0].IsPrivate())
ips := filterAddressesForPairingServer(allIps)
s.Require().Len(ips, 2)
s.Require().NotNil(ips[0].To4())
s.Require().NotNil(ni1[0].To4())
s.Require().Equal(ips[0].To4(), ni1[0].To4())
s.Require().Equal(ips[1], ni3[0])
}
func (s *IPsTestingSuite) TestConnectionParams_FindReachableAddresses() {
var remoteIps []net.IP
var localNets []net.IPNet
var ips []net.IP
// Test 1
remoteIps = []net.IP{
net.IPv4(10, 1, 2, 3),
net.IPv4(172, 16, 2, 42),
net.IPv4(192, 168, 1, 42),
}
localNets = []net.IPNet{
{IP: net.IPv4(192, 168, 1, 43), Mask: net.IPv4Mask(255, 255, 255, 0)},
}
ips = findReachableAddresses(remoteIps, localNets)
s.Require().Len(ips, 1)
s.Require().Equal(ips[0], remoteIps[2])
// Test 2
remoteIps = []net.IP{
net.IPv4(10, 1, 2, 3),
net.IPv4(172, 16, 2, 42),
net.IPv4(192, 168, 1, 42),
}
localNets = []net.IPNet{
{IP: net.IPv4(10, 1, 1, 1), Mask: net.IPv4Mask(255, 255, 0, 0)},
{IP: net.IPv4(172, 16, 2, 43), Mask: net.IPv4Mask(255, 255, 255, 0)},
{IP: net.IPv4(192, 168, 2, 43), Mask: net.IPv4Mask(255, 255, 255, 0)},
}
ips = findReachableAddresses(remoteIps, localNets)
s.Require().Len(ips, 2)
s.Require().Equal(ips[0], remoteIps[0])
s.Require().Equal(ips[1], remoteIps[1])
// Test 3
remoteIps = []net.IP{
net.IPv4(10, 1, 2, 3),
net.IPv4(172, 16, 2, 42),
net.IPv4(192, 168, 1, 42),
}
localNets = []net.IPNet{}
ips = findReachableAddresses(remoteIps, localNets)
s.Require().Len(ips, 0)
// Test 4
remoteIps = []net.IP{}
localNets = []net.IPNet{}
ips = findReachableAddresses(remoteIps, localNets)
s.Require().Len(ips, 0)
}
func (s *IPsTestingSuite) TestConnectionParams_RealNetworksTest() {
// This test is intended to be run manually.
// 1. set `printDetails` to true
// 2. run Part 1 on 2 devices
// 3. copy printed results to Part 2
// 4. update expected results in Part 3
// 5. run Part 3
// printing is disabled by default to avoid showing sensitive information
const printDetails = false
printLocalAddresses := func(in [][]net.IP) {
fmt.Println("{")
for _, a := range in {
fmt.Println(" {")
for _, v := range a {
fmt.Println(" net.ParseIP(\"", v.String(), "\"),")
}
fmt.Println(" },")
}
fmt.Println("}")
}
printNets := func(in []net.IPNet) {
fmt.Println("{")
for _, n := range in {
fmt.Println(" \"", n.String(), "\",")
}
fmt.Println("}")
}
parseNets := func(in []string) []net.IPNet {
var out []net.IPNet
for _, v := range in {
_, network, err := net.ParseCIDR(v)
s.Require().NoError(err)
out = append(out, *network)
}
return out
}
// Part 1:
// print needed stuff. Run on both machines.
addrs, err := getLocalAddresses()
s.Require().NoError(err)
s.Logger.Info("MacOS:", zap.Any("addrs", addrs))
if printDetails {
printLocalAddresses(addrs)
}
nets, err := getAllAvailableNetworks()
s.Require().NoError(err)
s.Logger.Info("MacOS:", zap.Any("nets", nets))
if printDetails {
printNets(nets)
}
// Part 2:
// Input all printed devices details below
macNIs := [][]net.IP{
{
net.ParseIP("127.0.0.1"),
net.ParseIP("::1"),
net.ParseIP("fe80::1"),
},
{
net.ParseIP("fe80::c1f:ee0d:1476:dd9a"),
net.ParseIP("192.168.1.36"),
},
{
net.ParseIP("172.16.9.1"),
},
}
macNets := parseNets([]string{
"127.0.0.1/8",
"::1/128",
"fe80::1/64",
"fe80::c1f:ee0d:1476:dd9a/64",
"192.168.1.36/24",
"172.16.9.1/24",
})
winNIs := [][]net.IP{
{
net.ParseIP("fe80::6fd7:5ce4:554f:165a"),
net.ParseIP("192.168.1.33"),
},
{
net.ParseIP("fe80::ffa5:98e1:285c:42eb"),
net.ParseIP("10.0.85.2"),
},
{
net.ParseIP("::1"),
net.ParseIP("127.0.0.1"),
},
}
winNets := parseNets([]string{
"fe80::6fd7:5ce4:554f:165a/64",
"192.168.1.33/24",
"fe80::ffa5:98e1:285c:42eb/64",
"10.0.85.2/32",
"::1/128",
"127.0.0.1/8",
})
// Part 3:
// The test itself
// Windows as server, Mac as client
winIPs := filterAddressesForPairingServer(winNIs)
winReachableIps := findReachableAddresses(winIPs, macNets)
s.Require().Len(winReachableIps, 1)
s.Require().Equal(winReachableIps[0].String(), "192.168.1.33")
// Windows as server, Mac as client
macIPs := filterAddressesForPairingServer(macNIs)
macReachableIps := findReachableAddresses(macIPs, winNets)
s.Require().Len(macReachableIps, 1)
s.Require().Equal(macReachableIps[0].String(), "192.168.1.36")
}

View File

@ -9,11 +9,14 @@ import (
"encoding/pem"
"fmt"
"math/big"
"net"
"net/url"
"time"
"go.uber.org/zap"
"github.com/status-im/status-go/logutils"
"github.com/status-im/status-go/server"
"github.com/status-im/status-go/signal"
)
func makeSerialNumberFromKey(pk *ecdsa.PrivateKey) *big.Int {
@ -23,8 +26,8 @@ func makeSerialNumberFromKey(pk *ecdsa.PrivateKey) *big.Int {
return new(big.Int).SetBytes(h.Sum(nil))
}
func GenerateCertFromKey(pk *ecdsa.PrivateKey, from time.Time, hostname string) (tls.Certificate, []byte, error) {
cert := server.GenerateX509Cert(makeSerialNumberFromKey(pk), from, from.Add(time.Hour), hostname)
func GenerateCertFromKey(pk *ecdsa.PrivateKey, from time.Time, IPAddresses []net.IP, DNSNames []string) (tls.Certificate, []byte, error) {
cert := server.GenerateX509Cert(makeSerialNumberFromKey(pk), from, from.Add(time.Hour), IPAddresses, DNSNames)
certPem, keyPem, err := server.GenerateX509PEMs(cert, pk)
if err != nil {
return tls.Certificate{}, nil, err
@ -111,13 +114,13 @@ func getServerCert(URL *url.URL) (*x509.Certificate, error) {
conn, err := tls.Dial("tcp", URL.Host, conf)
if err != nil {
signal.SendLocalPairingEvent(Event{Type: EventConnectionError, Error: err.Error(), Action: ActionConnect})
return nil, err
}
defer conn.Close()
// No error on the dial out then the URL.Host is accessible
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})
defer func(conn *tls.Conn) {
if e := conn.Close(); e != nil {
logutils.ZapLogger().Warn("failed to close temporary TLS connection:", zap.Error(e))
}
}(conn)
certs := conn.ConnectionState().PeerCertificates
if len(certs) != 1 {

View File

@ -14,6 +14,7 @@ import (
"github.com/status-im/status-go/api"
"github.com/status-im/status-go/logutils"
"github.com/status-im/status-go/server"
"github.com/status-im/status-go/signal"
)
@ -36,16 +37,41 @@ type BaseClient struct {
// NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams
func NewBaseClient(c *ConnectionParams) (*BaseClient, error) {
u, err := c.URL()
var baseAddress *url.URL
var serverCert *x509.Certificate
var certErrs error
netIps, err := server.FindReachableAddressesForPairingClient(c.netIPs)
if err != nil {
return nil, err
}
serverCert, err := getServerCert(u)
for i := range netIps {
u, err := c.URL(i)
if err != nil {
return nil, err
}
serverCert, err = getServerCert(u)
if err != nil {
certErrs = fmt.Errorf("%sconnecting to '%s' failed: %s; ", certErrs.Error(), u, err.Error())
continue
}
baseAddress = u
break
}
if serverCert == nil {
certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs)
signal.SendLocalPairingEvent(Event{Type: EventConnectionError, Error: certErrs.Error(), Action: ActionConnect})
return nil, certErrs
}
// No error on the dial out then the URL.Host is accessible
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})
err = verifyCert(serverCert, c.publicKey)
if err != nil {
return nil, err
@ -78,7 +104,7 @@ func NewBaseClient(c *ConnectionParams) (*BaseClient, error) {
Client: &http.Client{Transport: tr, Jar: cj},
serverCert: serverCert,
challengeTaker: NewChallengeTaker(NewPayloadEncryptor(c.aesKey)),
baseAddress: u,
baseAddress: baseAddress,
}, nil
}

View File

@ -18,7 +18,7 @@ import (
type TestPairingServerComponents struct {
EphemeralPK *ecdsa.PrivateKey
EphemeralAES []byte
OutboundIP net.IP
IPAddresses []net.IP
Cert tls.Certificate
SS *SenderServer
RS *ReceiverServer
@ -36,19 +36,20 @@ func (tpsc *TestPairingServerComponents) SetupPairingServerComponents(t *testing
tpsc.EphemeralAES, err = common.MakeECDHSharedKey(tpsc.EphemeralPK, &tpsc.EphemeralPK.PublicKey)
require.NoError(t, err)
// 3) Device outbound IP address
tpsc.OutboundIP, err = server.GetOutboundIP()
// 3) Device IP address
tpsc.IPAddresses, err = server.GetLocalAddressesForPairingServer()
require.NoError(t, err)
// Generate tls.Certificate and Server
tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, time.Now(), tpsc.OutboundIP.String())
tpsc.Cert, _, err = GenerateCertFromKey(tpsc.EphemeralPK, time.Now(), tpsc.IPAddresses, []string{})
require.NoError(t, err)
sc := &ServerConfig{
PK: &tpsc.EphemeralPK.PublicKey,
EK: tpsc.EphemeralAES,
Cert: &tpsc.Cert,
Hostname: tpsc.OutboundIP.String(),
IPAddresses: tpsc.IPAddresses,
ListenIP: net.IPv4zero,
}
tpsc.SS, err = NewSenderServer(nil, &SenderServerConfig{ServerConfig: sc, SenderConfig: &SenderConfig{}})

View File

@ -3,6 +3,7 @@ package pairing
import (
"crypto/ecdsa"
"crypto/tls"
"net"
"github.com/status-im/status-go/multiaccounts"
"github.com/status-im/status-go/params"
@ -67,7 +68,8 @@ type ServerConfig struct {
PK *ecdsa.PublicKey `json:"-"`
EK []byte `json:"-"`
Cert *tls.Certificate `json:"-"`
Hostname string `json:"-"`
ListenIP net.IP `json:"-"`
IPAddresses []net.IP `json:"-"`
}
type ClientConfig struct{}

View File

@ -20,16 +20,16 @@ const (
type ConnectionParams struct {
version versioning.ConnectionParamVersion
netIP net.IP
netIPs []net.IP
port int
publicKey *ecdsa.PublicKey
aesKey []byte
}
func NewConnectionParams(netIP net.IP, port int, publicKey *ecdsa.PublicKey, aesKey []byte) *ConnectionParams {
func NewConnectionParams(netIPs []net.IP, port int, publicKey *ecdsa.PublicKey, aesKey []byte) *ConnectionParams {
cp := new(ConnectionParams)
cp.version = versioning.LatestConnectionParamVer
cp.netIP = netIP
cp.netIPs = netIPs
cp.port = port
cp.publicKey = publicKey
cp.aesKey = aesKey
@ -45,17 +45,72 @@ func NewConnectionParams(netIP net.IP, port int, publicKey *ecdsa.PublicKey, aes
// - string type identifier
// - version
// - net.IP
// - version 1: a single net.IP
// - version 2: array of IPs in next form:
// | 1 byte | 4*N bytes | 1 byte | 16*N bytes |
// | N | N * IPv4 | M | M * IPv6 |
// - port
// - ecdsa CompressedPublicKey
// - AES encryption key
func (cp *ConnectionParams) ToString() string {
v := base58.Encode(new(big.Int).SetInt64(int64(cp.version)).Bytes())
ip := base58.Encode(cp.netIP)
ips := base58.Encode(SerializeNetIps(cp.netIPs))
p := base58.Encode(new(big.Int).SetInt64(int64(cp.port)).Bytes())
k := base58.Encode(elliptic.MarshalCompressed(cp.publicKey.Curve, cp.publicKey.X, cp.publicKey.Y))
ek := base58.Encode(cp.aesKey)
return fmt.Sprintf("%s%s:%s:%s:%s:%s", connectionStringID, v, ip, p, k, ek)
return fmt.Sprintf("%s%s:%s:%s:%s:%s", connectionStringID, v, ips, p, k, ek)
}
func SerializeNetIps(ips []net.IP) []byte {
var out []byte
var ipv4 []net.IP
var ipv6 []net.IP
for _, ip := range ips {
if v := ip.To4(); v != nil {
ipv4 = append(ipv4, v)
} else {
ipv6 = append(ipv6, ip)
}
}
for _, arr := range [][]net.IP{ipv4, ipv6} {
out = append(out, uint8(len(arr)))
for _, ip := range arr {
out = append(out, ip...)
}
}
return out
}
func ParseNetIps(in []byte) ([]net.IP, error) {
var out []net.IP
if len(in) < 1 {
return nil, fmt.Errorf("net.ip field is too short: '%d', at least 1 byte required", len(in))
}
for _, ipLen := range []int{net.IPv4len, net.IPv6len} {
count := int(in[0])
in = in[1:]
if expectedLen := ipLen * count; len(in) < expectedLen {
return nil, fmt.Errorf("net.ip.ip%d field is too short, expected at least '%d' bytes, '%d' bytes found", ipLen, expectedLen, len(in))
}
for i := 0; i < count; i++ {
offset := i * ipLen
ip := in[offset : ipLen+offset]
out = append(out, ip)
}
in = in[ipLen*count:]
}
return out, nil
}
// FromString parses a connection params string required for to securely connect to another Status device.
@ -63,7 +118,7 @@ func (cp *ConnectionParams) ToString() string {
func (cp *ConnectionParams) FromString(s string) error {
if len(s) < 2 {
return fmt.Errorf("connection string is invalid: '%s'", s)
return fmt.Errorf("connection string is too short: '%s'", s)
}
if s[:2] != connectionStringID {
@ -78,7 +133,22 @@ func (cp *ConnectionParams) FromString(s string) error {
}
cp.version = versioning.ConnectionParamVersion(new(big.Int).SetBytes(base58.Decode(sData[0])).Int64())
cp.netIP = base58.Decode(sData[1])
netIpsBytes := base58.Decode(sData[1])
switch cp.version {
case versioning.ConnectionParamsV1:
if len(netIpsBytes) != net.IPv4len {
return fmt.Errorf("invalid IP size: '%d' bytes, expected: '%d' bytes", len(netIpsBytes), net.IPv4len)
}
cp.netIPs = []net.IP{netIpsBytes}
case versioning.ConnectionParamsV2:
netIps, err := ParseNetIps(netIpsBytes)
if err != nil {
return err
}
cp.netIPs = netIps
}
cp.port = int(new(big.Int).SetBytes(base58.Decode(sData[2])).Int64())
cp.publicKey = new(ecdsa.PublicKey)
cp.publicKey.X, cp.publicKey.Y = elliptic.UnmarshalCompressed(elliptic.P256(), base58.Decode(sData[3]))
@ -120,8 +190,10 @@ func (cp *ConnectionParams) validateVersion() error {
}
func (cp *ConnectionParams) validateNetIP() error {
if ok := net.ParseIP(cp.netIP.String()); ok == nil {
return fmt.Errorf("invalid net ip '%s'", cp.netIP)
for _, ip := range cp.netIPs {
if ok := net.ParseIP(ip.String()); ok == nil {
return fmt.Errorf("invalid net ip '%s'", cp.netIPs)
}
}
return nil
}
@ -154,7 +226,11 @@ func (cp *ConnectionParams) validateAESKey() error {
return nil
}
func (cp *ConnectionParams) URL() (*url.URL, error) {
func (cp *ConnectionParams) URL(IPIndex int) (*url.URL, error) {
if IPIndex < 0 || IPIndex >= len(cp.netIPs) {
return nil, fmt.Errorf("invalid IP index '%d'", IPIndex)
}
err := cp.validate()
if err != nil {
return nil, err
@ -162,7 +238,7 @@ func (cp *ConnectionParams) URL() (*url.URL, error) {
u := &url.URL{
Scheme: "https",
Host: fmt.Sprintf("%s:%d", cp.netIP, cp.port),
Host: fmt.Sprintf("%s:%d", cp.netIPs[IPIndex], cp.port),
}
return u, nil
}

View File

@ -1,6 +1,10 @@
package pairing
import (
"fmt"
"net"
"sort"
"strconv"
"testing"
"github.com/stretchr/testify/suite"
@ -9,8 +13,10 @@ import (
"github.com/status-im/status-go/server/servertest"
)
var (
connectionString = "cs2:4FHRnp:Q4:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6"
const (
connectionStringV1 = "cs2:4FHRnp:Q4:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6"
connectionStringV2 = "cs3:kDDauj5:Q4:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6"
port = 1337
)
func TestConnectionParamsSuite(t *testing.T) {
@ -31,17 +37,27 @@ func (s *ConnectionParamsSuite) SetupSuite() {
s.SetupCertComponents(s.T())
s.SetupLoggerComponents()
cert, _, err := GenerateCertFromKey(s.PK, s.NotBefore, server.DefaultIP.String())
ip := server.LocalHostIP
ips := []net.IP{ip}
cert, _, err := GenerateCertFromKey(s.PK, s.NotBefore, ips, []string{})
s.Require().NoError(err)
bs := server.NewServer(&cert, server.DefaultIP.String(), nil, s.Logger)
err = bs.SetPort(1337)
sc := ServerConfig{
PK: &s.PK.PublicKey,
EK: s.AES,
Cert: &cert,
IPAddresses: ips,
ListenIP: net.IPv4zero,
}
bs := server.NewServer(&cert, net.IPv4zero.String(), nil, s.Logger)
err = bs.SetPort(port)
s.Require().NoError(err)
s.server = &BaseServer{
Server: bs,
pk: &s.PK.PublicKey,
ek: s.AES,
config: sc,
}
}
@ -50,21 +66,73 @@ func (s *ConnectionParamsSuite) TestConnectionParams_ToString() {
s.Require().NoError(err)
cps := cp.ToString()
s.Require().Equal(connectionString, cps)
s.Require().Equal(connectionStringV2, cps)
}
func (s *ConnectionParamsSuite) TestConnectionParams_Generate() {
testCases := []struct {
description string
cs string
}{
{description: "ConnectionString_version1", cs: connectionStringV1},
{description: "ConnectionString_version2", cs: connectionStringV2},
}
for _, tc := range testCases {
s.T().Run(tc.description, func(t *testing.T) {
cp := new(ConnectionParams)
err := cp.FromString(connectionString)
err := cp.FromString(connectionStringV2)
s.Require().NoError(err)
u, err := cp.URL()
u, err := cp.URL(0)
s.Require().NoError(err)
s.Require().Equal("https://127.0.0.1:1337", u.String())
s.Require().Equal(server.DefaultIP.String(), u.Hostname())
s.Require().Equal("1337", u.Port())
expectedURL := fmt.Sprintf("https://%s:%d", server.LocalHostIP.String(), port)
s.Require().Equal(expectedURL, u.String())
s.Require().Equal(server.LocalHostIP.String(), u.Hostname())
s.Require().Equal(strconv.Itoa(port), u.Port())
s.Require().True(cp.publicKey.Equal(&s.PK.PublicKey))
s.Require().Equal(s.AES, cp.aesKey)
})
}
}
func (s *ConnectionParamsSuite) TestConnectionParams_ParseNetIps() {
in := []net.IP{
{192, 168, 1, 42},
net.ParseIP("fe80::6fd7:5ce4:554f:165a"),
{172, 16, 9, 1},
net.ParseIP("fe80::ffa5:98e1:285c:42eb"),
net.ParseIP("fe80::c1f:ee0d:1476:dd9a"),
}
bytes := SerializeNetIps(in)
s.Require().Equal(bytes,
[]byte{
2, // v4 count
192, 168, 1, 42, // v4 1
172, 16, 9, 1, // v4 2
3, // v6 count
0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0x6f, 0xd7, 0x5c, 0xe4, 0x55, 0x4f, 0x16, 0x5a, // v6 1
0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0xff, 0xa5, 0x98, 0xe1, 0x28, 0x5c, 0x42, 0xeb, // v6 2
0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0x0c, 0x1f, 0xee, 0x0d, 0x14, 0x76, 0xdd, 0x9a, // v6 3
})
out, err := ParseNetIps(bytes)
s.Require().NoError(err)
s.Require().Len(in, 5)
sort.SliceStable(in, func(i, j int) bool {
return in[i].String() < in[j].String()
})
sort.SliceStable(out, func(i, j int) bool {
return out[i].String() < out[j].String()
})
s.Require().Equal(in, out)
}

View File

@ -36,7 +36,7 @@ func preflightHandler(w http.ResponseWriter, r *http.Request) {
func makeCert(address net.IP) (*tls.Certificate, []byte, error) {
notBefore := time.Now()
notAfter := notBefore.Add(time.Minute)
return server.GenerateTLSCert(notBefore, notAfter, address.String())
return server.GenerateTLSCert(notBefore, notAfter, []net.IP{address}, []string{})
}
func makeAndStartServer(cert *tls.Certificate, address net.IP) (string, func() error, error) {

View File

@ -5,7 +5,6 @@ import (
"crypto/elliptic"
"crypto/rand"
"encoding/json"
"fmt"
"net"
"time"
@ -29,8 +28,7 @@ type BaseServer struct {
server.Server
challengeGiver *ChallengeGiver
pk *ecdsa.PublicKey
ek []byte
config ServerConfig
}
// NewBaseServer returns a *BaseServer init from the given *SenderServerConfig
@ -43,13 +41,12 @@ func NewBaseServer(logger *zap.Logger, e *PayloadEncryptor, config *ServerConfig
bs := &BaseServer{
Server: server.NewServer(
config.Cert,
config.Hostname,
config.ListenIP.String(),
nil,
logger,
),
challengeGiver: cg,
pk: config.PK,
ek: config.EK,
config: *config,
}
bs.SetTimeout(config.Timeout)
return bs, nil
@ -57,18 +54,7 @@ func NewBaseServer(logger *zap.Logger, e *PayloadEncryptor, config *ServerConfig
// MakeConnectionParams generates a *ConnectionParams based on the Server's current state
func (s *BaseServer) MakeConnectionParams() (*ConnectionParams, error) {
hostname := s.GetHostname()
netIP := net.ParseIP(hostname)
if netIP == nil {
return nil, fmt.Errorf("invalid ip address given '%s'", hostname)
}
netIP4 := netIP.To4()
if netIP4 != nil {
netIP = netIP4
}
return NewConnectionParams(netIP, s.MustGetPort(), s.pk, s.ek), nil
return NewConnectionParams(s.config.IPAddresses, s.MustGetPort(), s.config.PK, s.config.EK), nil
}
func MakeServerConfig(config *ServerConfig) error {
@ -83,12 +69,12 @@ func MakeServerConfig(config *ServerConfig) error {
return err
}
outboundIP, err := server.GetOutboundIP()
ips, err := server.GetLocalAddressesForPairingServer()
if err != nil {
return err
}
tlsCert, _, err := GenerateCertFromKey(tlsKey, time.Now(), outboundIP.String())
tlsCert, _, err := GenerateCertFromKey(tlsKey, time.Now(), ips, []string{})
if err != nil {
return err
}
@ -96,7 +82,9 @@ func MakeServerConfig(config *ServerConfig) error {
config.PK = &tlsKey.PublicKey
config.EK = AESKey
config.Cert = &tlsCert
config.Hostname = outboundIP.String()
config.IPAddresses = ips
config.ListenIP = net.IPv4zero
return nil
}

View File

@ -4,6 +4,7 @@ type ConnectionParamVersion int
const (
ConnectionParamsV1 ConnectionParamVersion = iota + 1
ConnectionParamsV2
)
type LocalPairingVersion int
@ -13,6 +14,6 @@ const (
)
const (
LatestConnectionParamVer = ConnectionParamsV1
LatestConnectionParamVer = ConnectionParamsV2
LatestLocalPairingVer = LocalPairingV1
)

View File

@ -51,7 +51,7 @@ func (s *QROpsTestSuite) SetupTest() {
s.Require().NoError(err)
s.serverNoPort = &MediaServer{Server: Server{
hostname: DefaultIP.String(),
hostname: LocalHostIP.String(),
portManger: newPortManager(s.Logger, nil),
}}
go func() {

View File

@ -59,14 +59,14 @@ func (s *ServerURLSuite) SetupTest() {
s.Require().NoError(err)
s.server = &MediaServer{Server: Server{
hostname: DefaultIP.String(),
hostname: LocalHostIP.String(),
portManger: newPortManager(s.Logger, nil),
}}
err = s.server.SetPort(customPortForTests)
s.Require().NoError(err)
s.serverNoPort = &MediaServer{Server: Server{
hostname: DefaultIP.String(),
hostname: LocalHostIP.String(),
portManger: newPortManager(s.Logger, nil),
}}
go func() {