package sctp import ( "errors" "io" "sort" "sync/atomic" ) func sortChunksByTSN(a []*chunkPayloadData) { sort.Slice(a, func(i, j int) bool { return sna32LT(a[i].tsn, a[j].tsn) }) } func sortChunksBySSN(a []*chunkSet) { sort.Slice(a, func(i, j int) bool { return sna16LT(a[i].ssn, a[j].ssn) }) } // chunkSet is a set of chunks that share the same SSN type chunkSet struct { ssn uint16 // used only with the ordered chunks ppi PayloadProtocolIdentifier chunks []*chunkPayloadData } func newChunkSet(ssn uint16, ppi PayloadProtocolIdentifier) *chunkSet { return &chunkSet{ ssn: ssn, ppi: ppi, chunks: []*chunkPayloadData{}, } } func (set *chunkSet) push(chunk *chunkPayloadData) bool { // check if dup for _, c := range set.chunks { if c.tsn == chunk.tsn { return false } } // append and sort set.chunks = append(set.chunks, chunk) sortChunksByTSN(set.chunks) // Check if we now have a complete set complete := set.isComplete() return complete } func (set *chunkSet) isComplete() bool { // Condition for complete set // 0. Has at least one chunk. // 1. Begins with beginningFragment set to true // 2. Ends with endingFragment set to true // 3. TSN monotinically increase by 1 from beginning to end // 0. nChunks := len(set.chunks) if nChunks == 0 { return false } // 1. if !set.chunks[0].beginningFragment { return false } // 2. if !set.chunks[nChunks-1].endingFragment { return false } // 3. var lastTSN uint32 for i, c := range set.chunks { if i > 0 { // Fragments must have contiguous TSN // From RFC 4960 Section 3.3.1: // When a user message is fragmented into multiple chunks, the TSNs are // used by the receiver to reassemble the message. This means that the // TSNs for each fragment of a fragmented user message MUST be strictly // sequential. if c.tsn != lastTSN+1 { // mid or end fragment is missing return false } } lastTSN = c.tsn } return true } type reassemblyQueue struct { si uint16 nextSSN uint16 // expected SSN for next ordered chunk ordered []*chunkSet unordered []*chunkSet unorderedChunks []*chunkPayloadData nBytes uint64 } var errTryAgain = errors.New("try again") func newReassemblyQueue(si uint16) *reassemblyQueue { // From RFC 4960 Sec 6.5: // The Stream Sequence Number in all the streams MUST start from 0 when // the association is established. Also, when the Stream Sequence // Number reaches the value 65535 the next Stream Sequence Number MUST // be set to 0. return &reassemblyQueue{ si: si, nextSSN: 0, // From RFC 4960 Sec 6.5: ordered: make([]*chunkSet, 0), unordered: make([]*chunkSet, 0), } } func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { var cset *chunkSet if chunk.streamIdentifier != r.si { return false } if chunk.unordered { // First, insert into unorderedChunks array r.unorderedChunks = append(r.unorderedChunks, chunk) atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) sortChunksByTSN(r.unorderedChunks) // Scan unorderedChunks that are contiguous (in TSN) cset = r.findCompleteUnorderedChunkSet() // If found, append the complete set to the unordered array if cset != nil { r.unordered = append(r.unordered, cset) return true } return false } // This is an ordered chunk if sna16LT(chunk.streamSequenceNumber, r.nextSSN) { return false } // Check if a chunkSet with the SSN already exists for _, set := range r.ordered { if set.ssn == chunk.streamSequenceNumber { cset = set break } } // If not found, create a new chunkSet if cset == nil { cset = newChunkSet(chunk.streamSequenceNumber, chunk.payloadType) r.ordered = append(r.ordered, cset) if !chunk.unordered { sortChunksBySSN(r.ordered) } } atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) return cset.push(chunk) } func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { startIdx := -1 nChunks := 0 var lastTSN uint32 var found bool for i, c := range r.unorderedChunks { // seek beigining if c.beginningFragment { startIdx = i nChunks = 1 lastTSN = c.tsn if c.endingFragment { found = true break } continue } if startIdx < 0 { continue } // Check if contiguous in TSN if c.tsn != lastTSN+1 { startIdx = -1 continue } lastTSN = c.tsn nChunks++ if c.endingFragment { found = true break } } if !found { return nil } // Extract the range of chunks var chunks []*chunkPayloadData chunks = append(chunks, r.unorderedChunks[startIdx:startIdx+nChunks]...) r.unorderedChunks = append( r.unorderedChunks[:startIdx], r.unorderedChunks[startIdx+nChunks:]...) chunkSet := newChunkSet(0, chunks[0].payloadType) chunkSet.chunks = chunks return chunkSet } func (r *reassemblyQueue) isReadable() bool { // Check unordered first if len(r.unordered) > 0 { // The chunk sets in r.unordered should all be complete. return true } // Check ordered sets if len(r.ordered) > 0 { cset := r.ordered[0] if cset.isComplete() { if sna16LTE(cset.ssn, r.nextSSN) { return true } } } return false } func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, error) { var cset *chunkSet // Check unordered first switch { case len(r.unordered) > 0: cset = r.unordered[0] r.unordered = r.unordered[1:] case len(r.ordered) > 0: // Now, check ordered cset = r.ordered[0] if !cset.isComplete() { return 0, 0, errTryAgain } if sna16GT(cset.ssn, r.nextSSN) { return 0, 0, errTryAgain } r.ordered = r.ordered[1:] if cset.ssn == r.nextSSN { r.nextSSN++ } default: return 0, 0, errTryAgain } // Concat all fragments into the buffer nWritten := 0 ppi := cset.ppi var err error for _, c := range cset.chunks { toCopy := len(c.userData) r.subtractNumBytes(toCopy) if err == nil { n := copy(buf[nWritten:], c.userData) nWritten += n if n < toCopy { err = io.ErrShortBuffer } } } return nWritten, ppi, err } func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) { // Use lastSSN to locate a chunkSet then remove it if the set has // not been complete keep := []*chunkSet{} for _, set := range r.ordered { if sna16LTE(set.ssn, lastSSN) { if !set.isComplete() { // drop the set for _, c := range set.chunks { r.subtractNumBytes(len(c.userData)) } continue } } keep = append(keep, set) } r.ordered = keep // Finally, forward nextSSN if sna16LTE(r.nextSSN, lastSSN) { r.nextSSN = lastSSN + 1 } } func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) { // Remove all fragments in the unordered sets that contains chunks // equal to or older than `newCumulativeTSN`. // We know all sets in the r.unordered are complete ones. // Just remove chunks that are equal to or older than newCumulativeTSN // from the unorderedChunks lastIdx := -1 for i, c := range r.unorderedChunks { if sna32GT(c.tsn, newCumulativeTSN) { break } lastIdx = i } if lastIdx >= 0 { for _, c := range r.unorderedChunks[0 : lastIdx+1] { r.subtractNumBytes(len(c.userData)) } r.unorderedChunks = r.unorderedChunks[lastIdx+1:] } } func (r *reassemblyQueue) subtractNumBytes(nBytes int) { cur := atomic.LoadUint64(&r.nBytes) if int(cur) >= nBytes { atomic.AddUint64(&r.nBytes, -uint64(nBytes)) } else { atomic.StoreUint64(&r.nBytes, 0) } } func (r *reassemblyQueue) getNumBytes() int { return int(atomic.LoadUint64(&r.nBytes)) }