package log import ( "fmt" "io" "sync" "sync/atomic" ) // MaxWriterBuffer specifies how big the writer buffer can get before // killing the writer. var MaxWriterBuffer = 512 * 1024 // MirrorWriter implements a WriteCloser which syncs incoming bytes to multiple // [buffered] WriteClosers. They can be added with AddWriter(). type MirrorWriter struct { active uint32 // channel for incoming writers writerAdd chan *writerAdd // slices of writer/sync-channel pairs writers []*bufWriter // synchronization channel for incoming writes msgSync chan []byte } type writerSync struct { w io.WriteCloser br chan []byte } // NewMirrorWriter initializes and returns a MirrorWriter. func NewMirrorWriter() *MirrorWriter { mw := &MirrorWriter{ msgSync: make(chan []byte, 64), // sufficiently large buffer to avoid callers waiting writerAdd: make(chan *writerAdd), } go mw.logRoutine() return mw } // Write broadcasts the written bytes to all Writers. func (mw *MirrorWriter) Write(b []byte) (int, error) { mycopy := make([]byte, len(b)) copy(mycopy, b) mw.msgSync <- mycopy return len(b), nil } // Close closes the MirrorWriter func (mw *MirrorWriter) Close() error { // it is up to the caller to ensure that write is not called during or // after close is called. close(mw.msgSync) return nil } func (mw *MirrorWriter) doClose() { for _, w := range mw.writers { w.writer.Close() } } func (mw *MirrorWriter) logRoutine() { // rebind to avoid races on nilling out struct fields msgSync := mw.msgSync writerAdd := mw.writerAdd defer mw.doClose() for { select { case b, ok := <-msgSync: if !ok { return } // write to all writers dropped := mw.broadcastMessage(b) // consolidate the slice if dropped { mw.clearDeadWriters() } case wa := <-writerAdd: mw.writers = append(mw.writers, newBufWriter(wa.w)) atomic.StoreUint32(&mw.active, 1) close(wa.done) } } } // broadcastMessage sends the given message to every writer // if any writer is killed during the send, 'true' is returned func (mw *MirrorWriter) broadcastMessage(b []byte) bool { var dropped bool for i, w := range mw.writers { _, err := w.Write(b) if err != nil { mw.writers[i] = nil dropped = true } } return dropped } func (mw *MirrorWriter) clearDeadWriters() { writers := mw.writers mw.writers = nil for _, w := range writers { if w != nil { mw.writers = append(mw.writers, w) } } if len(mw.writers) == 0 { atomic.StoreUint32(&mw.active, 0) } } type writerAdd struct { w io.WriteCloser done chan struct{} } // AddWriter attaches a new WriteCloser to this MirrorWriter. // The new writer will start getting any bytes written to the mirror. func (mw *MirrorWriter) AddWriter(w io.WriteCloser) { wa := &writerAdd{ w: w, done: make(chan struct{}), } mw.writerAdd <- wa <-wa.done } // Active returns if there is at least one Writer // attached to this MirrorWriter func (mw *MirrorWriter) Active() (active bool) { return atomic.LoadUint32(&mw.active) == 1 } func newBufWriter(w io.WriteCloser) *bufWriter { bw := &bufWriter{ writer: w, incoming: make(chan []byte, 1), } go bw.loop() return bw } // writes incoming messages to a buffer and when it fills // up, writes them to the writer type bufWriter struct { writer io.WriteCloser incoming chan []byte deathLock sync.Mutex dead bool } var errDeadWriter = fmt.Errorf("writer is dead") func (bw *bufWriter) Write(b []byte) (int, error) { bw.deathLock.Lock() dead := bw.dead bw.deathLock.Unlock() if dead { if bw.incoming != nil { close(bw.incoming) bw.incoming = nil } return 0, errDeadWriter } bw.incoming <- b return len(b), nil } func (bw *bufWriter) die() { bw.deathLock.Lock() bw.dead = true bw.writer.Close() bw.deathLock.Unlock() } func (bw *bufWriter) loop() { bufsize := 0 bufBase := make([][]byte, 0, 16) // some initial memory buffered := bufBase nextCh := make(chan []byte) var nextMsg []byte go func() { for b := range nextCh { _, err := bw.writer.Write(b) if err != nil { // TODO: need a way to notify there was an error here // wouldn't want to log here as it could casue an infinite loop bw.die() return } } }() // collect and buffer messages incoming := bw.incoming for { if nextMsg == nil || nextCh == nil { // nextCh == nil implies we are 'dead' and draining the incoming channel // until the caller notices and closes it for us select { case b, ok := <-incoming: if !ok { return } nextMsg = b } } select { case b, ok := <-incoming: if !ok { return } bufsize += len(b) buffered = append(buffered, b) if bufsize > MaxWriterBuffer { // if we have too many messages buffered, kill the writer bw.die() if nextCh != nil { close(nextCh) } nextCh = nil // explicity keep going here to drain incoming } case nextCh <- nextMsg: nextMsg = nil if len(buffered) > 0 { nextMsg = buffered[0] buffered = buffered[1:] bufsize -= len(nextMsg) } if len(buffered) == 0 { // reset slice position buffered = bufBase[:0] } } } }