optimised finding server cert (#4148)

* optimised finding server cert

* make sure `close(done)` invoked only once

* remove sleep

* resolve IDE warning

* refactor for findServerCert
This commit is contained in:
frank 2023-10-18 14:17:49 +08:00 committed by GitHub
parent 0881d8cdb0
commit 3326362b90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 70 additions and 34 deletions

View File

@ -1,5 +1,12 @@
package common package common
import "runtime"
const ( const (
AndroidPlatform = "android" AndroidPlatform = "android"
WindowsPlatform = "windows"
) )
func OperatingSystemIs(targetOS string) bool {
return runtime.GOOS == targetOS
}

View File

@ -1,9 +1,10 @@
package common package common
import ( import (
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
"github.com/status-im/status-go/protocol/protobuf" "github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/transport" "github.com/status-im/status-go/protocol/transport"
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
) )
const MainStatusShardCluster = 16 const MainStatusShardCluster = 16

View File

@ -2,7 +2,6 @@ package server
import ( import (
"net" "net"
"runtime"
"go.uber.org/zap" "go.uber.org/zap"
@ -98,7 +97,7 @@ func getAndroidLocalIP() ([][]net.IP, error) {
func getLocalAddresses() ([][]net.IP, error) { func getLocalAddresses() ([][]net.IP, error) {
// TODO until we can resolve Android errors when calling net.Interfaces() just return the outbound local address. // TODO until we can resolve Android errors when calling net.Interfaces() just return the outbound local address.
// Sorry Android // Sorry Android
if runtime.GOOS == common.AndroidPlatform { if common.OperatingSystemIs(common.AndroidPlatform) {
return getAndroidLocalIP() return getAndroidLocalIP()
} }
@ -192,7 +191,7 @@ func getAllAvailableNetworks() ([]net.IPNet, error) {
// that returns a reachable server's address to be used by local pairing client. // that returns a reachable server's address to be used by local pairing client.
func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) { func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) {
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android // TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
if runtime.GOOS == common.AndroidPlatform { if common.OperatingSystemIs(common.AndroidPlatform) {
return serverIps, nil return serverIps, nil
} }

View File

@ -114,7 +114,8 @@ func getServerCert(URL *url.URL) (*x509.Certificate, error) {
InsecureSkipVerify: true, // nolint: gosec // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason. InsecureSkipVerify: true, // nolint: gosec // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
} }
conn, err := tls.Dial("tcp", URL.Host, conf) // one second should be enough to get the server's TLS cert in LAN?
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", URL.Host, conf)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -8,10 +8,10 @@ import (
"encoding/pem" "encoding/pem"
"fmt" "fmt"
"io" "io"
"net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/url" "net/url"
"time"
"go.uber.org/zap" "go.uber.org/zap"
@ -38,31 +38,52 @@ type BaseClient struct {
challengeTaker *ChallengeTaker challengeTaker *ChallengeTaker
} }
func findServerCert(c *ConnectionParams) (*url.URL, *x509.Certificate, error) { func findServerCert(c *ConnectionParams, reachableIPs []net.IP) (*url.URL, *x509.Certificate, error) {
netIps, err := server.FindReachableAddressesForPairingClient(c.netIPs)
if err != nil {
return nil, nil, err
}
var baseAddress *url.URL var baseAddress *url.URL
var serverCert *x509.Certificate var serverCert *x509.Certificate
var certErrs error
for _, ip := range netIps {
u := c.BuildURL(ip)
serverCert, err = getServerCert(u) type connectionError struct {
if err != nil { ip net.IP
var certErr string err error
if certErrs != nil { }
certErr = certErrs.Error() errCh := make(chan connectionError, len(reachableIPs))
}
certErrs = fmt.Errorf("%sconnecting to '%s' failed: %s; ", certErr, u, err.Error()) type result struct {
continue u *url.URL
} cert *x509.Certificate
}
baseAddress = u successCh := make(chan result, 1) // as we close on the first success
break
for _, ip := range reachableIPs {
go func(ip net.IP) {
u := c.BuildURL(ip)
cert, err := getServerCert(u)
if err != nil {
errCh <- connectionError{ip: ip, err: fmt.Errorf("connecting to '%s' failed: %s", u, err.Error())}
return
}
// If no error, send the results to the success channel
successCh <- result{u: u, cert: cert}
}(ip)
}
// Keep track of error counts
errorCount := 0
var combinedErrors string
for {
select {
case success := <-successCh:
baseAddress = success.u
serverCert = success.cert
return baseAddress, serverCert, nil
case ipErr := <-errCh:
errorCount++
combinedErrors += fmt.Sprintf("IP %s: %s; ", ipErr.ip, ipErr.err)
if errorCount == len(reachableIPs) {
return nil, nil, fmt.Errorf(combinedErrors)
}
}
} }
return baseAddress, serverCert, certErrs
} }
// NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams // NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams
@ -71,13 +92,19 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
var serverCert *x509.Certificate var serverCert *x509.Certificate
var certErrs error var certErrs error
netIPs, err := server.FindReachableAddressesForPairingClient(c.netIPs)
if err != nil {
logger.Error("[local pair client] failed to find reachable addresses", zap.Error(err), zap.Any("netIPs", netIPs))
signal.SendLocalPairingEvent(Event{Type: EventConnectionError, Error: err.Error(), Action: ActionConnect})
return nil, err
}
maxRetries := 3 maxRetries := 3
for i := 0; i < maxRetries; i++ { for i := 0; i < maxRetries; i++ {
baseAddress, serverCert, certErrs = findServerCert(c) baseAddress, serverCert, certErrs = findServerCert(c, netIPs)
if serverCert == nil { if serverCert == nil {
certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs) certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs)
time.Sleep(1 * time.Second) logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs), zap.Any("netIPs", netIPs), zap.Int("retry", i+1))
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs))
} else { } else {
break break
} }
@ -92,7 +119,7 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
// No error on the dial out then the URL.Host is accessible // No error on the dial out then the URL.Host is accessible
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect}) signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})
err := verifyCert(serverCert, c.publicKey) err = verifyCert(serverCert, c.publicKey)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -45,7 +45,7 @@ func (p *PeerNotifier) handler(hello *peers.LocalPairingPeerHello) {
func (p *PeerNotifier) Search() error { func (p *PeerNotifier) Search() error {
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android // TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
if runtime.GOOS == common.AndroidPlatform { if common.OperatingSystemIs(common.AndroidPlatform) {
return nil return nil
} }

View File

@ -12,12 +12,13 @@ import (
"os" "os"
"path" "path"
"path/filepath" "path/filepath"
"runtime"
"strconv" "strconv"
"strings" "strings"
"testing" "testing"
"time" "time"
"github.com/status-im/status-go/common"
_ "github.com/stretchr/testify/suite" // required to register testify flags _ "github.com/stretchr/testify/suite" // required to register testify flags
"github.com/status-im/status-go/logutils" "github.com/status-im/status-go/logutils"
@ -222,7 +223,7 @@ func WaitClosed(c <-chan struct{}, d time.Duration) error {
func MakeTestNodeConfig(networkID int) (*params.NodeConfig, error) { func MakeTestNodeConfig(networkID int) (*params.NodeConfig, error) {
testDir := filepath.Join(TestDataDir, TestNetworkNames[networkID]) testDir := filepath.Join(TestDataDir, TestNetworkNames[networkID])
if runtime.GOOS == "windows" { if common.OperatingSystemIs(common.WindowsPlatform) {
testDir = filepath.ToSlash(testDir) testDir = filepath.ToSlash(testDir)
} }