2022-11-09 17:39:17 +03:00

279 lines
7.3 KiB
Go

package wire
import (
"bytes"
"encoding/binary"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// ParseConnectionID parses the destination connection ID of a packet.
// It uses the data slice for the connection ID.
// That means that the connection ID must not be used after the packet buffer is released.
func ParseConnectionID(data []byte, shortHeaderConnIDLen int) (protocol.ConnectionID, error) {
if len(data) == 0 {
return nil, io.EOF
}
if !IsLongHeaderPacket(data[0]) {
if len(data) < shortHeaderConnIDLen+1 {
return nil, io.EOF
}
return protocol.ConnectionID(data[1 : 1+shortHeaderConnIDLen]), nil
}
if len(data) < 6 {
return nil, io.EOF
}
destConnIDLen := int(data[5])
if len(data) < 6+destConnIDLen {
return nil, io.EOF
}
return protocol.ConnectionID(data[6 : 6+destConnIDLen]), nil
}
// IsLongHeaderPacket says if this is a Long Header packet
func IsLongHeaderPacket(firstByte byte) bool {
return firstByte&0x80 > 0
}
// IsVersionNegotiationPacket says if this is a version negotiation packet
func IsVersionNegotiationPacket(b []byte) bool {
if len(b) < 5 {
return false
}
return IsLongHeaderPacket(b[0]) && b[1] == 0 && b[2] == 0 && b[3] == 0 && b[4] == 0
}
// Is0RTTPacket says if this is a 0-RTT packet.
// A packet sent with a version we don't understand can never be a 0-RTT packet.
func Is0RTTPacket(b []byte) bool {
if len(b) < 5 {
return false
}
if !IsLongHeaderPacket(b[0]) {
return false
}
version := protocol.VersionNumber(binary.BigEndian.Uint32(b[1:5]))
if !protocol.IsSupportedVersion(protocol.SupportedVersions, version) {
return false
}
if version == protocol.Version2 {
return b[0]>>4&0b11 == 0b10
}
return b[0]>>4&0b11 == 0b01
}
var ErrUnsupportedVersion = errors.New("unsupported version")
// The Header is the version independent part of the header
type Header struct {
IsLongHeader bool
typeByte byte
Type protocol.PacketType
Version protocol.VersionNumber
SrcConnectionID protocol.ConnectionID
DestConnectionID protocol.ConnectionID
Length protocol.ByteCount
Token []byte
parsedLen protocol.ByteCount // how many bytes were read while parsing this header
}
// ParsePacket parses a packet.
// If the packet has a long header, the packet is cut according to the length field.
// If we understand the version, the packet is header up unto the packet number.
// Otherwise, only the invariant part of the header is parsed.
func ParsePacket(data []byte, shortHeaderConnIDLen int) (*Header, []byte /* packet data */, []byte /* rest */, error) {
hdr, err := parseHeader(bytes.NewReader(data), shortHeaderConnIDLen)
if err != nil {
if err == ErrUnsupportedVersion {
return hdr, nil, nil, ErrUnsupportedVersion
}
return nil, nil, nil, err
}
var rest []byte
if hdr.IsLongHeader {
if protocol.ByteCount(len(data)) < hdr.ParsedLen()+hdr.Length {
return nil, nil, nil, fmt.Errorf("packet length (%d bytes) is smaller than the expected length (%d bytes)", len(data)-int(hdr.ParsedLen()), hdr.Length)
}
packetLen := int(hdr.ParsedLen() + hdr.Length)
rest = data[packetLen:]
data = data[:packetLen]
}
return hdr, data, rest, nil
}
// ParseHeader parses the header.
// For short header packets: up to the packet number.
// For long header packets:
// * if we understand the version: up to the packet number
// * if not, only the invariant part of the header
func parseHeader(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
startLen := b.Len()
h, err := parseHeaderImpl(b, shortHeaderConnIDLen)
if err != nil {
return h, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return h, err
}
func parseHeaderImpl(b *bytes.Reader, shortHeaderConnIDLen int) (*Header, error) {
typeByte, err := b.ReadByte()
if err != nil {
return nil, err
}
h := &Header{
typeByte: typeByte,
IsLongHeader: IsLongHeaderPacket(typeByte),
}
if !h.IsLongHeader {
if h.typeByte&0x40 == 0 {
return nil, errors.New("not a QUIC packet")
}
if err := h.parseShortHeader(b, shortHeaderConnIDLen); err != nil {
return nil, err
}
return h, nil
}
return h, h.parseLongHeader(b)
}
func (h *Header) parseShortHeader(b *bytes.Reader, shortHeaderConnIDLen int) error {
var err error
h.DestConnectionID, err = protocol.ReadConnectionID(b, shortHeaderConnIDLen)
return err
}
func (h *Header) parseLongHeader(b *bytes.Reader) error {
v, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.Version = protocol.VersionNumber(v)
if h.Version != 0 && h.typeByte&0x40 == 0 {
return errors.New("not a QUIC packet")
}
destConnIDLen, err := b.ReadByte()
if err != nil {
return err
}
h.DestConnectionID, err = protocol.ReadConnectionID(b, int(destConnIDLen))
if err != nil {
return err
}
srcConnIDLen, err := b.ReadByte()
if err != nil {
return err
}
h.SrcConnectionID, err = protocol.ReadConnectionID(b, int(srcConnIDLen))
if err != nil {
return err
}
if h.Version == 0 { // version negotiation packet
return nil
}
// If we don't understand the version, we have no idea how to interpret the rest of the bytes
if !protocol.IsSupportedVersion(protocol.SupportedVersions, h.Version) {
return ErrUnsupportedVersion
}
if h.Version == protocol.Version2 {
switch h.typeByte >> 4 & 0b11 {
case 0b00:
h.Type = protocol.PacketTypeRetry
case 0b01:
h.Type = protocol.PacketTypeInitial
case 0b10:
h.Type = protocol.PacketType0RTT
case 0b11:
h.Type = protocol.PacketTypeHandshake
}
} else {
switch h.typeByte >> 4 & 0b11 {
case 0b00:
h.Type = protocol.PacketTypeInitial
case 0b01:
h.Type = protocol.PacketType0RTT
case 0b10:
h.Type = protocol.PacketTypeHandshake
case 0b11:
h.Type = protocol.PacketTypeRetry
}
}
if h.Type == protocol.PacketTypeRetry {
tokenLen := b.Len() - 16
if tokenLen <= 0 {
return io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
_, err := b.Seek(16, io.SeekCurrent)
return err
}
if h.Type == protocol.PacketTypeInitial {
tokenLen, err := quicvarint.Read(b)
if err != nil {
return err
}
if tokenLen > uint64(b.Len()) {
return io.EOF
}
h.Token = make([]byte, tokenLen)
if _, err := io.ReadFull(b, h.Token); err != nil {
return err
}
}
pl, err := quicvarint.Read(b)
if err != nil {
return err
}
h.Length = protocol.ByteCount(pl)
return nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *Header) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// ParseExtended parses the version dependent part of the header.
// The Reader has to be set such that it points to the first byte of the header.
func (h *Header) ParseExtended(b *bytes.Reader, ver protocol.VersionNumber) (*ExtendedHeader, error) {
extHdr := h.toExtendedHeader()
reservedBitsValid, err := extHdr.parse(b, ver)
if err != nil {
return nil, err
}
if !reservedBitsValid {
return extHdr, ErrInvalidReservedBits
}
return extHdr, nil
}
func (h *Header) toExtendedHeader() *ExtendedHeader {
return &ExtendedHeader{Header: *h}
}
// PacketType is the type of the packet, for logging purposes
func (h *Header) PacketType() string {
if h.IsLongHeader {
return h.Type.String()
}
return "1-RTT"
}