diff --git a/connect/proxy/conn.go b/connect/proxy/conn.go index 70019e55cb..fe52853f0e 100644 --- a/connect/proxy/conn.go +++ b/connect/proxy/conn.go @@ -8,8 +8,9 @@ import ( // Conn represents a single proxied TCP connection. type Conn struct { - src, dst net.Conn - stopping int32 + src, dst net.Conn + srcW, dstW countWriter + stopping int32 } // NewConn returns a conn joining the two given net.Conn @@ -17,6 +18,8 @@ func NewConn(src, dst net.Conn) *Conn { return &Conn{ src: src, dst: dst, + srcW: countWriter{w: src}, + dstW: countWriter{w: dst}, stopping: 0, } } @@ -47,10 +50,10 @@ func (c *Conn) CopyBytes() error { // causing this goroutine to exit but not the outer one. See // TestConnSrcClosing which will fail if you comment the defer below. defer c.Close() - io.Copy(c.dst, c.src) + io.Copy(&c.dstW, c.src) }() - _, err := io.Copy(c.src, c.dst) + _, err := io.Copy(&c.srcW, c.dst) // Note that we don't wait for the other goroutine to finish because it either // already has due to it's src conn closing, or it will once our defer fires // and closes the source conn. No need for the extra coordination. @@ -59,3 +62,37 @@ func (c *Conn) CopyBytes() error { } return err } + +// Stats returns number of bytes transmitted and recieved. Transmit means bytes +// written to dst, receive means bytes written to src. +func (c *Conn) Stats() (txBytes, rxBytes uint64) { + return c.srcW.Written(), c.dstW.Written() +} + +// countWriter is an io.Writer that counts the number of bytes being written +// before passing them through. We use it to gather metrics for bytes +// sent/received. Note that since we are always copying between a net.TCPConn +// and a tls.Conn, none of the optimisations using syscalls like splice and +// ReaderTo/WriterFrom can be used anyway and io.Copy falls back to a generic +// buffered read/write loop. +// +// We use atomic updates to synchronize reads and writes here. It's the cheapest +// uncontended option based on +// https://gist.github.com/banks/e76b40c0cc4b01503f0a0e4e0af231d5. Further +// optimization can be made when if/when identified as a real overhead. +type countWriter struct { + written uint64 + w io.Writer +} + +// Write implements io.Writer +func (cw *countWriter) Write(p []byte) (n int, err error) { + n, err = cw.w.Write(p) + atomic.AddUint64(&cw.written, uint64(n)) + return +} + +// Written returns how many bytes have been written to w. +func (cw *countWriter) Written() uint64 { + return atomic.LoadUint64(&cw.written) +} diff --git a/connect/proxy/conn_test.go b/connect/proxy/conn_test.go index a37720ea0a..4de428ad08 100644 --- a/connect/proxy/conn_test.go +++ b/connect/proxy/conn_test.go @@ -7,6 +7,7 @@ import ( "testing" "time" + "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -88,6 +89,10 @@ func TestConn(t *testing.T) { require.Nil(t, err) require.Equal(t, "ping 2\n", got) + tx, rx := c.Stats() + assert.Equal(t, uint64(7), tx) + assert.Equal(t, uint64(7), rx) + _, err = src.Write([]byte("pong 1\n")) require.Nil(t, err) _, err = dst.Write([]byte("pong 2\n")) @@ -101,6 +106,10 @@ func TestConn(t *testing.T) { require.Nil(t, err) require.Equal(t, "pong 2\n", got) + tx, rx = c.Stats() + assert.Equal(t, uint64(14), tx) + assert.Equal(t, uint64(14), rx) + c.Close() ret := <-retCh