optimised finding server cert (#4148)

* optimised finding server cert

* make sure `close(done)` invoked only once

* remove sleep

* resolve IDE warning

* refactor for findServerCert
This commit is contained in:
frank 2023-10-18 14:17:49 +08:00 committed by GitHub
parent 0881d8cdb0
commit 3326362b90
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 70 additions and 34 deletions

View File

@ -1,5 +1,12 @@
package common
import "runtime"
const (
AndroidPlatform = "android"
WindowsPlatform = "windows"
)
func OperatingSystemIs(targetOS string) bool {
return runtime.GOOS == targetOS
}

View File

@ -1,9 +1,10 @@
package common
import (
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
"github.com/status-im/status-go/protocol/protobuf"
"github.com/status-im/status-go/protocol/transport"
"github.com/waku-org/go-waku/waku/v2/protocol/relay"
)
const MainStatusShardCluster = 16

View File

@ -2,7 +2,6 @@ package server
import (
"net"
"runtime"
"go.uber.org/zap"
@ -98,7 +97,7 @@ func getAndroidLocalIP() ([][]net.IP, error) {
func getLocalAddresses() ([][]net.IP, error) {
// TODO until we can resolve Android errors when calling net.Interfaces() just return the outbound local address.
// Sorry Android
if runtime.GOOS == common.AndroidPlatform {
if common.OperatingSystemIs(common.AndroidPlatform) {
return getAndroidLocalIP()
}
@ -192,7 +191,7 @@ func getAllAvailableNetworks() ([]net.IPNet, error) {
// that returns a reachable server's address to be used by local pairing client.
func FindReachableAddressesForPairingClient(serverIps []net.IP) ([]net.IP, error) {
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
if runtime.GOOS == common.AndroidPlatform {
if common.OperatingSystemIs(common.AndroidPlatform) {
return serverIps, nil
}

View File

@ -114,7 +114,8 @@ func getServerCert(URL *url.URL) (*x509.Certificate, error) {
InsecureSkipVerify: true, // nolint: gosec // Only skip verify to get the server's TLS cert. DO NOT skip for any other reason.
}
conn, err := tls.Dial("tcp", URL.Host, conf)
// one second should be enough to get the server's TLS cert in LAN?
conn, err := tls.DialWithDialer(&net.Dialer{Timeout: time.Second}, "tcp", URL.Host, conf)
if err != nil {
return nil, err
}

View File

@ -8,10 +8,10 @@ import (
"encoding/pem"
"fmt"
"io"
"net"
"net/http"
"net/http/cookiejar"
"net/url"
"time"
"go.uber.org/zap"
@ -38,31 +38,52 @@ type BaseClient struct {
challengeTaker *ChallengeTaker
}
func findServerCert(c *ConnectionParams) (*url.URL, *x509.Certificate, error) {
netIps, err := server.FindReachableAddressesForPairingClient(c.netIPs)
if err != nil {
return nil, nil, err
}
func findServerCert(c *ConnectionParams, reachableIPs []net.IP) (*url.URL, *x509.Certificate, error) {
var baseAddress *url.URL
var serverCert *x509.Certificate
var certErrs error
for _, ip := range netIps {
type connectionError struct {
ip net.IP
err error
}
errCh := make(chan connectionError, len(reachableIPs))
type result struct {
u *url.URL
cert *x509.Certificate
}
successCh := make(chan result, 1) // as we close on the first success
for _, ip := range reachableIPs {
go func(ip net.IP) {
u := c.BuildURL(ip)
serverCert, err = getServerCert(u)
cert, err := getServerCert(u)
if err != nil {
var certErr string
if certErrs != nil {
certErr = certErrs.Error()
errCh <- connectionError{ip: ip, err: fmt.Errorf("connecting to '%s' failed: %s", u, err.Error())}
return
}
certErrs = fmt.Errorf("%sconnecting to '%s' failed: %s; ", certErr, u, err.Error())
continue
// If no error, send the results to the success channel
successCh <- result{u: u, cert: cert}
}(ip)
}
baseAddress = u
break
// Keep track of error counts
errorCount := 0
var combinedErrors string
for {
select {
case success := <-successCh:
baseAddress = success.u
serverCert = success.cert
return baseAddress, serverCert, nil
case ipErr := <-errCh:
errorCount++
combinedErrors += fmt.Sprintf("IP %s: %s; ", ipErr.ip, ipErr.err)
if errorCount == len(reachableIPs) {
return nil, nil, fmt.Errorf(combinedErrors)
}
}
}
return baseAddress, serverCert, certErrs
}
// NewBaseClient returns a fully qualified BaseClient from the given ConnectionParams
@ -71,13 +92,19 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
var serverCert *x509.Certificate
var certErrs error
netIPs, err := server.FindReachableAddressesForPairingClient(c.netIPs)
if err != nil {
logger.Error("[local pair client] failed to find reachable addresses", zap.Error(err), zap.Any("netIPs", netIPs))
signal.SendLocalPairingEvent(Event{Type: EventConnectionError, Error: err.Error(), Action: ActionConnect})
return nil, err
}
maxRetries := 3
for i := 0; i < maxRetries; i++ {
baseAddress, serverCert, certErrs = findServerCert(c)
baseAddress, serverCert, certErrs = findServerCert(c, netIPs)
if serverCert == nil {
certErrs = fmt.Errorf("failed to connect to any of given addresses. %w", certErrs)
time.Sleep(1 * time.Second)
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs))
logger.Warn("failed to connect to any of given addresses. Retrying...", zap.Error(certErrs), zap.Any("netIPs", netIPs), zap.Int("retry", i+1))
} else {
break
}
@ -92,7 +119,7 @@ func NewBaseClient(c *ConnectionParams, logger *zap.Logger) (*BaseClient, error)
// No error on the dial out then the URL.Host is accessible
signal.SendLocalPairingEvent(Event{Type: EventConnectionSuccess, Action: ActionConnect})
err := verifyCert(serverCert, c.publicKey)
err = verifyCert(serverCert, c.publicKey)
if err != nil {
return nil, err
}

View File

@ -45,7 +45,7 @@ func (p *PeerNotifier) handler(hello *peers.LocalPairingPeerHello) {
func (p *PeerNotifier) Search() error {
// TODO until we can resolve Android errors when calling net.Interfaces() just noop. Sorry Android
if runtime.GOOS == common.AndroidPlatform {
if common.OperatingSystemIs(common.AndroidPlatform) {
return nil
}

View File

@ -12,12 +12,13 @@ import (
"os"
"path"
"path/filepath"
"runtime"
"strconv"
"strings"
"testing"
"time"
"github.com/status-im/status-go/common"
_ "github.com/stretchr/testify/suite" // required to register testify flags
"github.com/status-im/status-go/logutils"
@ -222,7 +223,7 @@ func WaitClosed(c <-chan struct{}, d time.Duration) error {
func MakeTestNodeConfig(networkID int) (*params.NodeConfig, error) {
testDir := filepath.Join(TestDataDir, TestNetworkNames[networkID])
if runtime.GOOS == "windows" {
if common.OperatingSystemIs(common.WindowsPlatform) {
testDir = filepath.ToSlash(testDir)
}