250 lines
7.4 KiB
Go
Raw Normal View History

2022-04-01 12:16:46 -04:00
package wire
import (
"bytes"
"errors"
"fmt"
"io"
"github.com/lucas-clemente/quic-go/internal/protocol"
"github.com/lucas-clemente/quic-go/internal/utils"
"github.com/lucas-clemente/quic-go/quicvarint"
)
// ErrInvalidReservedBits is returned when the reserved bits are incorrect.
// When this error is returned, parsing continues, and an ExtendedHeader is returned.
// This is necessary because we need to decrypt the packet in that case,
// in order to avoid a timing side-channel.
var ErrInvalidReservedBits = errors.New("invalid reserved bits")
// ExtendedHeader is the header of a QUIC packet.
type ExtendedHeader struct {
Header
typeByte byte
KeyPhase protocol.KeyPhaseBit
PacketNumberLen protocol.PacketNumberLen
PacketNumber protocol.PacketNumber
parsedLen protocol.ByteCount
}
func (h *ExtendedHeader) parse(b *bytes.Reader, v protocol.VersionNumber) (bool /* reserved bits valid */, error) {
startLen := b.Len()
// read the (now unencrypted) first byte
var err error
h.typeByte, err = b.ReadByte()
if err != nil {
return false, err
}
if _, err := b.Seek(int64(h.Header.ParsedLen())-1, io.SeekCurrent); err != nil {
return false, err
}
var reservedBitsValid bool
if h.IsLongHeader {
reservedBitsValid, err = h.parseLongHeader(b, v)
} else {
reservedBitsValid, err = h.parseShortHeader(b, v)
}
if err != nil {
return false, err
}
h.parsedLen = protocol.ByteCount(startLen - b.Len())
return reservedBitsValid, err
}
func (h *ExtendedHeader) parseLongHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0xc != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) parseShortHeader(b *bytes.Reader, _ protocol.VersionNumber) (bool /* reserved bits valid */, error) {
h.KeyPhase = protocol.KeyPhaseZero
if h.typeByte&0x4 > 0 {
h.KeyPhase = protocol.KeyPhaseOne
}
if err := h.readPacketNumber(b); err != nil {
return false, err
}
if h.typeByte&0x18 != 0 {
return false, nil
}
return true, nil
}
func (h *ExtendedHeader) readPacketNumber(b *bytes.Reader) error {
h.PacketNumberLen = protocol.PacketNumberLen(h.typeByte&0x3) + 1
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
n, err := b.ReadByte()
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen2:
n, err := utils.BigEndian.ReadUint16(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen3:
n, err := utils.BigEndian.ReadUint24(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
case protocol.PacketNumberLen4:
n, err := utils.BigEndian.ReadUint32(b)
if err != nil {
return err
}
h.PacketNumber = protocol.PacketNumber(n)
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
return nil
}
// Write writes the Header.
func (h *ExtendedHeader) Write(b *bytes.Buffer, ver protocol.VersionNumber) error {
if h.DestConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.DestConnectionID.Len())
}
if h.SrcConnectionID.Len() > protocol.MaxConnIDLen {
return fmt.Errorf("invalid connection ID length: %d bytes", h.SrcConnectionID.Len())
}
if h.IsLongHeader {
return h.writeLongHeader(b, ver)
}
return h.writeShortHeader(b, ver)
}
func (h *ExtendedHeader) writeLongHeader(b *bytes.Buffer, version protocol.VersionNumber) error {
2022-04-01 12:16:46 -04:00
var packetType uint8
if version == protocol.Version2 {
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0b01
case protocol.PacketType0RTT:
packetType = 0b10
case protocol.PacketTypeHandshake:
packetType = 0b11
case protocol.PacketTypeRetry:
packetType = 0b00
}
} else {
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeInitial:
packetType = 0b00
case protocol.PacketType0RTT:
packetType = 0b01
case protocol.PacketTypeHandshake:
packetType = 0b10
case protocol.PacketTypeRetry:
packetType = 0b11
}
2022-04-01 12:16:46 -04:00
}
firstByte := 0xc0 | packetType<<4
if h.Type != protocol.PacketTypeRetry {
// Retry packets don't have a packet number
firstByte |= uint8(h.PacketNumberLen - 1)
}
b.WriteByte(firstByte)
utils.BigEndian.WriteUint32(b, uint32(h.Version))
b.WriteByte(uint8(h.DestConnectionID.Len()))
b.Write(h.DestConnectionID.Bytes())
b.WriteByte(uint8(h.SrcConnectionID.Len()))
b.Write(h.SrcConnectionID.Bytes())
//nolint:exhaustive
switch h.Type {
case protocol.PacketTypeRetry:
b.Write(h.Token)
return nil
case protocol.PacketTypeInitial:
quicvarint.Write(b, uint64(len(h.Token)))
b.Write(h.Token)
}
quicvarint.WriteWithLen(b, uint64(h.Length), 2)
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writeShortHeader(b *bytes.Buffer, _ protocol.VersionNumber) error {
typeByte := 0x40 | uint8(h.PacketNumberLen-1)
if h.KeyPhase == protocol.KeyPhaseOne {
typeByte |= byte(1 << 2)
}
b.WriteByte(typeByte)
b.Write(h.DestConnectionID.Bytes())
return h.writePacketNumber(b)
}
func (h *ExtendedHeader) writePacketNumber(b *bytes.Buffer) error {
switch h.PacketNumberLen {
case protocol.PacketNumberLen1:
b.WriteByte(uint8(h.PacketNumber))
case protocol.PacketNumberLen2:
utils.BigEndian.WriteUint16(b, uint16(h.PacketNumber))
case protocol.PacketNumberLen3:
utils.BigEndian.WriteUint24(b, uint32(h.PacketNumber))
case protocol.PacketNumberLen4:
utils.BigEndian.WriteUint32(b, uint32(h.PacketNumber))
default:
return fmt.Errorf("invalid packet number length: %d", h.PacketNumberLen)
}
return nil
}
// ParsedLen returns the number of bytes that were consumed when parsing the header
func (h *ExtendedHeader) ParsedLen() protocol.ByteCount {
return h.parsedLen
}
// GetLength determines the length of the Header.
func (h *ExtendedHeader) GetLength(v protocol.VersionNumber) protocol.ByteCount {
if h.IsLongHeader {
length := 1 /* type byte */ + 4 /* version */ + 1 /* dest conn ID len */ + protocol.ByteCount(h.DestConnectionID.Len()) + 1 /* src conn ID len */ + protocol.ByteCount(h.SrcConnectionID.Len()) + protocol.ByteCount(h.PacketNumberLen) + 2 /* length */
if h.Type == protocol.PacketTypeInitial {
length += quicvarint.Len(uint64(len(h.Token))) + protocol.ByteCount(len(h.Token))
}
return length
}
length := protocol.ByteCount(1 /* type byte */ + h.DestConnectionID.Len())
length += protocol.ByteCount(h.PacketNumberLen)
return length
}
// Log logs the Header
func (h *ExtendedHeader) Log(logger utils.Logger) {
if h.IsLongHeader {
var token string
if h.Type == protocol.PacketTypeInitial || h.Type == protocol.PacketTypeRetry {
if len(h.Token) == 0 {
token = "Token: (empty), "
} else {
token = fmt.Sprintf("Token: %#x, ", h.Token)
}
if h.Type == protocol.PacketTypeRetry {
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sVersion: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.Version)
return
}
}
logger.Debugf("\tLong Header{Type: %s, DestConnectionID: %s, SrcConnectionID: %s, %sPacketNumber: %d, PacketNumberLen: %d, Length: %d, Version: %s}", h.Type, h.DestConnectionID, h.SrcConnectionID, token, h.PacketNumber, h.PacketNumberLen, h.Length, h.Version)
} else {
logger.Debugf("\tShort Header{DestConnectionID: %s, PacketNumber: %d, PacketNumberLen: %d, KeyPhase: %s}", h.DestConnectionID, h.PacketNumber, h.PacketNumberLen, h.KeyPhase)
}
}