cleanup code

This commit is contained in:
noot 2019-08-24 15:43:27 -04:00
parent e40f6b283e
commit c01c852045
8 changed files with 29 additions and 59 deletions

View File

@ -10,18 +10,18 @@ func (s *secureSession) Encrypt(plaintext []byte) (ciphertext []byte, err error)
if s.xx_complete {
if s.initiator {
cs := s.xx_ns.CS1()
cs, ciphertext = xx.EncryptWithAd(cs, nil, plaintext)
_, ciphertext = xx.EncryptWithAd(cs, nil, plaintext)
} else {
cs := s.xx_ns.CS2()
cs, ciphertext = xx.EncryptWithAd(cs, nil, plaintext)
_, ciphertext = xx.EncryptWithAd(cs, nil, plaintext)
}
} else if s.ik_complete {
if s.initiator {
cs := s.ik_ns.CS1()
cs, ciphertext = ik.EncryptWithAd(cs, nil, plaintext)
_, ciphertext = ik.EncryptWithAd(cs, nil, plaintext)
} else {
cs := s.ik_ns.CS2()
cs, ciphertext = ik.EncryptWithAd(cs, nil, plaintext)
_, ciphertext = ik.EncryptWithAd(cs, nil, plaintext)
}
} else {
return nil, errors.New("encrypt err: haven't completed handshake")
@ -35,18 +35,18 @@ func (s *secureSession) Decrypt(ciphertext []byte) (plaintext []byte, err error)
if s.xx_complete {
if s.initiator {
cs := s.xx_ns.CS2()
cs, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext)
_, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext)
} else {
cs := s.xx_ns.CS1()
cs, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext)
_, plaintext, ok = xx.DecryptWithAd(cs, nil, ciphertext)
}
} else if s.ik_complete {
if s.initiator {
cs := s.ik_ns.CS2()
cs, plaintext, ok = ik.DecryptWithAd(cs, nil, ciphertext)
_, plaintext, ok = ik.DecryptWithAd(cs, nil, ciphertext)
} else {
cs := s.ik_ns.CS1()
cs, plaintext, ok = ik.DecryptWithAd(cs, nil, ciphertext)
_, plaintext, ok = ik.DecryptWithAd(cs, nil, ciphertext)
}
} else {
return nil, errors.New("decrypt err: haven't completed handshake")

View File

@ -27,6 +27,19 @@ func TestEncryptAndDecrypt_InitToResp(t *testing.T) {
} else if err != nil {
t.Fatal(err)
}
plaintext = []byte("goodbye")
ciphertext, err = initConn.Encrypt(plaintext)
if err != nil {
t.Fatal(err)
}
result, err = respConn.Decrypt(ciphertext)
if !bytes.Equal(plaintext, result) {
t.Fatalf("got %x expected %x", result, plaintext)
} else if err != nil {
t.Fatal(err)
}
}
func TestEncryptAndDecrypt_RespToInit(t *testing.T) {

View File

@ -170,9 +170,6 @@ func (mb *MessageBuffer) Encode1() []byte {
enc = append(enc, mb.ne[:]...)
enc = append(enc, mb.ciphertext...)
// log.Debug("XX_Encode1", "ne", mb.ne)
// log.Debug("XX_Encode1", "ns", mb.ns)
return enc
}
@ -182,12 +179,10 @@ func Decode0(in []byte) (*MessageBuffer, error) {
return nil, errors.New("cannot decode stage 0 MessageBuffer: length less than 32 bytes")
}
//log.Debug("XX_Decode0", "in", in)
mb := new(MessageBuffer)
copy(mb.ne[:], in[:32])
mb.ns = in[32:80]
mb.ciphertext = in[80:]
//log.Debug("XX_Decode0", "mb", mb)
return mb, nil
}
@ -195,18 +190,12 @@ func Decode0(in []byte) (*MessageBuffer, error) {
// Decodes messages at stage 1 into MessageBuffer
func Decode1(in []byte) (*MessageBuffer, error) {
if len(in) < 80 {
return nil, errors.New("cannot decode stage 1/2 MessageBuffer: length less than 96 bytes")
return nil, errors.New("cannot decode stage 1 MessageBuffer: length less than 96 bytes")
}
// log.Debug("XX_Decode1", "in", in)
// log.Debug("XX_Decode1", "ns", in[32:80])
mb := new(MessageBuffer)
copy(mb.ne[:], in[:32])
//mb.ns = in[32:80]
mb.ciphertext = in[32:]
// copy(mb.ns,)
// copy(mb.ciphertext,)
return mb, nil
}
@ -523,7 +512,6 @@ func SendMessage(session *NoiseSession, message []byte) (*NoiseSession, MessageB
}
if session.mc == 1 {
session.h, messageBuffer, session.cs1, session.cs2 = writeMessageB(&session.hs, message)
//session.hs = handshakestate{}
}
session.mc = session.mc + 1
return session, messageBuffer
@ -537,7 +525,6 @@ func RecvMessage(session *NoiseSession, message *MessageBuffer) (*NoiseSession,
}
if session.mc == 1 {
session.h, plaintext, valid, session.cs1, session.cs2 = readMessageB(&session.hs, message)
//session.hs = handshakestate{}
}
session.mc = session.mc + 1
return session, plaintext, valid

View File

@ -23,7 +23,6 @@ func (s *secureSession) ik_sendHandshakeMessage(payload []byte, initial_stage bo
}
log.Debugf("ik_sendHandshakeMessage", "initiator", s.initiator, "msgbuf", msgbuf)
log.Debugf("ik_sendHandshakeMessage", "initiator", s.initiator, "encMsgBuf", encMsgBuf, "ns_len", len(msgbuf.NS()), "enc_len", len(encMsgBuf))
err := s.WriteLength(len(encMsgBuf))
if err != nil {

View File

@ -4,15 +4,12 @@ import (
"context"
"encoding/binary"
"fmt"
//"io"
"net"
"time"
logging "github.com/ipfs/go-log"
//proto "github.com/gogo/protobuf/proto"
"github.com/libp2p/go-libp2p-core/crypto"
"github.com/libp2p/go-libp2p-core/peer"
//"github.com/libp2p/go-libp2p-core/sec"
ik "github.com/ChainSafe/go-libp2p-noise/ik"
pb "github.com/ChainSafe/go-libp2p-noise/pb"
@ -126,21 +123,15 @@ func (s *secureSession) verifyPayload(payload *pb.NoiseHandshakePayload, noiseKe
}
func (s *secureSession) runHandshake(ctx context.Context) error {
log.Debugf("runHandshake", "cache", s.noiseStaticKeyCache)
// if we have the peer's noise static key and we support noise pipes, we can try IK
if s.noiseStaticKeyCache[s.remotePeer] != [32]byte{} || s.noisePipesSupport {
log.Debugf("runHandshake_ik")
// known static key for peer, try IK //
buf, err := s.runHandshake_ik(ctx)
if err != nil {
log.Error("runHandshake_ik", "err", err)
// TODO: PIPE TO XX
// IK failed, pipe to XXfallback
err = s.runHandshake_xx(ctx, true, buf)
if err != nil {
log.Error("runHandshake_xx", "err", err)
@ -153,7 +144,6 @@ func (s *secureSession) runHandshake(ctx context.Context) error {
s.ik_complete = true
} else {
// unknown static key for peer, try XX //
err := s.runHandshake_xx(ctx, false, nil)

View File

@ -9,10 +9,13 @@ import (
"github.com/libp2p/go-libp2p-core/sec"
)
// ID is the protocol ID for noise
const ID = "/noise/0.0.1"
var _ sec.SecureTransport = &Transport{}
// Transport implements the interface sec.SecureTransport
// https://godoc.org/github.com/libp2p/go-libp2p-core/sec#SecureConn
type Transport struct {
LocalID peer.ID
PrivateKey crypto.PrivKey

View File

@ -26,7 +26,6 @@ import (
"hash"
"io"
"math"
//log "github.com/ChainSafe/log15"
)
/* ---------------------------------------------------------------- *
@ -171,9 +170,6 @@ func (mb *MessageBuffer) Encode1() []byte {
enc = append(enc, mb.ns...)
enc = append(enc, mb.ciphertext...)
// log.Debug("XX_Encode1", "ne", mb.ne)
// log.Debug("XX_Encode1", "ns", mb.ns)
return enc
}
@ -183,11 +179,9 @@ func Decode0(in []byte) (*MessageBuffer, error) {
return nil, errors.New("cannot decode stage 0 MessageBuffer: length less than 32 bytes")
}
//log.Debug("XX_Decode0", "in", in)
mb := new(MessageBuffer)
copy(mb.ne[:], in[:32])
mb.ciphertext = in[32:]
//log.Debug("XX_Decode0", "mb", mb)
return mb, nil
}
@ -198,15 +192,10 @@ func Decode1(in []byte) (*MessageBuffer, error) {
return nil, errors.New("cannot decode stage 1/2 MessageBuffer: length less than 96 bytes")
}
// log.Debug("XX_Decode1", "in", in)
// log.Debug("XX_Decode1", "ns", in[32:80])
mb := new(MessageBuffer)
copy(mb.ne[:], in[:32])
mb.ns = in[32:80]
mb.ciphertext = in[80:]
// copy(mb.ns,)
// copy(mb.ciphertext,)
return mb, nil
}

View File

@ -3,7 +3,7 @@ package noise
import (
"context"
"fmt"
//log "github.com/ChainSafe/log15"
proto "github.com/gogo/protobuf/proto"
"github.com/libp2p/go-libp2p-core/peer"
@ -25,7 +25,6 @@ func (s *secureSession) xx_sendHandshakeMessage(payload []byte, initial_stage bo
}
log.Debugf("xx_sendHandshakeMessage", "initiator", s.initiator, "msgbuf", msgbuf, "initial_stage", initial_stage)
//log.Debugf("xx_sendHandshakeMessage", "initiator", s.initiator, "encMsgBuf", encMsgBuf, "ns_len", len(msgbuf.NS()), "enc_len", len(encMsgBuf), "initial_stage", initial_stage)
err := s.WriteLength(len(encMsgBuf))
if err != nil {
@ -63,8 +62,6 @@ func (s *secureSession) xx_recvHandshakeMessage(initial_stage bool) (buf []byte,
msgbuf, err = xx.Decode1(buf)
}
//log.Debugf("xx_recvHandshakeMessage", "initiator", s.initiator, "msgbuf", msgbuf, "buf len", len(buf), "initial_stage", initial_stage)
if err != nil {
log.Debugf("xx_recvHandshakeMessage decode", "initiator", s.initiator, "error", err)
return buf, nil, false, fmt.Errorf("decode msg fail: %s", err)
@ -181,14 +178,8 @@ func (s *secureSession) runHandshake_xx(ctx context.Context, fallback bool, init
return fmt.Errorf("validation fail")
}
//log.Debugf("stage 1 xx_recvHandshakeMessage", "initiator", s.initiator, "msgbuf", msgbuf, "payload len", len(plaintext))
}
log.Debugf("stage 1 initiator", "payload", plaintext)
log.Debugf("stage 1 initiator", "remote key", s.xx_ns.RemoteKey())
// stage 2 //
if !fallback {
@ -243,7 +234,7 @@ func (s *secureSession) runHandshake_xx(ctx context.Context, fallback bool, init
if !fallback {
// read message
buf, plaintext, valid, err = s.xx_recvHandshakeMessage(true)
_, plaintext, valid, err = s.xx_recvHandshakeMessage(true)
if err != nil {
return fmt.Errorf("stage 0 responder fail: %s", err)
}
@ -288,8 +279,6 @@ func (s *secureSession) runHandshake_xx(ctx context.Context, fallback bool, init
log.Error("xx_recvHandshakeMessage", "initiator", s.initiator, "error", "validation fail")
return fmt.Errorf("validation fail")
}
//log.Debugf("xx_recvHandshakeMessage", "initiator", s.initiator, "msgbuf", msgbuf, "payload len", len(plaintext))
}
log.Debugf("stage 0 responder", "plaintext", plaintext, "plaintext len", len(plaintext))
@ -304,7 +293,7 @@ func (s *secureSession) runHandshake_xx(ctx context.Context, fallback bool, init
// stage 2 //
// read message
buf, plaintext, valid, err = s.xx_recvHandshakeMessage(false)
_, plaintext, valid, err = s.xx_recvHandshakeMessage(false)
if err != nil {
return fmt.Errorf("stage 2 responder fail: %s", err)
}