105 lines
2.4 KiB
Go
105 lines
2.4 KiB
Go
|
// Copyright (c) 2021 Tulir Asokan
|
||
|
//
|
||
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
||
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||
|
|
||
|
package socket
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"crypto/cipher"
|
||
|
"encoding/binary"
|
||
|
"sync"
|
||
|
"sync/atomic"
|
||
|
|
||
|
"github.com/gorilla/websocket"
|
||
|
)
|
||
|
|
||
|
type NoiseSocket struct {
|
||
|
fs *FrameSocket
|
||
|
onFrame FrameHandler
|
||
|
writeKey cipher.AEAD
|
||
|
readKey cipher.AEAD
|
||
|
writeCounter uint32
|
||
|
readCounter uint32
|
||
|
writeLock sync.Mutex
|
||
|
destroyed uint32
|
||
|
stopConsumer chan struct{}
|
||
|
}
|
||
|
|
||
|
type DisconnectHandler func(socket *NoiseSocket, remote bool)
|
||
|
type FrameHandler func([]byte)
|
||
|
|
||
|
func newNoiseSocket(fs *FrameSocket, writeKey, readKey cipher.AEAD, frameHandler FrameHandler, disconnectHandler DisconnectHandler) (*NoiseSocket, error) {
|
||
|
ns := &NoiseSocket{
|
||
|
fs: fs,
|
||
|
writeKey: writeKey,
|
||
|
readKey: readKey,
|
||
|
onFrame: frameHandler,
|
||
|
stopConsumer: make(chan struct{}),
|
||
|
}
|
||
|
fs.OnDisconnect = func(remote bool) {
|
||
|
disconnectHandler(ns, remote)
|
||
|
}
|
||
|
go ns.consumeFrames(fs.ctx, fs.Frames)
|
||
|
return ns, nil
|
||
|
}
|
||
|
|
||
|
func (ns *NoiseSocket) consumeFrames(ctx context.Context, frames <-chan []byte) {
|
||
|
ctxDone := ctx.Done()
|
||
|
for {
|
||
|
select {
|
||
|
case frame := <-frames:
|
||
|
ns.receiveEncryptedFrame(frame)
|
||
|
case <-ctxDone:
|
||
|
return
|
||
|
case <-ns.stopConsumer:
|
||
|
return
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func generateIV(count uint32) []byte {
|
||
|
iv := make([]byte, 12)
|
||
|
binary.BigEndian.PutUint32(iv[8:], count)
|
||
|
return iv
|
||
|
}
|
||
|
|
||
|
func (ns *NoiseSocket) Context() context.Context {
|
||
|
return ns.fs.Context()
|
||
|
}
|
||
|
|
||
|
func (ns *NoiseSocket) Stop(disconnect bool) {
|
||
|
if atomic.CompareAndSwapUint32(&ns.destroyed, 0, 1) {
|
||
|
close(ns.stopConsumer)
|
||
|
ns.fs.OnDisconnect = nil
|
||
|
if disconnect {
|
||
|
ns.fs.Close(websocket.CloseNormalClosure)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (ns *NoiseSocket) SendFrame(plaintext []byte) error {
|
||
|
ns.writeLock.Lock()
|
||
|
ciphertext := ns.writeKey.Seal(nil, generateIV(ns.writeCounter), plaintext, nil)
|
||
|
ns.writeCounter++
|
||
|
err := ns.fs.SendFrame(ciphertext)
|
||
|
ns.writeLock.Unlock()
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
func (ns *NoiseSocket) receiveEncryptedFrame(ciphertext []byte) {
|
||
|
count := atomic.AddUint32(&ns.readCounter, 1) - 1
|
||
|
plaintext, err := ns.readKey.Open(nil, generateIV(count), ciphertext, nil)
|
||
|
if err != nil {
|
||
|
ns.fs.log.Warnf("Failed to decrypt frame: %v", err)
|
||
|
return
|
||
|
}
|
||
|
ns.onFrame(plaintext)
|
||
|
}
|
||
|
|
||
|
func (ns *NoiseSocket) IsConnected() bool {
|
||
|
return ns.fs.IsConnected()
|
||
|
}
|