status-go/server/connection.go
2022-10-05 12:58:32 +01:00

195 lines
4.8 KiB
Go

package server
import (
"crypto/ecdsa"
"crypto/elliptic"
"fmt"
"math/big"
"net"
"net/url"
"strings"
"github.com/btcsuite/btcutil/base58"
)
type ConnectionParamVersion int
type Mode int
const (
Version1 ConnectionParamVersion = iota + 1
)
const (
Receiving Mode = iota + 1
Sending
)
const (
connectionStringID = "cs"
)
type ConnectionParams struct {
version ConnectionParamVersion
netIP net.IP
port int
publicKey *ecdsa.PublicKey
aesKey []byte
serverMode Mode
}
func NewConnectionParams(netIP net.IP, port int, publicKey *ecdsa.PublicKey, aesKey []byte, mode Mode) *ConnectionParams {
cp := new(ConnectionParams)
cp.version = Version1
cp.netIP = netIP
cp.port = port
cp.publicKey = publicKey
cp.aesKey = aesKey
cp.serverMode = mode
return cp
}
// ToString generates a string required for generating a secure connection to another Status device.
//
// The returned string will look like below:
// - "cs2:4FHRnp:H6G:uqnnMwVUfJc2Fkcaojet8F1ufKC3hZdGEt47joyBx9yd:BbnZ7Gc66t54a9kEFCf7FW8SGQuYypwHVeNkRYeNoqV6:2"
//
// Format bytes encoded into a base58 string, delimited by ":"
// - string type identifier
// - version
// - net.IP
// - port
// - ecdsa CompressedPublicKey
// - AES encryption key
// - server mode
func (cp *ConnectionParams) ToString() string {
v := base58.Encode(new(big.Int).SetInt64(int64(cp.version)).Bytes())
ip := base58.Encode(cp.netIP)
p := base58.Encode(new(big.Int).SetInt64(int64(cp.port)).Bytes())
k := base58.Encode(elliptic.MarshalCompressed(cp.publicKey.Curve, cp.publicKey.X, cp.publicKey.Y))
ek := base58.Encode(cp.aesKey)
m := base58.Encode(new(big.Int).SetInt64(int64(cp.serverMode)).Bytes())
return fmt.Sprintf("%s%s:%s:%s:%s:%s:%s", connectionStringID, v, ip, p, k, ek, m)
}
// FromString parses a connection params string required for to securely connect to another Status device.
// This function parses a connection string generated by ToString
func (cp *ConnectionParams) FromString(s string) error {
if s[:2] != connectionStringID {
return fmt.Errorf("connection string doesn't begin with identifier '%s'", connectionStringID)
}
requiredParams := 6
sData := strings.Split(s[2:], ":")
if len(sData) != requiredParams {
return fmt.Errorf("expected data '%s' to have length of '%d', received '%d'", s, requiredParams, len(sData))
}
cp.version = ConnectionParamVersion(new(big.Int).SetBytes(base58.Decode(sData[0])).Int64())
cp.netIP = base58.Decode(sData[1])
cp.port = int(new(big.Int).SetBytes(base58.Decode(sData[2])).Int64())
cp.publicKey = new(ecdsa.PublicKey)
cp.publicKey.X, cp.publicKey.Y = elliptic.UnmarshalCompressed(elliptic.P256(), base58.Decode(sData[3]))
cp.publicKey.Curve = elliptic.P256()
cp.aesKey = base58.Decode(sData[4])
cp.serverMode = Mode(new(big.Int).SetBytes(base58.Decode(sData[5])).Int64())
return cp.validate()
}
func (cp *ConnectionParams) validate() error {
err := cp.validateVersion()
if err != nil {
return err
}
err = cp.validateNetIP()
if err != nil {
return err
}
err = cp.validatePort()
if err != nil {
return err
}
err = cp.validatePublicKey()
if err != nil {
return err
}
err = cp.validateAESKey()
if err != nil {
return err
}
return cp.validateServerMode()
}
func (cp *ConnectionParams) validateVersion() error {
switch cp.version {
case Version1:
return nil
default:
return fmt.Errorf("unsupported version '%d'", cp.version)
}
}
func (cp *ConnectionParams) validateNetIP() error {
if ok := net.ParseIP(cp.netIP.String()); ok == nil {
return fmt.Errorf("invalid net ip '%s'", cp.netIP)
}
return nil
}
func (cp *ConnectionParams) validatePort() error {
if cp.port > 0 && cp.port < 0x10000 {
return nil
}
return fmt.Errorf("port '%d' outside of bounds of 1 - 65535", cp.port)
}
func (cp *ConnectionParams) validatePublicKey() error {
switch {
case cp.publicKey.Curve == nil, cp.publicKey.Curve != elliptic.P256():
return fmt.Errorf("public key Curve not `elliptic.P256`")
case cp.publicKey.X == nil, cp.publicKey.X.Cmp(big.NewInt(0)) == 0:
return fmt.Errorf("public key X not set")
case cp.publicKey.Y == nil, cp.publicKey.Y.Cmp(big.NewInt(0)) == 0:
return fmt.Errorf("public key Y not set")
default:
return nil
}
}
func (cp *ConnectionParams) validateAESKey() error {
if len(cp.aesKey) != 32 {
return fmt.Errorf("AES key invalid length, expect length 32, received length '%d'", len(cp.aesKey))
}
return nil
}
func (cp *ConnectionParams) validateServerMode() error {
switch cp.serverMode {
case Receiving, Sending:
return nil
default:
return fmt.Errorf("invalid server mode '%d'", cp.serverMode)
}
}
func (cp *ConnectionParams) URL() (*url.URL, error) {
err := cp.validate()
if err != nil {
return nil, err
}
u := &url.URL{
Scheme: "https",
Host: fmt.Sprintf("%s:%d", cp.netIP, cp.port),
}
return u, nil
}