diff --git a/agent/proxycfg/internal/watch/watchmap.go b/agent/proxycfg/internal/watch/watchmap.go new file mode 100644 index 0000000000..bbf42dc9af --- /dev/null +++ b/agent/proxycfg/internal/watch/watchmap.go @@ -0,0 +1,108 @@ +package watch + +import "context" + +// Map safely stores and retrieves values by validating that +// there is a live watch for a key. InitWatch must be called +// to associate a key with its cancel function before any +// Set's are called. +type Map[K comparable, V any] struct { + M map[K]watchedVal[V] +} + +type watchedVal[V any] struct { + Val *V + + // keeping cancel private has a beneficial side effect: + // copying Map with copystructure.Copy will zero out + // cancel, preventing it from being called by the + // receiver of a proxy config snapshot. + cancel context.CancelFunc +} + +func NewMap[K comparable, V any]() Map[K, V] { + return Map[K, V]{M: make(map[K]watchedVal[V])} +} + +// InitWatch associates a cancel function with a key, +// allowing Set to be called for the key. The cancel +// function is allowed to be nil. +// +// Any existing data for a key will be cancelled and +// overwritten. +func (m Map[K, V]) InitWatch(key K, cancel func()) { + if _, present := m.M[key]; present { + m.CancelWatch(key) + } + m.M[key] = watchedVal[V]{ + cancel: cancel, + } +} + +// CancelWatch first calls the cancel function +// associated with the key then deletes the key +// from the map. No-op if key is not present. +func (m Map[K, V]) CancelWatch(key K) { + if entry, ok := m.M[key]; ok { + if entry.cancel != nil { + entry.cancel() + } + delete(m.M, key) + } +} + +// IsWatched returns true if InitWatch has been +// called for key and has not been cancelled by +// CancelWatch. +func (m Map[K, V]) IsWatched(key K) bool { + if _, present := m.M[key]; present { + return true + } + return false +} + +// Set stores V if K exists in the map. +// No-op if the key never was initialized with InitWatch +// or if the entry got cancelled by CancelWatch. +func (m Map[K, V]) Set(key K, val V) bool { + if entry, ok := m.M[key]; ok { + entry.Val = &val + m.M[key] = entry + return true + } + return false +} + +// Get returns the underlying value for a key. +// If an entry has been set, returns (V, true). +// Otherwise, returns the zero value (V, false). +// +// Note that even if InitWatch has been called +// for a key, unless Set has been called this +// function will return false. +func (m Map[K, V]) Get(key K) (V, bool) { + if entry, ok := m.M[key]; ok { + if entry.Val != nil { + return *entry.Val, true + } + } + var empty V + return empty, false +} + +func (m Map[K, V]) Len() int { + return len(m.M) +} + +// ForEachKey iterates through the map, calling f +// for each iteration. It is up to the caller to +// Get the value and nil-check if required. +// Stops iterating if f returns false. +// Order of iteration is non-deterministic. +func (m Map[K, V]) ForEachKey(f func(K) bool) { + for k := range m.M { + if ok := f(k); !ok { + return + } + } +} diff --git a/agent/proxycfg/internal/watch/watchmap_test.go b/agent/proxycfg/internal/watch/watchmap_test.go new file mode 100644 index 0000000000..590351853e --- /dev/null +++ b/agent/proxycfg/internal/watch/watchmap_test.go @@ -0,0 +1,113 @@ +package watch + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestMap(t *testing.T) { + m := NewMap[string, string]() + + // Set without init is a no-op + { + m.Set("hello", "world") + require.Equal(t, 0, m.Len()) + } + + // Getting from empty map + { + got, ok := m.Get("hello") + require.False(t, ok) + require.Empty(t, got) + } + + var called bool + cancelMock := func() { + called = true + } + + // InitWatch successful + { + m.InitWatch("hello", cancelMock) + require.Equal(t, 1, m.Len()) + } + + // Get still returns false + { + got, ok := m.Get("hello") + require.False(t, ok) + require.Empty(t, got) + } + + // Set successful + { + require.True(t, m.Set("hello", "world")) + require.Equal(t, 1, m.Len()) + } + + // Get successful + { + got, ok := m.Get("hello") + require.True(t, ok) + require.Equal(t, "world", got) + } + + // CancelWatch successful + { + m.CancelWatch("hello") + require.Equal(t, 0, m.Len()) + require.True(t, called) + } + + // Get no-op + { + got, ok := m.Get("hello") + require.False(t, ok) + require.Empty(t, got) + } + + // Set no-op + { + require.False(t, m.Set("hello", "world")) + require.Equal(t, 0, m.Len()) + } +} + +func TestMap_ForEach(t *testing.T) { + type testType struct { + s string + } + + m := NewMap[string, any]() + inputs := map[string]any{ + "hello": 13, + "foo": struct{}{}, + "bar": &testType{s: "wow"}, + } + for k, v := range inputs { + m.InitWatch(k, nil) + m.Set(k, v) + } + require.Equal(t, 3, m.Len()) + + // returning true continues iteration + { + var count int + m.ForEachKey(func(k string) bool { + count++ + return true + }) + require.Equal(t, 3, count) + } + + // returning false exits loop + { + var count int + m.ForEachKey(func(k string) bool { + count++ + return false + }) + require.Equal(t, 1, count) + } +}