72 lines
1.7 KiB
Go
Raw Normal View History

2017-12-03 01:24:05 +01:00
package sshd
import (
"io"
"net"
"time"
"github.com/shazow/rateio"
)
type limitedConn struct {
net.Conn
io.Reader // Our rate-limited io.Reader for net.Conn
}
func (r *limitedConn) Read(p []byte) (n int, err error) {
return r.Reader.Read(p)
}
// ReadLimitConn returns a net.Conn whose io.Reader interface is rate-limited by limiter.
func ReadLimitConn(conn net.Conn, limiter rateio.Limiter) net.Conn {
return &limitedConn{
Conn: conn,
Reader: rateio.NewReader(conn, limiter),
}
}
// Count each read as 1 unless it exceeds some number of bytes.
type inputLimiter struct {
// TODO: Could do all kinds of fancy things here, like be more forgiving of
// connections that have been around for a while.
Amount int
Frequency time.Duration
remaining int
readCap int
numRead int
timeRead time.Time
}
// NewInputLimiter returns a rateio.Limiter with sensible defaults for
// differentiating between humans typing and bots spamming.
func NewInputLimiter() rateio.Limiter {
grace := time.Second * 3
return &inputLimiter{
Amount: 2 << 14, // ~16kb, should be plenty for a high typing rate/copypasta/large key handshakes.
Frequency: time.Minute * 1,
readCap: 128, // Allow up to 128 bytes per read (anecdotally, 1 character = 52 bytes over ssh)
numRead: -1024 * 1024, // Start with a 1mb grace
timeRead: time.Now().Add(grace),
}
}
// Count applies 1 if n<readCap, else n
func (limit *inputLimiter) Count(n int) error {
now := time.Now()
if now.After(limit.timeRead) {
limit.numRead = 0
limit.timeRead = now.Add(limit.Frequency)
}
if n <= limit.readCap {
limit.numRead += 1
} else {
limit.numRead += n
}
if limit.numRead > limit.Amount {
return rateio.ErrRateExceeded
}
return nil
}