optimised finding server cert (#4148)
* optimised finding server cert * make sure `close(done)` invoked only once * remove sleep * resolve IDE warning * refactor for findServerCert
This commit is contained in:
parent
0881d8cdb0
commit
3326362b90
|
@ -1,5 +1,12 @@
|
|||
package common
|
||||
|
||||
import "runtime"
|
||||
|
||||
const (
|
||||
AndroidPlatform = "android"
|
||||
WindowsPlatform = "windows"
|
||||
)
|
||||
|
||||
func OperatingSystemIs(targetOS string) bool {
|
||||
return runtime.GOOS == targetOS
|
||||
}
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
package common
|
||||
|
||||
import (
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
|
||||
|
||||
"github.com/status-im/status-go/protocol/protobuf"
|
||||
"github.com/status-im/status-go/protocol/transport"
|
||||
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
|
||||
)
|
||||
|
||||
const MainStatusShardCluster = 16
|
||||
|
|
|
@ -2,7 +2,6 @@ package server
|
|||
|
||||
import (
|
||||
"net"
|
||||
"runtime"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -98,7 +97,7 @@ func getAndroidLocalIP() ([][]net.IP, error) {
|
|||
func getLocalAddresses() ([][]net.IP, error) {
|
||||
// TODO until we can resolve Android errors when calling net.Interfaces() just return the outbound local address.
|
||||
// Sorry Android
|
||||
if runtime.GOOS == common.AndroidPlatform {
|
||||
if common.OperatingSystemIs(common.AndroidPlatform) {
|
||||
return getAndroidLocalIP()
|
||||
}
|
||||
|
||||
|
@ -192,7 +191,7 @@ func getAllAvailableNetworks() ([]net.IPNet, error) {
|
|||
// that returns a reachable server's address to be used by local pairing client.
|
||||
func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) {
|
||||
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
|
||||
if runtime.GOOS == common.AndroidPlatform {
|
||||
if common.OperatingSystemIs(common.AndroidPlatform) {
|
||||
return serverIps, nil
|
||||
}
|
||||
|
||||
|
|
|
@ -114,7 +114,8 @@ func getServerCert(URL *url.URL) (*x509.Certificate, error) {
|
|||
InsecureSkipVerify: true, // nolint: gosec // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
|
||||
}
|
||||
|
||||
conn, err := tls.Dial("tcp", URL.Host, conf)
|
||||
// one second should be enough to get the server's TLS cert in LAN?
|
||||
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", URL.Host, conf)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -8,10 +8,10 @@ import (
|
|||
"encoding/pem"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"go.uber.org/zap"
|
||||
|
||||
|
@ -38,31 +38,52 @@ type BaseClient struct {
|
|||
challengeTaker *ChallengeTaker
|
||||
}
|
||||
|
||||
func findServerCert(c *ConnectionParams) (*url.URL, *x509.Certificate, error) {
|
||||
netIps, err := server.FindReachableAddressesForPairingClient(c.netIPs)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
func findServerCert(c *ConnectionParams, reachableIPs []net.IP) (*url.URL, *x509.Certificate, error) {
|
||||
var baseAddress *url.URL
|
||||
var serverCert *x509.Certificate
|
||||
var certErrs error
|
||||
for _, ip := range netIps {
|
||||
|
||||
type connectionError struct {
|
||||
ip net.IP
|
||||
err error
|
||||
}
|
||||
errCh := make(chan connectionError, len(reachableIPs))
|
||||
|
||||
type result struct {
|
||||
u *url.URL
|
||||
cert *x509.Certificate
|
||||
}
|
||||
successCh := make(chan result, 1) // as we close on the first success
|
||||
|
||||
for _, ip := range reachableIPs {
|
||||
go func(ip net.IP) {
|
||||
u := c.BuildURL(ip)
|
||||
|
||||
serverCert, err = getServerCert(u)
|
||||
cert, err := getServerCert(u)
|
||||
if err != nil {
|
||||
var certErr string
|
||||
if certErrs != nil {
|
||||
certErr = certErrs.Error()
|
||||
errCh <- connectionError{ip: ip, err: fmt.Errorf("connecting to '%s' failed: %s", u, err.Error())}
|
||||
return
|
||||
}
|
||||
certErrs = fmt.Errorf("%sconnecting to '%s' failed: %s; ", certErr, u, err.Error())
|
||||
continue
|
||||
// If no error, send the results to the success channel
|
||||
successCh <- result{u: u, cert: cert}
|
||||
}(ip)
|
||||
}
|
||||
|
||||
baseAddress = u
|
||||
break
|
||||
// Keep track of error counts
|
||||
errorCount := 0
|
||||
var combinedErrors string
|
||||
for {
|
||||
select {
|
||||
case success := <-successCh:
|
||||
baseAddress = success.u
|
||||
serverCert = success.cert
|
||||
return baseAddress, serverCert, nil
|
||||
case ipErr := <-errCh:
|
||||
errorCount++
|
||||
combinedErrors += fmt.Sprintf("IP %s: %s; ", ipErr.ip, ipErr.err)
|
||||
if errorCount == len(reachableIPs) {
|
||||
return nil, nil, fmt.Errorf(combinedErrors)
|
||||
}
|
||||
}
|
||||
}
|
||||
return baseAddress, serverCert, certErrs
|
||||
}
|
||||
|
||||
// NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams
|
||||
|
@ -71,13 +92,19 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
|
|||
var serverCert *x509.Certificate
|
||||
var certErrs error
|
||||
|
||||
netIPs, err := server.FindReachableAddressesForPairingClient(c.netIPs)
|
||||
if err != nil {
|
||||
logger.Error("[local pair client] failed to find reachable addresses", zap.Error(err), zap.Any("netIPs", netIPs))
|
||||
signal.SendLocalPairingEvent(Event{Type: EventConnectionError, Error: err.Error(), Action: ActionConnect})
|
||||
return nil, err
|
||||
}
|
||||
|
||||
maxRetries := 3
|
||||
for i := 0; i < maxRetries; i++ {
|
||||
baseAddress, serverCert, certErrs = findServerCert(c)
|
||||
baseAddress, serverCert, certErrs = findServerCert(c, netIPs)
|
||||
if serverCert == nil {
|
||||
certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs)
|
||||
time.Sleep(1 * time.Second)
|
||||
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs))
|
||||
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs), zap.Any("netIPs", netIPs), zap.Int("retry", i+1))
|
||||
} else {
|
||||
break
|
||||
}
|
||||
|
@ -92,7 +119,7 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
|
|||
// No error on the dial out then the URL.Host is accessible
|
||||
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})
|
||||
|
||||
err := verifyCert(serverCert, c.publicKey)
|
||||
err = verifyCert(serverCert, c.publicKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -45,7 +45,7 @@ func (p *PeerNotifier) handler(hello *peers.LocalPairingPeerHello) {
|
|||
|
||||
func (p *PeerNotifier) Search() error {
|
||||
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
|
||||
if runtime.GOOS == common.AndroidPlatform {
|
||||
if common.OperatingSystemIs(common.AndroidPlatform) {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -12,12 +12,13 @@ import (
|
|||
"os"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/status-im/status-go/common"
|
||||
|
||||
_ "github.com/stretchr/testify/suite" // required to register testify flags
|
||||
|
||||
"github.com/status-im/status-go/logutils"
|
||||
|
@ -222,7 +223,7 @@ func WaitClosed(c <-chan struct{}, d time.Duration) error {
|
|||
func MakeTestNodeConfig(networkID int) (*params.NodeConfig, error) {
|
||||
testDir := filepath.Join(TestDataDir, TestNetworkNames[networkID])
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
if common.OperatingSystemIs(common.WindowsPlatform) {
|
||||
testDir = filepath.ToSlash(testDir)
|
||||
}
|
||||
|
||||
|
|
Loading…
Reference in New Issue