diff --git a/core/network/context.go b/core/network/context.go index 01f3177d..7fabfb53 100644 --- a/core/network/context.go +++ b/core/network/context.go @@ -14,12 +14,13 @@ type noDialCtxKey struct{} type dialPeerTimeoutCtxKey struct{} type forceDirectDialCtxKey struct{} type useTransientCtxKey struct{} -type simConnectCtxKey struct{} +type simConnectCtxKey struct{ isClient bool } var noDial = noDialCtxKey{} var forceDirectDial = forceDirectDialCtxKey{} var useTransient = useTransientCtxKey{} -var simConnect = simConnectCtxKey{} +var simConnectIsServer = simConnectCtxKey{} +var simConnectIsClient = simConnectCtxKey{isClient: true} // EXPERIMENTAL // WithForceDirectDial constructs a new context with an option that instructs the network @@ -39,22 +40,26 @@ func GetForceDirectDial(ctx context.Context) (forceDirect bool, reason string) { return false, "" } -// EXPERIMENTAL // WithSimultaneousConnect constructs a new context with an option that instructs the transport // to apply hole punching logic where applicable. -func WithSimultaneousConnect(ctx context.Context, reason string) context.Context { - return context.WithValue(ctx, simConnect, reason) +// EXPERIMENTAL +func WithSimultaneousConnect(ctx context.Context, isClient bool, reason string) context.Context { + if isClient { + return context.WithValue(ctx, simConnectIsClient, reason) + } + return context.WithValue(ctx, simConnectIsServer, reason) } -// EXPERIMENTAL // GetSimultaneousConnect returns true if the simultaneous connect option is set in the context. -func GetSimultaneousConnect(ctx context.Context) (simconnect bool, reason string) { - v := ctx.Value(simConnect) - if v != nil { - return true, v.(string) +// EXPERIMENTAL +func GetSimultaneousConnect(ctx context.Context) (simconnect bool, isClient bool, reason string) { + if v := ctx.Value(simConnectIsClient); v != nil { + return true, true, v.(string) } - - return false, "" + if v := ctx.Value(simConnectIsServer); v != nil { + return true, false, v.(string) + } + return false, false, "" } // WithNoDial constructs a new context with an option that instructs the network diff --git a/core/network/context_test.go b/core/network/context_test.go index 09125516..b12def5e 100644 --- a/core/network/context_test.go +++ b/core/network/context_test.go @@ -4,6 +4,8 @@ import ( "context" "testing" "time" + + "github.com/stretchr/testify/require" ) func TestDefaultTimeout(t *testing.T) { @@ -38,3 +40,20 @@ func TestSettingTimeout(t *testing.T) { t.Fatal("peer timeout doesn't match set timeout") } } + +func TestSimultaneousConnect(t *testing.T) { + t.Run("for the server", func(t *testing.T) { + serverCtx := WithSimultaneousConnect(context.Background(), false, "foobar") + ok, isClient, reason := GetSimultaneousConnect(serverCtx) + require.True(t, ok) + require.False(t, isClient) + require.Equal(t, reason, "foobar") + }) + t.Run("for the client", func(t *testing.T) { + serverCtx := WithSimultaneousConnect(context.Background(), true, "foo") + ok, isClient, reason := GetSimultaneousConnect(serverCtx) + require.True(t, ok) + require.True(t, isClient) + require.Equal(t, reason, "foo") + }) +}