status-go/vendor/github.com/pion/sctp/reassembly_queue.go
2024-06-05 16:10:03 -04:00

367 lines
8.1 KiB
Go

// SPDX-FileCopyrightText: 2023 The Pion community <https://pion.ly>
// SPDX-License-Identifier: MIT
package sctp
import (
"errors"
"io"
"sort"
"sync/atomic"
)
func sortChunksByTSN(a []*chunkPayloadData) {
sort.Slice(a, func(i, j int) bool {
return sna32LT(a[i].tsn, a[j].tsn)
})
}
func sortChunksBySSN(a []*chunkSet) {
sort.Slice(a, func(i, j int) bool {
return sna16LT(a[i].ssn, a[j].ssn)
})
}
// chunkSet is a set of chunks that share the same SSN
type chunkSet struct {
ssn uint16 // used only with the ordered chunks
ppi PayloadProtocolIdentifier
chunks []*chunkPayloadData
}
func newChunkSet(ssn uint16, ppi PayloadProtocolIdentifier) *chunkSet {
return &chunkSet{
ssn: ssn,
ppi: ppi,
chunks: []*chunkPayloadData{},
}
}
func (set *chunkSet) push(chunk *chunkPayloadData) bool {
// check if dup
for _, c := range set.chunks {
if c.tsn == chunk.tsn {
return false
}
}
// append and sort
set.chunks = append(set.chunks, chunk)
sortChunksByTSN(set.chunks)
// Check if we now have a complete set
complete := set.isComplete()
return complete
}
func (set *chunkSet) isComplete() bool {
// Condition for complete set
// 0. Has at least one chunk.
// 1. Begins with beginningFragment set to true
// 2. Ends with endingFragment set to true
// 3. TSN monotinically increase by 1 from beginning to end
// 0.
nChunks := len(set.chunks)
if nChunks == 0 {
return false
}
// 1.
if !set.chunks[0].beginningFragment {
return false
}
// 2.
if !set.chunks[nChunks-1].endingFragment {
return false
}
// 3.
var lastTSN uint32
for i, c := range set.chunks {
if i > 0 {
// Fragments must have contiguous TSN
// From RFC 4960 Section 3.3.1:
// When a user message is fragmented into multiple chunks, the TSNs are
// used by the receiver to reassemble the message. This means that the
// TSNs for each fragment of a fragmented user message MUST be strictly
// sequential.
if c.tsn != lastTSN+1 {
// mid or end fragment is missing
return false
}
}
lastTSN = c.tsn
}
return true
}
type reassemblyQueue struct {
si uint16
nextSSN uint16 // expected SSN for next ordered chunk
ordered []*chunkSet
unordered []*chunkSet
unorderedChunks []*chunkPayloadData
nBytes uint64
}
var errTryAgain = errors.New("try again")
func newReassemblyQueue(si uint16) *reassemblyQueue {
// From RFC 4960 Sec 6.5:
// The Stream Sequence Number in all the streams MUST start from 0 when
// the association is established. Also, when the Stream Sequence
// Number reaches the value 65535 the next Stream Sequence Number MUST
// be set to 0.
return &reassemblyQueue{
si: si,
nextSSN: 0, // From RFC 4960 Sec 6.5:
ordered: make([]*chunkSet, 0),
unordered: make([]*chunkSet, 0),
}
}
func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool {
var cset *chunkSet
if chunk.streamIdentifier != r.si {
return false
}
if chunk.unordered {
// First, insert into unorderedChunks array
r.unorderedChunks = append(r.unorderedChunks, chunk)
atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData)))
sortChunksByTSN(r.unorderedChunks)
// Scan unorderedChunks that are contiguous (in TSN)
cset = r.findCompleteUnorderedChunkSet()
// If found, append the complete set to the unordered array
if cset != nil {
r.unordered = append(r.unordered, cset)
return true
}
return false
}
// This is an ordered chunk
if sna16LT(chunk.streamSequenceNumber, r.nextSSN) {
return false
}
// Check if a fragmented chunkSet with the fragmented SSN already exists
if chunk.isFragmented() {
for _, set := range r.ordered {
// nolint:godox
// TODO: add caution around SSN wrapping here... this helps only a little bit
// by ensuring we don't add to an unfragmented cset (1 chunk). There's
// a case where if the SSN does wrap around, we may see the same SSN
// for a different chunk.
// nolint:godox
// TODO: this slice can get pretty big; it may be worth maintaining a map
// for O(1) lookups at the cost of 2x memory.
if set.ssn == chunk.streamSequenceNumber && set.chunks[0].isFragmented() {
cset = set
break
}
}
}
// If not found, create a new chunkSet
if cset == nil {
cset = newChunkSet(chunk.streamSequenceNumber, chunk.payloadType)
r.ordered = append(r.ordered, cset)
if !chunk.unordered {
sortChunksBySSN(r.ordered)
}
}
atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData)))
return cset.push(chunk)
}
func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet {
startIdx := -1
nChunks := 0
var lastTSN uint32
var found bool
for i, c := range r.unorderedChunks {
// seek beigining
if c.beginningFragment {
startIdx = i
nChunks = 1
lastTSN = c.tsn
if c.endingFragment {
found = true
break
}
continue
}
if startIdx < 0 {
continue
}
// Check if contiguous in TSN
if c.tsn != lastTSN+1 {
startIdx = -1
continue
}
lastTSN = c.tsn
nChunks++
if c.endingFragment {
found = true
break
}
}
if !found {
return nil
}
// Extract the range of chunks
var chunks []*chunkPayloadData
chunks = append(chunks, r.unorderedChunks[startIdx:startIdx+nChunks]...)
r.unorderedChunks = append(
r.unorderedChunks[:startIdx],
r.unorderedChunks[startIdx+nChunks:]...)
chunkSet := newChunkSet(0, chunks[0].payloadType)
chunkSet.chunks = chunks
return chunkSet
}
func (r *reassemblyQueue) isReadable() bool {
// Check unordered first
if len(r.unordered) > 0 {
// The chunk sets in r.unordered should all be complete.
return true
}
// Check ordered sets
if len(r.ordered) > 0 {
cset := r.ordered[0]
if cset.isComplete() {
if sna16LTE(cset.ssn, r.nextSSN) {
return true
}
}
}
return false
}
func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, error) {
var cset *chunkSet
// Check unordered first
switch {
case len(r.unordered) > 0:
cset = r.unordered[0]
r.unordered = r.unordered[1:]
case len(r.ordered) > 0:
// Now, check ordered
cset = r.ordered[0]
if !cset.isComplete() {
return 0, 0, errTryAgain
}
if sna16GT(cset.ssn, r.nextSSN) {
return 0, 0, errTryAgain
}
r.ordered = r.ordered[1:]
if cset.ssn == r.nextSSN {
r.nextSSN++
}
default:
return 0, 0, errTryAgain
}
// Concat all fragments into the buffer
nWritten := 0
ppi := cset.ppi
var err error
for _, c := range cset.chunks {
toCopy := len(c.userData)
r.subtractNumBytes(toCopy)
if err == nil {
n := copy(buf[nWritten:], c.userData)
nWritten += n
if n < toCopy {
err = io.ErrShortBuffer
}
}
}
return nWritten, ppi, err
}
func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) {
// Use lastSSN to locate a chunkSet then remove it if the set has
// not been complete
keep := []*chunkSet{}
for _, set := range r.ordered {
if sna16LTE(set.ssn, lastSSN) {
if !set.isComplete() {
// drop the set
for _, c := range set.chunks {
r.subtractNumBytes(len(c.userData))
}
continue
}
}
keep = append(keep, set)
}
r.ordered = keep
// Finally, forward nextSSN
if sna16LTE(r.nextSSN, lastSSN) {
r.nextSSN = lastSSN + 1
}
}
func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) {
// Remove all fragments in the unordered sets that contains chunks
// equal to or older than `newCumulativeTSN`.
// We know all sets in the r.unordered are complete ones.
// Just remove chunks that are equal to or older than newCumulativeTSN
// from the unorderedChunks
lastIdx := -1
for i, c := range r.unorderedChunks {
if sna32GT(c.tsn, newCumulativeTSN) {
break
}
lastIdx = i
}
if lastIdx >= 0 {
for _, c := range r.unorderedChunks[0 : lastIdx+1] {
r.subtractNumBytes(len(c.userData))
}
r.unorderedChunks = r.unorderedChunks[lastIdx+1:]
}
}
func (r *reassemblyQueue) subtractNumBytes(nBytes int) {
cur := atomic.LoadUint64(&r.nBytes)
if int(cur) >= nBytes {
atomic.AddUint64(&r.nBytes, -uint64(nBytes))
} else {
atomic.StoreUint64(&r.nBytes, 0)
}
}
func (r *reassemblyQueue) getNumBytes() int {
return int(atomic.LoadUint64(&r.nBytes))
}