diff --git a/agent/discovery_chain_endpoint.go b/agent/discovery_chain_endpoint.go index 1df2e39f40..c6dddd64e8 100644 --- a/agent/discovery_chain_endpoint.go +++ b/agent/discovery_chain_endpoint.go @@ -95,63 +95,6 @@ type discoveryChainReadRequest struct { OverrideConnectTimeout time.Duration } -func (t *discoveryChainReadRequest) UnmarshalJSON(data []byte) (err error) { - type Alias discoveryChainReadRequest - aux := &struct { - OverrideConnectTimeout interface{} - OverrideProtocol interface{} - OverrideMeshGateway *struct{ Mode interface{} } - - OverrideConnectTimeoutSnake interface{} `json:"override_connect_timeout"` - OverrideProtocolSnake interface{} `json:"override_protocol"` - OverrideMeshGatewaySnake *struct{ Mode interface{} } `json:"override_mesh_gateway"` - - *Alias - }{ - Alias: (*Alias)(t), - } - if err = lib.UnmarshalJSON(data, &aux); err != nil { - return err - } - - if aux.OverrideConnectTimeout == nil { - aux.OverrideConnectTimeout = aux.OverrideConnectTimeoutSnake - } - if aux.OverrideProtocol == nil { - aux.OverrideProtocol = aux.OverrideProtocolSnake - } - if aux.OverrideMeshGateway == nil { - aux.OverrideMeshGateway = aux.OverrideMeshGatewaySnake - } - - // weakly typed input - if aux.OverrideProtocol != nil { - switch v := aux.OverrideProtocol.(type) { - case string, float64, bool: - t.OverrideProtocol = fmt.Sprintf("%v", v) - default: - return fmt.Errorf("OverrideProtocol: invalid type %T", v) - } - } - if aux.OverrideMeshGateway != nil { - t.OverrideMeshGateway.Mode = structs.MeshGatewayMode(fmt.Sprintf("%v", aux.OverrideMeshGateway.Mode)) - } - - // duration - if aux.OverrideConnectTimeout != nil { - switch v := aux.OverrideConnectTimeout.(type) { - case string: - if t.OverrideConnectTimeout, err = time.ParseDuration(v); err != nil { - return err - } - case float64: - t.OverrideConnectTimeout = time.Duration(v) - } - } - - return nil -} - // discoveryChainReadResponse is the API variation of structs.DiscoveryChainResponse type discoveryChainReadResponse struct { Chain *structs.CompiledDiscoveryChain diff --git a/agent/http_decode_test.go b/agent/http_decode_test.go index 2fc79334a0..9993614439 100644 --- a/agent/http_decode_test.go +++ b/agent/http_decode_test.go @@ -1984,283 +1984,6 @@ func TestDecodeCatalogRegister(t *testing.T) { } } -// discoveryChainReadRequest: -// OverrideMeshGateway structs.MeshGatewayConfig -// Mode structs.MeshGatewayMode // string -// OverrideProtocol string -// OverrideConnectTimeout time.Duration -func TestDecodeDiscoveryChainRead(t *testing.T) { - var weaklyTypedDurationTCs = []translateValueTestCase{ - { - desc: "positive string integer (weakly typed)", - durations: &durationTC{ - in: `"2000"`, - }, - wantErr: true, - }, - { - desc: "negative string integer (weakly typed)", - durations: &durationTC{ - in: `"-50"`, - }, - wantErr: true, - }, - } - - for _, tc := range append(durationTestCases, weaklyTypedDurationTCs...) { - t.Run(tc.desc, func(t *testing.T) { - // set up request body - jsonStr := fmt.Sprintf(`{ - "OverrideConnectTimeout": %s - }`, tc.durations.in) - body := bytes.NewBuffer([]byte(jsonStr)) - - var out discoveryChainReadRequest - err := decodeBody(body, &out) - if err == nil && tc.wantErr { - t.Fatal("expected err, got nil") - } - if err != nil && !tc.wantErr { - t.Fatalf("expected nil error, got %v", err) - } - if out.OverrideConnectTimeout != tc.durations.want { - t.Fatalf("expected OverrideConnectTimeout to be %s, got %s", tc.durations.want, out.OverrideConnectTimeout) - } - }) - } - - // Other possibly weakly-typed inputs.. - var weaklyTypedStringTCs = []struct { - desc string - in, want string - wantErr bool - }{ - { - desc: "positive integer for string field (weakly typed)", - in: `200`, - want: "200", - }, - { - desc: "negative integer for string field (weakly typed)", - in: `-200`, - want: "-200", - }, - { - desc: "bool for string field (weakly typed)", - in: `true`, - want: "true", // previously: "1" - }, - { - desc: "float for string field (weakly typed)", - in: `1.2223`, - want: "1.2223", - }, - { - desc: "map for string field (weakly typed)", - in: `{}`, - wantErr: true, - }, - { - desc: "slice for string field (weakly typed)", - in: `[]`, - wantErr: true, // previously: want: "" - }, - } - - for _, tc := range weaklyTypedStringTCs { - t.Run(tc.desc, func(t *testing.T) { - // set up request body - jsonStr := fmt.Sprintf(`{ - "OverrideProtocol": %[1]s, - "OverrideMeshGateway": {"Mode": %[1]s} - }`, tc.in) - body := bytes.NewBuffer([]byte(jsonStr)) - - var out discoveryChainReadRequest - err := decodeBody(body, &out) - - if err == nil && tc.wantErr { - t.Fatal("expected err, got nil") - } - if err != nil && !tc.wantErr { - t.Fatalf("expected nil error, got %v", err) - } - if out.OverrideProtocol != tc.want { - t.Fatalf("expected OverrideProtocol to be %s, got %s", tc.want, out.OverrideProtocol) - } - if out.OverrideMeshGateway.Mode != structs.MeshGatewayMode(tc.want) { - t.Fatalf("expected OverrideMeshGateway.Mode to be %s, got %s", tc.want, out.OverrideMeshGateway.Mode) - } - }) - } - - // translate field tcs - - overrideMeshGatewayFields := []string{ - `"OverrideMeshGateway": {"Mode": %s}`, - `"override_mesh_gateway": {"Mode": %s}`, - } - - overrideMeshGatewayEqFn := func(out interface{}, want interface{}) error { - got := out.(discoveryChainReadRequest).OverrideMeshGateway.Mode - if got != structs.MeshGatewayMode(want.(string)) { - return fmt.Errorf("expected OverrideMeshGateway to be %s, got %s", want, got) - } - return nil - } - - var translateOverrideMeshGatewayTCs = []translateKeyTestCase{ - { - desc: "OverrideMeshGateway: both set", - in: []interface{}{`"one"`, `"two"`}, - want: "one", - jsonFmtStr: `{` + strings.Join(overrideMeshGatewayFields, ",") + `}`, - equalityFn: overrideMeshGatewayEqFn, - }, - { - desc: "OverrideMeshGateway: first set", - in: []interface{}{`"one"`}, - want: "one", - jsonFmtStr: `{` + overrideMeshGatewayFields[0] + `}`, - equalityFn: overrideMeshGatewayEqFn, - }, - { - desc: "OverrideMeshGateway: second set", - in: []interface{}{`"two"`}, - want: "two", - jsonFmtStr: `{` + overrideMeshGatewayFields[1] + `}`, - equalityFn: overrideMeshGatewayEqFn, - }, - { - desc: "OverrideMeshGateway: neither set", - in: []interface{}{}, - want: "", // zero value - jsonFmtStr: `{}`, - equalityFn: overrideMeshGatewayEqFn, - }, - } - - overrideProtocolFields := []string{ - `"OverrideProtocol": %s`, - `"override_protocol": %s`, - } - - overrideProtocolEqFn := func(out interface{}, want interface{}) error { - got := out.(discoveryChainReadRequest).OverrideProtocol - if got != want { - return fmt.Errorf("expected OverrideProtocol to be %s, got %s", want, got) - } - return nil - } - - var translateOverrideProtocolTCs = []translateKeyTestCase{ - { - desc: "OverrideProtocol: both set", - in: []interface{}{`"one"`, `"two"`}, - want: "one", - jsonFmtStr: `{` + strings.Join(overrideProtocolFields, ",") + `}`, - equalityFn: overrideProtocolEqFn, - }, - { - desc: "OverrideProtocol: first set", - in: []interface{}{`"one"`}, - want: "one", - jsonFmtStr: `{` + overrideProtocolFields[0] + `}`, - equalityFn: overrideProtocolEqFn, - }, - { - desc: "OverrideProtocol: second set", - in: []interface{}{`"two"`}, - want: "two", - jsonFmtStr: `{` + overrideProtocolFields[1] + `}`, - equalityFn: overrideProtocolEqFn, - }, - { - desc: "OverrideProtocol: neither set", - in: []interface{}{}, - want: "", // zero value - jsonFmtStr: `{}`, - equalityFn: overrideProtocolEqFn, - }, - } - - overrideConnectTimeoutFields := []string{ - `"OverrideConnectTimeout": %s`, - `"override_connect_timeout": %s`, - } - - overrideConnectTimeoutEqFn := func(out interface{}, want interface{}) error { - got := out.(discoveryChainReadRequest).OverrideConnectTimeout - if got != want { - return fmt.Errorf("expected OverrideConnectTimeout to be %s, got %s", want, got) - } - return nil - } - - var translateOverrideConnectTimeoutTCs = []translateKeyTestCase{ - { - desc: "OverrideConnectTimeout: both set", - in: []interface{}{`"2h0m"`, `"3h0m"`}, - want: 2 * time.Hour, - jsonFmtStr: "{" + strings.Join(overrideConnectTimeoutFields, ",") + "}", - equalityFn: overrideConnectTimeoutEqFn, - }, - { - desc: "OverrideConnectTimeout: first set", - in: []interface{}{`"2h0m"`}, - want: 2 * time.Hour, - jsonFmtStr: "{" + overrideConnectTimeoutFields[0] + "}", - equalityFn: overrideConnectTimeoutEqFn, - }, - { - desc: "OverrideConnectTimeout: second set", - in: []interface{}{`"3h0m"`}, - want: 3 * time.Hour, - jsonFmtStr: "{" + overrideConnectTimeoutFields[1] + "}", - equalityFn: overrideConnectTimeoutEqFn, - }, - { - desc: "OverrideConnectTimeout: neither set", - in: []interface{}{}, - want: time.Duration(0), - jsonFmtStr: "{}", - equalityFn: overrideConnectTimeoutEqFn, - }, - } - - // lib.TranslateKeys(raw, map[string]string{ - // "override_mesh_gateway": "overridemeshgateway", - // "override_protocol": "overrideprotocol", - // "override_connect_timeout": "overrideconnecttimeout", - // }) - - translateFieldTCs := [][]translateKeyTestCase{ - translateOverrideMeshGatewayTCs, - translateOverrideProtocolTCs, - translateOverrideConnectTimeoutTCs, - } - - for _, tcGroup := range translateFieldTCs { - for _, tc := range tcGroup { - t.Run(tc.desc, func(t *testing.T) { - jsonStr := fmt.Sprintf(tc.jsonFmtStr, tc.in...) - body := bytes.NewBuffer([]byte(jsonStr)) - - var out discoveryChainReadRequest - err := decodeBody(body, &out) - if err != nil { - t.Fatal(err) - } - - if err := tc.equalityFn(out, tc.want); err != nil { - t.Fatal(err) - } - }) - } - } - -} - // IntentionRequest: // Datacenter string // Op structs.IntentionOp