Add new watch.Map type to refactor proxycfg

This commit is contained in:
Chris S. Kim 2022-07-13 12:12:31 -04:00 committed by Chris S. Kim
parent b4ffa9ae0c
commit 7f32cba735
2 changed files with 221 additions and 0 deletions

View File

@ -0,0 +1,108 @@
package watch
import "context"
// Map safely stores and retrieves values by validating that
// there is a live watch for a key. InitWatch must be called
// to associate a key with its cancel function before any
// Set's are called.
type Map[K comparable, V any] struct {
M map[K]watchedVal[V]
}
type watchedVal[V any] struct {
Val *V
// keeping cancel private has a beneficial side effect:
// copying Map with copystructure.Copy will zero out
// cancel, preventing it from being called by the
// receiver of a proxy config snapshot.
cancel context.CancelFunc
}
func NewMap[K comparable, V any]() Map[K, V] {
return Map[K, V]{M: make(map[K]watchedVal[V])}
}
// InitWatch associates a cancel function with a key,
// allowing Set to be called for the key. The cancel
// function is allowed to be nil.
//
// Any existing data for a key will be cancelled and
// overwritten.
func (m Map[K, V]) InitWatch(key K, cancel func()) {
if _, present := m.M[key]; present {
m.CancelWatch(key)
}
m.M[key] = watchedVal[V]{
cancel: cancel,
}
}
// CancelWatch first calls the cancel function
// associated with the key then deletes the key
// from the map. No-op if key is not present.
func (m Map[K, V]) CancelWatch(key K) {
if entry, ok := m.M[key]; ok {
if entry.cancel != nil {
entry.cancel()
}
delete(m.M, key)
}
}
// IsWatched returns true if InitWatch has been
// called for key and has not been cancelled by
// CancelWatch.
func (m Map[K, V]) IsWatched(key K) bool {
if _, present := m.M[key]; present {
return true
}
return false
}
// Set stores V if K exists in the map.
// No-op if the key never was initialized with InitWatch
// or if the entry got cancelled by CancelWatch.
func (m Map[K, V]) Set(key K, val V) bool {
if entry, ok := m.M[key]; ok {
entry.Val = &val
m.M[key] = entry
return true
}
return false
}
// Get returns the underlying value for a key.
// If an entry has been set, returns (V, true).
// Otherwise, returns the zero value (V, false).
//
// Note that even if InitWatch has been called
// for a key, unless Set has been called this
// function will return false.
func (m Map[K, V]) Get(key K) (V, bool) {
if entry, ok := m.M[key]; ok {
if entry.Val != nil {
return *entry.Val, true
}
}
var empty V
return empty, false
}
func (m Map[K, V]) Len() int {
return len(m.M)
}
// ForEachKey iterates through the map, calling f
// for each iteration. It is up to the caller to
// Get the value and nil-check if required.
// Stops iterating if f returns false.
// Order of iteration is non-deterministic.
func (m Map[K, V]) ForEachKey(f func(K) bool) {
for k := range m.M {
if ok := f(k); !ok {
return
}
}
}

View File

@ -0,0 +1,113 @@
package watch
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMap(t *testing.T) {
m := NewMap[string, string]()
// Set without init is a no-op
{
m.Set("hello", "world")
require.Equal(t, 0, m.Len())
}
// Getting from empty map
{
got, ok := m.Get("hello")
require.False(t, ok)
require.Empty(t, got)
}
var called bool
cancelMock := func() {
called = true
}
// InitWatch successful
{
m.InitWatch("hello", cancelMock)
require.Equal(t, 1, m.Len())
}
// Get still returns false
{
got, ok := m.Get("hello")
require.False(t, ok)
require.Empty(t, got)
}
// Set successful
{
require.True(t, m.Set("hello", "world"))
require.Equal(t, 1, m.Len())
}
// Get successful
{
got, ok := m.Get("hello")
require.True(t, ok)
require.Equal(t, "world", got)
}
// CancelWatch successful
{
m.CancelWatch("hello")
require.Equal(t, 0, m.Len())
require.True(t, called)
}
// Get no-op
{
got, ok := m.Get("hello")
require.False(t, ok)
require.Empty(t, got)
}
// Set no-op
{
require.False(t, m.Set("hello", "world"))
require.Equal(t, 0, m.Len())
}
}
func TestMap_ForEach(t *testing.T) {
type testType struct {
s string
}
m := NewMap[string, any]()
inputs := map[string]any{
"hello": 13,
"foo": struct{}{},
"bar": &testType{s: "wow"},
}
for k, v := range inputs {
m.InitWatch(k, nil)
m.Set(k, v)
}
require.Equal(t, 3, m.Len())
// returning true continues iteration
{
var count int
m.ForEachKey(func(k string) bool {
count++
return true
})
require.Equal(t, 3, count)
}
// returning false exits loop
{
var count int
m.ForEachKey(func(k string) bool {
count++
return false
})
require.Equal(t, 1, count)
}
}