294 lines
6.3 KiB
Go
Raw Normal View History

2022-01-31 00:27:37 +01:00
package binary
import (
"fmt"
"math"
"strconv"
"go.mau.fi/whatsmeow/binary/token"
"go.mau.fi/whatsmeow/types"
)
type binaryEncoder struct {
data []byte
}
func newEncoder() *binaryEncoder {
return &binaryEncoder{[]byte{0}}
}
func (w *binaryEncoder) getData() []byte {
return w.data
}
func (w *binaryEncoder) pushByte(b byte) {
w.data = append(w.data, b)
}
func (w *binaryEncoder) pushBytes(bytes []byte) {
w.data = append(w.data, bytes...)
}
func (w *binaryEncoder) pushIntN(value, n int, littleEndian bool) {
for i := 0; i < n; i++ {
var curShift int
if littleEndian {
curShift = i
} else {
curShift = n - i - 1
}
w.pushByte(byte((value >> uint(curShift*8)) & 0xFF))
}
}
func (w *binaryEncoder) pushInt20(value int) {
w.pushBytes([]byte{byte((value >> 16) & 0x0F), byte((value >> 8) & 0xFF), byte(value & 0xFF)})
}
func (w *binaryEncoder) pushInt8(value int) {
w.pushIntN(value, 1, false)
}
func (w *binaryEncoder) pushInt16(value int) {
w.pushIntN(value, 2, false)
}
func (w *binaryEncoder) pushInt32(value int) {
w.pushIntN(value, 4, false)
}
func (w *binaryEncoder) pushString(value string) {
w.pushBytes([]byte(value))
}
func (w *binaryEncoder) writeByteLength(length int) {
if length < 256 {
w.pushByte(token.Binary8)
w.pushInt8(length)
} else if length < (1 << 20) {
w.pushByte(token.Binary20)
w.pushInt20(length)
} else if length < math.MaxInt32 {
w.pushByte(token.Binary32)
w.pushInt32(length)
} else {
panic(fmt.Errorf("length is too large: %d", length))
}
}
const tagSize = 1
func (w *binaryEncoder) writeNode(n Node) {
if n.Tag == "0" {
w.pushByte(token.List8)
w.pushByte(token.ListEmpty)
return
}
hasContent := 0
if n.Content != nil {
hasContent = 1
}
w.writeListStart(2*len(n.Attrs) + tagSize + hasContent)
w.writeString(n.Tag)
w.writeAttributes(n.Attrs)
if n.Content != nil {
w.write(n.Content)
}
}
func (w *binaryEncoder) write(data interface{}) {
switch typedData := data.(type) {
case nil:
w.pushByte(token.ListEmpty)
case types.JID:
w.writeJID(typedData)
case string:
w.writeString(typedData)
case int:
w.writeString(strconv.Itoa(typedData))
case int32:
w.writeString(strconv.FormatInt(int64(typedData), 10))
case uint:
w.writeString(strconv.FormatUint(uint64(typedData), 10))
case uint32:
w.writeString(strconv.FormatUint(uint64(typedData), 10))
case int64:
w.writeString(strconv.FormatInt(typedData, 10))
case uint64:
w.writeString(strconv.FormatUint(typedData, 10))
case bool:
w.writeString(strconv.FormatBool(typedData))
case []byte:
w.writeBytes(typedData)
case []Node:
w.writeListStart(len(typedData))
for _, n := range typedData {
w.writeNode(n)
}
default:
panic(fmt.Errorf("%w: %T", ErrInvalidType, typedData))
}
}
func (w *binaryEncoder) writeString(data string) {
var dictIndex byte
if tokenIndex, ok := token.IndexOfSingleToken(data); ok {
w.pushByte(tokenIndex)
} else if dictIndex, tokenIndex, ok = token.IndexOfDoubleByteToken(data); ok {
w.pushByte(token.Dictionary0 + dictIndex)
w.pushByte(tokenIndex)
} else if validateNibble(data) {
w.writePackedBytes(data, token.Nibble8)
} else if validateHex(data) {
w.writePackedBytes(data, token.Hex8)
} else {
w.writeStringRaw(data)
}
}
func (w *binaryEncoder) writeBytes(value []byte) {
w.writeByteLength(len(value))
w.pushBytes(value)
}
func (w *binaryEncoder) writeStringRaw(value string) {
w.writeByteLength(len(value))
w.pushString(value)
}
func (w *binaryEncoder) writeJID(jid types.JID) {
if jid.AD {
w.pushByte(token.ADJID)
w.pushByte(jid.Agent)
w.pushByte(jid.Device)
w.writeString(jid.User)
} else {
w.pushByte(token.JIDPair)
if len(jid.User) == 0 {
w.pushByte(token.ListEmpty)
} else {
w.write(jid.User)
}
w.write(jid.Server)
}
}
func (w *binaryEncoder) writeAttributes(attributes Attrs) {
if attributes == nil {
return
}
for key, val := range attributes {
if val == "" || val == nil {
continue
}
w.writeString(key)
w.write(val)
}
}
func (w *binaryEncoder) writeListStart(listSize int) {
if listSize == 0 {
w.pushByte(byte(token.ListEmpty))
} else if listSize < 256 {
w.pushByte(byte(token.List8))
w.pushInt8(listSize)
} else {
w.pushByte(byte(token.List16))
w.pushInt16(listSize)
}
}
func (w *binaryEncoder) writePackedBytes(value string, dataType int) {
if len(value) > token.PackedMax {
panic(fmt.Errorf("too many bytes to pack: %d", len(value)))
}
w.pushByte(byte(dataType))
roundedLength := byte(math.Ceil(float64(len(value)) / 2.0))
if len(value)%2 != 0 {
roundedLength |= 128
}
w.pushByte(roundedLength)
var packer func(byte) byte
if dataType == token.Nibble8 {
packer = packNibble
} else if dataType == token.Hex8 {
packer = packHex
} else {
// This should only be called with the correct values
panic(fmt.Errorf("invalid packed byte data type %v", dataType))
}
for i, l := 0, len(value)/2; i < l; i++ {
w.pushByte(w.packBytePair(packer, value[2*i], value[2*i+1]))
}
if len(value)%2 != 0 {
w.pushByte(w.packBytePair(packer, value[len(value)-1], '\x00'))
}
}
func (w *binaryEncoder) packBytePair(packer func(byte) byte, part1, part2 byte) byte {
return (packer(part1) << 4) | packer(part2)
}
func validateNibble(value string) bool {
if len(value) > token.PackedMax {
return false
}
for _, char := range value {
if !(char >= '0' && char <= '9') && char != '-' && char != '.' {
return false
}
}
return true
}
func packNibble(value byte) byte {
switch value {
case '-':
return 10
case '.':
return 11
case 0:
return 15
default:
if value >= '0' && value <= '9' {
return value - '0'
}
// This should be validated beforehand
panic(fmt.Errorf("invalid string to pack as nibble: %d / '%s'", value, string(value)))
}
}
func validateHex(value string) bool {
if len(value) > token.PackedMax {
return false
}
for _, char := range value {
if !(char >= '0' && char <= '9') && !(char >= 'A' && char <= 'F') && !(char >= 'a' && char <= 'f') {
return false
}
}
return true
}
func packHex(value byte) byte {
switch {
case value >= '0' && value <= '9':
return value - '0'
case value >= 'A' && value <= 'F':
return 10 + value - 'A'
case value >= 'a' && value <= 'f':
return 10 + value - 'a'
case value == 0:
return 15
default:
// This should be validated beforehand
panic(fmt.Errorf("invalid string to pack as hex: %d / '%s'", value, string(value)))
}
}