229 lines
5.4 KiB
Go
229 lines
5.4 KiB
Go
// Copyright (c) 2021 Tulir Asokan
|
|
//
|
|
// This Source Code Form is subject to the terms of the Mozilla Public
|
|
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
|
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
|
|
|
package socket
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
|
|
waLog "go.mau.fi/whatsmeow/util/log"
|
|
)
|
|
|
|
type FrameSocket struct {
|
|
conn *websocket.Conn
|
|
ctx context.Context
|
|
cancel func()
|
|
log waLog.Logger
|
|
lock sync.Mutex
|
|
|
|
Frames chan []byte
|
|
OnDisconnect func(remote bool)
|
|
WriteTimeout time.Duration
|
|
|
|
Header []byte
|
|
|
|
incomingLength int
|
|
receivedLength int
|
|
incoming []byte
|
|
partialHeader []byte
|
|
}
|
|
|
|
func NewFrameSocket(log waLog.Logger, header []byte) *FrameSocket {
|
|
return &FrameSocket{
|
|
conn: nil,
|
|
log: log,
|
|
Header: header,
|
|
Frames: make(chan []byte),
|
|
}
|
|
}
|
|
|
|
func (fs *FrameSocket) IsConnected() bool {
|
|
return fs.conn != nil
|
|
}
|
|
|
|
func (fs *FrameSocket) Context() context.Context {
|
|
return fs.ctx
|
|
}
|
|
|
|
func (fs *FrameSocket) Close(code int) {
|
|
fs.lock.Lock()
|
|
defer fs.lock.Unlock()
|
|
|
|
if fs.conn == nil {
|
|
return
|
|
}
|
|
|
|
if code > 0 {
|
|
message := websocket.FormatCloseMessage(code, "")
|
|
err := fs.conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
|
|
if err != nil {
|
|
fs.log.Warnf("Error sending close message: %v", err)
|
|
}
|
|
}
|
|
|
|
fs.cancel()
|
|
err := fs.conn.Close()
|
|
if err != nil {
|
|
fs.log.Errorf("Error closing websocket: %v", err)
|
|
}
|
|
fs.conn = nil
|
|
fs.ctx = nil
|
|
fs.cancel = nil
|
|
if fs.OnDisconnect != nil {
|
|
go fs.OnDisconnect(code == 0)
|
|
}
|
|
}
|
|
|
|
func (fs *FrameSocket) Connect() error {
|
|
fs.lock.Lock()
|
|
defer fs.lock.Unlock()
|
|
|
|
if fs.conn != nil {
|
|
return ErrSocketAlreadyOpen
|
|
}
|
|
ctx, cancel := context.WithCancel(context.Background())
|
|
dialer := websocket.Dialer{}
|
|
|
|
headers := http.Header{"Origin": []string{Origin}}
|
|
fs.log.Debugf("Dialing %s", URL)
|
|
conn, _, err := dialer.Dial(URL, headers)
|
|
if err != nil {
|
|
cancel()
|
|
return fmt.Errorf("couldn't dial whatsapp web websocket: %w", err)
|
|
}
|
|
|
|
fs.ctx, fs.cancel = ctx, cancel
|
|
fs.conn = conn
|
|
conn.SetCloseHandler(func(code int, text string) error {
|
|
fs.log.Debugf("Server closed websocket with status %d/%s", code, text)
|
|
cancel()
|
|
// from default CloseHandler
|
|
message := websocket.FormatCloseMessage(code, "")
|
|
_ = conn.WriteControl(websocket.CloseMessage, message, time.Now().Add(time.Second))
|
|
return nil
|
|
})
|
|
|
|
go fs.readPump(conn, ctx)
|
|
return nil
|
|
}
|
|
|
|
func (fs *FrameSocket) SendFrame(data []byte) error {
|
|
conn := fs.conn
|
|
if conn == nil {
|
|
return ErrSocketClosed
|
|
}
|
|
dataLength := len(data)
|
|
if dataLength >= FrameMaxSize {
|
|
return fmt.Errorf("%w (got %d bytes, max %d bytes)", ErrFrameTooLarge, len(data), FrameMaxSize)
|
|
}
|
|
|
|
headerLength := len(fs.Header)
|
|
// Whole frame is header + 3 bytes for length + data
|
|
wholeFrame := make([]byte, headerLength+FrameLengthSize+dataLength)
|
|
|
|
// Copy the header if it's there
|
|
if fs.Header != nil {
|
|
copy(wholeFrame[:headerLength], fs.Header)
|
|
// We only want to send the header once
|
|
fs.Header = nil
|
|
}
|
|
|
|
// Encode length of frame
|
|
wholeFrame[headerLength] = byte(dataLength >> 16)
|
|
wholeFrame[headerLength+1] = byte(dataLength >> 8)
|
|
wholeFrame[headerLength+2] = byte(dataLength)
|
|
|
|
// Copy actual frame data
|
|
copy(wholeFrame[headerLength+FrameLengthSize:], data)
|
|
|
|
if fs.WriteTimeout > 0 {
|
|
err := conn.SetWriteDeadline(time.Now().Add(fs.WriteTimeout))
|
|
if err != nil {
|
|
fs.log.Warnf("Failed to set write deadline: %v", err)
|
|
}
|
|
}
|
|
return conn.WriteMessage(websocket.BinaryMessage, wholeFrame)
|
|
}
|
|
|
|
func (fs *FrameSocket) frameComplete() {
|
|
data := fs.incoming
|
|
fs.incoming = nil
|
|
fs.partialHeader = nil
|
|
fs.incomingLength = 0
|
|
fs.receivedLength = 0
|
|
fs.Frames <- data
|
|
}
|
|
|
|
func (fs *FrameSocket) processData(msg []byte) {
|
|
for len(msg) > 0 {
|
|
// This probably doesn't happen a lot (if at all), so the code is unoptimized
|
|
if fs.partialHeader != nil {
|
|
msg = append(fs.partialHeader, msg...)
|
|
fs.partialHeader = nil
|
|
}
|
|
if fs.incoming == nil {
|
|
if len(msg) >= FrameLengthSize {
|
|
length := (int(msg[0]) << 16) + (int(msg[1]) << 8) + int(msg[2])
|
|
fs.incomingLength = length
|
|
fs.receivedLength = len(msg)
|
|
msg = msg[FrameLengthSize:]
|
|
if len(msg) >= length {
|
|
fs.incoming = msg[:length]
|
|
msg = msg[length:]
|
|
fs.frameComplete()
|
|
} else {
|
|
fs.incoming = make([]byte, length)
|
|
copy(fs.incoming, msg)
|
|
msg = nil
|
|
}
|
|
} else {
|
|
fs.log.Warnf("Received partial header (report if this happens often)")
|
|
fs.partialHeader = msg
|
|
msg = nil
|
|
}
|
|
} else {
|
|
if len(fs.incoming)+len(msg) >= fs.incomingLength {
|
|
copy(fs.incoming[fs.receivedLength:], msg[:fs.incomingLength-fs.receivedLength])
|
|
msg = msg[fs.incomingLength-fs.receivedLength:]
|
|
fs.frameComplete()
|
|
} else {
|
|
copy(fs.incoming[fs.receivedLength:], msg)
|
|
fs.receivedLength += len(msg)
|
|
msg = nil
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (fs *FrameSocket) readPump(conn *websocket.Conn, ctx context.Context) {
|
|
fs.log.Debugf("Frame websocket read pump starting %p", fs)
|
|
defer func() {
|
|
fs.log.Debugf("Frame websocket read pump exiting %p", fs)
|
|
go fs.Close(0)
|
|
}()
|
|
for {
|
|
msgType, data, err := conn.ReadMessage()
|
|
if err != nil {
|
|
// Ignore the error if the context has been closed
|
|
if !errors.Is(ctx.Err(), context.Canceled) {
|
|
fs.log.Errorf("Error reading from websocket: %v", err)
|
|
}
|
|
return
|
|
} else if msgType != websocket.BinaryMessage {
|
|
fs.log.Warnf("Got unexpected websocket message type %d", msgType)
|
|
continue
|
|
}
|
|
fs.processData(data)
|
|
}
|
|
}
|