optimised finding server cert (#4148)
* optimised finding server cert * make sure `close(done)` invoked only once * remove sleep * resolve IDE warning * refactor for findServerCert
This commit is contained in:
parent
0881d8cdb0
commit
3326362b90
|
@ -1,5 +1,12 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
|
import "runtime"
|
||||||
|
|
||||||
const (
|
const (
|
||||||
AndroidPlatform = "android"
|
AndroidPlatform = "android"
|
||||||
|
WindowsPlatform = "windows"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func OperatingSystemIs(targetOS string) bool {
|
||||||
|
return runtime.GOOS == targetOS
|
||||||
|
}
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
package common
|
package common
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
|
||||||
|
|
||||||
"github.com/status-im/status-go/protocol/protobuf"
|
"github.com/status-im/status-go/protocol/protobuf"
|
||||||
"github.com/status-im/status-go/protocol/transport"
|
"github.com/status-im/status-go/protocol/transport"
|
||||||
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const MainStatusShardCluster = 16
|
const MainStatusShardCluster = 16
|
||||||
|
|
|
@ -2,7 +2,6 @@ package server
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
@ -98,7 +97,7 @@ func getAndroidLocalIP() ([][]net.IP, error) {
|
||||||
func getLocalAddresses() ([][]net.IP, error) {
|
func getLocalAddresses() ([][]net.IP, error) {
|
||||||
// TODO until we can resolve Android errors when calling net.Interfaces() just return the outbound local address.
|
// TODO until we can resolve Android errors when calling net.Interfaces() just return the outbound local address.
|
||||||
// Sorry Android
|
// Sorry Android
|
||||||
if runtime.GOOS == common.AndroidPlatform {
|
if common.OperatingSystemIs(common.AndroidPlatform) {
|
||||||
return getAndroidLocalIP()
|
return getAndroidLocalIP()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -192,7 +191,7 @@ func getAllAvailableNetworks() ([]net.IPNet, error) {
|
||||||
// that returns a reachable server's address to be used by local pairing client.
|
// that returns a reachable server's address to be used by local pairing client.
|
||||||
func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) {
|
func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) {
|
||||||
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
|
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
|
||||||
if runtime.GOOS == common.AndroidPlatform {
|
if common.OperatingSystemIs(common.AndroidPlatform) {
|
||||||
return serverIps, nil
|
return serverIps, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -114,7 +114,8 @@ func getServerCert(URL *url.URL) (*x509.Certificate, error) {
|
||||||
InsecureSkipVerify: true, // nolint: gosec // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
|
InsecureSkipVerify: true, // nolint: gosec // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
|
||||||
}
|
}
|
||||||
|
|
||||||
conn, err := tls.Dial("tcp", URL.Host, conf)
|
// one second should be enough to get the server's TLS cert in LAN?
|
||||||
|
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", URL.Host, conf)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,10 +8,10 @@ import (
|
||||||
"encoding/pem"
|
"encoding/pem"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/cookiejar"
|
"net/http/cookiejar"
|
||||||
"net/url"
|
"net/url"
|
||||||
"time"
|
|
||||||
|
|
||||||
"go.uber.org/zap"
|
"go.uber.org/zap"
|
||||||
|
|
||||||
|
@ -38,31 +38,52 @@ type BaseClient struct {
|
||||||
challengeTaker *ChallengeTaker
|
challengeTaker *ChallengeTaker
|
||||||
}
|
}
|
||||||
|
|
||||||
func findServerCert(c *ConnectionParams) (*url.URL, *x509.Certificate, error) {
|
func findServerCert(c *ConnectionParams, reachableIPs []net.IP) (*url.URL, *x509.Certificate, error) {
|
||||||
netIps, err := server.FindReachableAddressesForPairingClient(c.netIPs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
|
||||||
}
|
|
||||||
var baseAddress *url.URL
|
var baseAddress *url.URL
|
||||||
var serverCert *x509.Certificate
|
var serverCert *x509.Certificate
|
||||||
var certErrs error
|
|
||||||
for _, ip := range netIps {
|
|
||||||
u := c.BuildURL(ip)
|
|
||||||
|
|
||||||
serverCert, err = getServerCert(u)
|
type connectionError struct {
|
||||||
if err != nil {
|
ip net.IP
|
||||||
var certErr string
|
err error
|
||||||
if certErrs != nil {
|
}
|
||||||
certErr = certErrs.Error()
|
errCh := make(chan connectionError, len(reachableIPs))
|
||||||
}
|
|
||||||
certErrs = fmt.Errorf("%sconnecting to '%s' failed: %s; ", certErr, u, err.Error())
|
type result struct {
|
||||||
continue
|
u *url.URL
|
||||||
}
|
cert *x509.Certificate
|
||||||
|
}
|
||||||
baseAddress = u
|
successCh := make(chan result, 1) // as we close on the first success
|
||||||
break
|
|
||||||
|
for _, ip := range reachableIPs {
|
||||||
|
go func(ip net.IP) {
|
||||||
|
u := c.BuildURL(ip)
|
||||||
|
cert, err := getServerCert(u)
|
||||||
|
if err != nil {
|
||||||
|
errCh <- connectionError{ip: ip, err: fmt.Errorf("connecting to '%s' failed: %s", u, err.Error())}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// If no error, send the results to the success channel
|
||||||
|
successCh <- result{u: u, cert: cert}
|
||||||
|
}(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep track of error counts
|
||||||
|
errorCount := 0
|
||||||
|
var combinedErrors string
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case success := <-successCh:
|
||||||
|
baseAddress = success.u
|
||||||
|
serverCert = success.cert
|
||||||
|
return baseAddress, serverCert, nil
|
||||||
|
case ipErr := <-errCh:
|
||||||
|
errorCount++
|
||||||
|
combinedErrors += fmt.Sprintf("IP %s: %s; ", ipErr.ip, ipErr.err)
|
||||||
|
if errorCount == len(reachableIPs) {
|
||||||
|
return nil, nil, fmt.Errorf(combinedErrors)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return baseAddress, serverCert, certErrs
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams
|
// NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams
|
||||||
|
@ -71,13 +92,19 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
|
||||||
var serverCert *x509.Certificate
|
var serverCert *x509.Certificate
|
||||||
var certErrs error
|
var certErrs error
|
||||||
|
|
||||||
|
netIPs, err := server.FindReachableAddressesForPairingClient(c.netIPs)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("[local pair client] failed to find reachable addresses", zap.Error(err), zap.Any("netIPs", netIPs))
|
||||||
|
signal.SendLocalPairingEvent(Event{Type: EventConnectionError, Error: err.Error(), Action: ActionConnect})
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
maxRetries := 3
|
maxRetries := 3
|
||||||
for i := 0; i < maxRetries; i++ {
|
for i := 0; i < maxRetries; i++ {
|
||||||
baseAddress, serverCert, certErrs = findServerCert(c)
|
baseAddress, serverCert, certErrs = findServerCert(c, netIPs)
|
||||||
if serverCert == nil {
|
if serverCert == nil {
|
||||||
certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs)
|
certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs)
|
||||||
time.Sleep(1 * time.Second)
|
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs), zap.Any("netIPs", netIPs), zap.Int("retry", i+1))
|
||||||
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs))
|
|
||||||
} else {
|
} else {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
@ -92,7 +119,7 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
|
||||||
// No error on the dial out then the URL.Host is accessible
|
// No error on the dial out then the URL.Host is accessible
|
||||||
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})
|
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})
|
||||||
|
|
||||||
err := verifyCert(serverCert, c.publicKey)
|
err = verifyCert(serverCert, c.publicKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
|
@ -45,7 +45,7 @@ func (p *PeerNotifier) handler(hello *peers.LocalPairingPeerHello) {
|
||||||
|
|
||||||
func (p *PeerNotifier) Search() error {
|
func (p *PeerNotifier) Search() error {
|
||||||
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
|
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
|
||||||
if runtime.GOOS == common.AndroidPlatform {
|
if common.OperatingSystemIs(common.AndroidPlatform) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,12 +12,13 @@ import (
|
||||||
"os"
|
"os"
|
||||||
"path"
|
"path"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"runtime"
|
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/status-im/status-go/common"
|
||||||
|
|
||||||
_ "github.com/stretchr/testify/suite" // required to register testify flags
|
_ "github.com/stretchr/testify/suite" // required to register testify flags
|
||||||
|
|
||||||
"github.com/status-im/status-go/logutils"
|
"github.com/status-im/status-go/logutils"
|
||||||
|
@ -222,7 +223,7 @@ func WaitClosed(c <-chan struct{}, d time.Duration) error {
|
||||||
func MakeTestNodeConfig(networkID int) (*params.NodeConfig, error) {
|
func MakeTestNodeConfig(networkID int) (*params.NodeConfig, error) {
|
||||||
testDir := filepath.Join(TestDataDir, TestNetworkNames[networkID])
|
testDir := filepath.Join(TestDataDir, TestNetworkNames[networkID])
|
||||||
|
|
||||||
if runtime.GOOS == "windows" {
|
if common.OperatingSystemIs(common.WindowsPlatform) {
|
||||||
testDir = filepath.ToSlash(testDir)
|
testDir = filepath.ToSlash(testDir)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue