package agent

import (
	"bytes"
	"encoding/base64"
	"fmt"
	"net/http"
	"net/http/httptest"
	"strings"
	"testing"
	"time"

	"github.com/hashicorp/consul/agent/structs"
	"github.com/hashicorp/consul/api"
	"github.com/hashicorp/consul/testrpc"
	"github.com/hashicorp/raft"
	"github.com/stretchr/testify/assert"
)

func TestTxnEndpoint_Bad_JSON(t *testing.T) {
	t.Parallel()
	a := NewTestAgent(t, "")
	defer a.Shutdown()

	buf := bytes.NewBuffer([]byte("{"))
	req, _ := http.NewRequest("PUT", "/v1/txn", buf)
	resp := httptest.NewRecorder()
	if _, err := a.srv.Txn(resp, req); err != nil {
		t.Fatalf("err: %v", err)
	}
	if resp.Code != 400 {
		t.Fatalf("expected 400, got %d", resp.Code)
	}
	if !bytes.Contains(resp.Body.Bytes(), []byte("Failed to parse")) {
		t.Fatalf("expected conflicting args error")
	}
}

func TestTxnEndpoint_Bad_Size_Item(t *testing.T) {
	t.Parallel()
	testIt := func(t *testing.T, agent *TestAgent, wantPass bool) {
		value := strings.Repeat("X", 3*raft.SuggestedMaxDataSize)
		value = base64.StdEncoding.EncodeToString([]byte(value))
		buf := bytes.NewBuffer([]byte(fmt.Sprintf(`
 [
     {
         "KV": {
             "Verb": "set",
             "Key": "key",
             "Value": %q
         }
     }
 ]
 `, value)))
		req, _ := http.NewRequest("PUT", "/v1/txn", buf)
		resp := httptest.NewRecorder()
		if _, err := agent.srv.Txn(resp, req); err != nil {
			t.Fatalf("err: %v", err)
		}
		if resp.Code != 413 && !wantPass {
			t.Fatalf("expected 413, got %d", resp.Code)
		}
		if resp.Code != 200 && wantPass {
			t.Fatalf("expected 200, got %d", resp.Code)
		}
	}

	t.Run("exceeds default limits", func(t *testing.T) {
		a := NewTestAgent(t, "")
		testIt(t, a, false)
		a.Shutdown()
	})

	t.Run("exceeds configured max txn len", func(t *testing.T) {
		a := NewTestAgent(t, "limits = { txn_max_req_len = 700000 }")
		testIt(t, a, false)
		a.Shutdown()
	})

	t.Run("exceeds default max kv value size", func(t *testing.T) {
		a := NewTestAgent(t, "limits = { txn_max_req_len = 123456789 }")
		testIt(t, a, false)
		a.Shutdown()
	})

	t.Run("allowed", func(t *testing.T) {
		a := NewTestAgent(t, `
limits = {
	txn_max_req_len = 123456789
	kv_max_value_size = 123456789
}`)
		testIt(t, a, true)
		a.Shutdown()
	})
}

func TestTxnEndpoint_Bad_Size_Net(t *testing.T) {
	t.Parallel()

	testIt := func(agent *TestAgent, wantPass bool) {
		value := strings.Repeat("X", 3*raft.SuggestedMaxDataSize)
		value = base64.StdEncoding.EncodeToString([]byte(value))
		buf := bytes.NewBuffer([]byte(fmt.Sprintf(`
 [
     {
         "KV": {
             "Verb": "set",
             "Key": "key1",
             "Value": %q
         }
     },
     {
         "KV": {
             "Verb": "set",
             "Key": "key1",
             "Value": %q
         }
     },
     {
         "KV": {
             "Verb": "set",
             "Key": "key1",
             "Value": %q
         }
     }
 ]
 `, value, value, value)))
		req, _ := http.NewRequest("PUT", "/v1/txn", buf)
		resp := httptest.NewRecorder()
		if _, err := agent.srv.Txn(resp, req); err != nil {
			t.Fatalf("err: %v", err)
		}
		if resp.Code != 413 && !wantPass {
			t.Fatalf("expected 413, got %d", resp.Code)
		}
		if resp.Code != 200 && wantPass {
			t.Fatalf("expected 200, got %d", resp.Code)
		}
	}

	t.Run("exceeds default limits", func(t *testing.T) {
		a := NewTestAgent(t, "")
		testIt(a, false)
		a.Shutdown()
	})

	t.Run("exceeds configured max txn len", func(t *testing.T) {
		a := NewTestAgent(t, "limits = { txn_max_req_len = 700000 }")
		testIt(a, false)
		a.Shutdown()
	})

	t.Run("exceeds default max kv value size", func(t *testing.T) {
		a := NewTestAgent(t, "limits = { txn_max_req_len = 123456789 }")
		testIt(a, false)
		a.Shutdown()
	})

	t.Run("allowed", func(t *testing.T) {
		a := NewTestAgent(t, `
limits = {
	txn_max_req_len = 123456789
	kv_max_value_size = 123456789
}`)
		testIt(a, true)
		a.Shutdown()
	})

	t.Run("allowed kv max backward compatible", func(t *testing.T) {
		a := NewTestAgent(t, "limits = { kv_max_value_size = 123456789 }")
		testIt(a, true)
		a.Shutdown()
	})
}

func TestTxnEndpoint_Bad_Size_Ops(t *testing.T) {
	t.Parallel()
	a := NewTestAgent(t, "")
	defer a.Shutdown()

	buf := bytes.NewBuffer([]byte(fmt.Sprintf(`
 [
     %s
     {
         "KV": {
             "Verb": "set",
             "Key": "key",
             "Value": ""
         }
     }
 ]
 `, strings.Repeat(`{ "KV": { "Verb": "get", "Key": "key" } },`, 2*maxTxnOps))))
	req, _ := http.NewRequest("PUT", "/v1/txn", buf)
	resp := httptest.NewRecorder()
	if _, err := a.srv.Txn(resp, req); err != nil {
		t.Fatalf("err: %v", err)
	}
	if resp.Code != 413 {
		t.Fatalf("expected 413, got %d", resp.Code)
	}
}

func TestTxnEndpoint_KV_Actions(t *testing.T) {
	t.Parallel()
	t.Run("", func(t *testing.T) {
		a := NewTestAgent(t, "")
		defer a.Shutdown()
		testrpc.WaitForTestAgent(t, a.RPC, "dc1")

		// Make sure all incoming fields get converted properly to the internal
		// RPC format.
		var index uint64
		id := makeTestSession(t, a.srv)
		{
			buf := bytes.NewBuffer([]byte(fmt.Sprintf(`
 [
     {
         "KV": {
             "Verb": "lock",
             "Key": "key",
             "Value": "aGVsbG8gd29ybGQ=",
             "Flags": 23,
             "Session": %q
         }
     },
     {
         "KV": {
             "Verb": "get",
             "Key": "key"
         }
     }
 ]
 `, id)))
			req, _ := http.NewRequest("PUT", "/v1/txn", buf)
			resp := httptest.NewRecorder()
			obj, err := a.srv.Txn(resp, req)
			if err != nil {
				t.Fatalf("err: %v", err)
			}
			if resp.Code != 200 {
				t.Fatalf("expected 200, got %d", resp.Code)
			}

			txnResp, ok := obj.(structs.TxnResponse)
			if !ok {
				t.Fatalf("bad type: %T", obj)
			}
			if len(txnResp.Results) != 2 {
				t.Fatalf("bad: %v", txnResp)
			}

			index = txnResp.Results[0].KV.ModifyIndex
			entMeta := txnResp.Results[0].KV.EnterpriseMeta

			expected := structs.TxnResponse{
				Results: structs.TxnResults{
					&structs.TxnResult{
						KV: &structs.DirEntry{
							Key:       "key",
							Value:     nil,
							Flags:     23,
							Session:   id,
							LockIndex: 1,
							RaftIndex: structs.RaftIndex{
								CreateIndex: index,
								ModifyIndex: index,
							},
							EnterpriseMeta: entMeta,
						},
					},
					&structs.TxnResult{
						KV: &structs.DirEntry{
							Key:       "key",
							Value:     []byte("hello world"),
							Flags:     23,
							Session:   id,
							LockIndex: 1,
							RaftIndex: structs.RaftIndex{
								CreateIndex: index,
								ModifyIndex: index,
							},
							EnterpriseMeta: entMeta,
						},
					},
				},
			}
			assert.Equal(t, expected, txnResp)
		}

		// Do a read-only transaction that should get routed to the
		// fast-path endpoint.
		{
			buf := bytes.NewBuffer([]byte(`
 [
     {
         "KV": {
             "Verb": "get",
             "Key": "key"
         }
     },
     {
         "KV": {
             "Verb": "get-tree",
             "Key": "key"
         }
     }
 ]
 `))
			req, _ := http.NewRequest("PUT", "/v1/txn", buf)
			resp := httptest.NewRecorder()
			obj, err := a.srv.Txn(resp, req)
			if err != nil {
				t.Fatalf("err: %v", err)
			}
			if resp.Code != 200 {
				t.Fatalf("expected 200, got %d", resp.Code)
			}

			header := resp.Header().Get("X-Consul-KnownLeader")
			if header != "true" {
				t.Fatalf("bad: %v", header)
			}
			header = resp.Header().Get("X-Consul-LastContact")
			if header != "0" {
				t.Fatalf("bad: %v", header)
			}

			txnResp, ok := obj.(structs.TxnReadResponse)
			if !ok {
				t.Fatalf("bad type: %T", obj)
			}
			entMeta := txnResp.Results[0].KV.EnterpriseMeta
			expected := structs.TxnReadResponse{
				TxnResponse: structs.TxnResponse{
					Results: structs.TxnResults{
						&structs.TxnResult{
							KV: &structs.DirEntry{
								Key:       "key",
								Value:     []byte("hello world"),
								Flags:     23,
								Session:   id,
								LockIndex: 1,
								RaftIndex: structs.RaftIndex{
									CreateIndex: index,
									ModifyIndex: index,
								},
								EnterpriseMeta: entMeta,
							},
						},
						&structs.TxnResult{
							KV: &structs.DirEntry{
								Key:       "key",
								Value:     []byte("hello world"),
								Flags:     23,
								Session:   id,
								LockIndex: 1,
								RaftIndex: structs.RaftIndex{
									CreateIndex: index,
									ModifyIndex: index,
								},
								EnterpriseMeta: entMeta,
							},
						},
					},
				},
				QueryMeta: structs.QueryMeta{
					KnownLeader: true,
				},
			}
			assert.Equal(t, expected, txnResp)
		}

		// Now that we have an index we can do a CAS to make sure the
		// index field gets translated to the RPC format.
		{
			buf := bytes.NewBuffer([]byte(fmt.Sprintf(`
 [
     {
         "KV": {
             "Verb": "cas",
             "Key": "key",
             "Value": "Z29vZGJ5ZSB3b3JsZA==",
             "Index": %d
         }
     },
     {
         "KV": {
             "Verb": "get",
             "Key": "key"
         }
     }
 ]
 `, index)))
			req, _ := http.NewRequest("PUT", "/v1/txn", buf)
			resp := httptest.NewRecorder()
			obj, err := a.srv.Txn(resp, req)
			if err != nil {
				t.Fatalf("err: %v", err)
			}
			if resp.Code != 200 {
				t.Fatalf("expected 200, got %d", resp.Code)
			}

			txnResp, ok := obj.(structs.TxnResponse)
			if !ok {
				t.Fatalf("bad type: %T", obj)
			}
			if len(txnResp.Results) != 2 {
				t.Fatalf("bad: %v", txnResp)
			}

			modIndex := txnResp.Results[0].KV.ModifyIndex
			entMeta := txnResp.Results[0].KV.EnterpriseMeta

			expected := structs.TxnResponse{
				Results: structs.TxnResults{
					&structs.TxnResult{
						KV: &structs.DirEntry{
							Key:     "key",
							Value:   nil,
							Session: id,
							RaftIndex: structs.RaftIndex{
								CreateIndex: index,
								ModifyIndex: modIndex,
							},
							EnterpriseMeta: entMeta,
						},
					},
					&structs.TxnResult{
						KV: &structs.DirEntry{
							Key:     "key",
							Value:   []byte("goodbye world"),
							Session: id,
							RaftIndex: structs.RaftIndex{
								CreateIndex: index,
								ModifyIndex: modIndex,
							},
							EnterpriseMeta: entMeta,
						},
					},
				},
			}
			assert.Equal(t, expected, txnResp)
		}
	})

	// Verify an error inside a transaction.
	t.Run("", func(t *testing.T) {
		a := NewTestAgent(t, "")
		defer a.Shutdown()

		buf := bytes.NewBuffer([]byte(`
 [
     {
         "KV": {
             "Verb": "lock",
             "Key": "key",
             "Value": "aGVsbG8gd29ybGQ=",
             "Session": "nope"
         }
     },
     {
         "KV": {
             "Verb": "get",
             "Key": "key"
         }
     }
 ]
 `))
		req, _ := http.NewRequest("PUT", "/v1/txn", buf)
		resp := httptest.NewRecorder()
		if _, err := a.srv.Txn(resp, req); err != nil {
			t.Fatalf("err: %v", err)
		}
		if resp.Code != 409 {
			t.Fatalf("expected 409, got %d", resp.Code)
		}
		if !bytes.Contains(resp.Body.Bytes(), []byte("failed session lookup")) {
			t.Fatalf("bad: %s", resp.Body.String())
		}
	})
}

func TestTxnEndpoint_UpdateCheck(t *testing.T) {
	t.Parallel()
	a := NewTestAgent(t, "")
	defer a.Shutdown()
	testrpc.WaitForTestAgent(t, a.RPC, "dc1")

	// Make sure the fields of a check are handled correctly when both creating and
	// updating, and test both sets of duration fields to ensure backwards compatibility.
	buf := bytes.NewBuffer([]byte(fmt.Sprintf(`
[
	{
		"Check": {
			"Verb": "set",
			"Check": {
				"Node": "%s",
				"CheckID": "nodecheck",
				"Name": "Node http check",
				"Status": "critical",
				"Notes": "Http based health check",
				"Output": "",
				"ServiceID": "",
				"ServiceName": "",
				"Definition": {
					"Interval": "6s",
					"Timeout": "6s",
					"DeregisterCriticalServiceAfter": "6s",
					"HTTP": "http://localhost:8000",
					"TLSSkipVerify": true
				}
			}
		}
	},
	{
		"Check": {
			"Verb": "set",
			"Check": {
				"Node": "%s",
				"CheckID": "nodecheck",
				"Name": "Node http check",
				"Status": "passing",
				"Notes": "Http based health check",
				"Output": "success",
				"ServiceID": "",
				"ServiceName": "",
				"Definition": {
					"Interval": "10s",
					"Timeout": "10s",
					"DeregisterCriticalServiceAfter": "15m",
					"HTTP": "http://localhost:9000",
					"TLSSkipVerify": false
				}
			}
		}
	},
	{
		"Check": {
			"Verb": "set",
			"Check": {
				"Node": "%s",
				"CheckID": "nodecheck",
				"Name": "Node http check",
				"Status": "passing",
				"Notes": "Http based health check",
				"Output": "success",
				"ServiceID": "",
				"ServiceName": "",
				"Definition": {
					"IntervalDuration": "15s",
					"TimeoutDuration": "15s",
					"DeregisterCriticalServiceAfterDuration": "30m",
					"HTTP": "http://localhost:9000",
					"TLSSkipVerify": false
				}
			}
		}
	}
]
`, a.config.NodeName, a.config.NodeName, a.config.NodeName)))
	req, _ := http.NewRequest("PUT", "/v1/txn", buf)
	resp := httptest.NewRecorder()
	obj, err := a.srv.Txn(resp, req)
	if err != nil {
		t.Fatalf("err: %v", err)
	}
	if resp.Code != 200 {
		t.Fatalf("expected 200, got %d", resp.Code)
	}

	txnResp, ok := obj.(structs.TxnResponse)
	if !ok {
		t.Fatalf("bad type: %T", obj)
	}
	if len(txnResp.Results) != 3 {
		t.Fatalf("bad: %v", txnResp)
	}
	index := txnResp.Results[0].Check.ModifyIndex
	expected := structs.TxnResponse{
		Results: structs.TxnResults{
			&structs.TxnResult{
				Check: &structs.HealthCheck{
					Node:    a.config.NodeName,
					CheckID: "nodecheck",
					Name:    "Node http check",
					Status:  api.HealthCritical,
					Notes:   "Http based health check",
					Definition: structs.HealthCheckDefinition{
						Interval:                       6 * time.Second,
						Timeout:                        6 * time.Second,
						DeregisterCriticalServiceAfter: 6 * time.Second,
						HTTP:                           "http://localhost:8000",
						TLSSkipVerify:                  true,
					},
					RaftIndex: structs.RaftIndex{
						CreateIndex: index,
						ModifyIndex: index,
					},
					EnterpriseMeta: *structs.DefaultEnterpriseMeta(),
				},
			},
			&structs.TxnResult{
				Check: &structs.HealthCheck{
					Node:    a.config.NodeName,
					CheckID: "nodecheck",
					Name:    "Node http check",
					Status:  api.HealthPassing,
					Notes:   "Http based health check",
					Output:  "success",
					Definition: structs.HealthCheckDefinition{
						Interval:                       10 * time.Second,
						Timeout:                        10 * time.Second,
						DeregisterCriticalServiceAfter: 15 * time.Minute,
						HTTP:                           "http://localhost:9000",
						TLSSkipVerify:                  false,
					},
					RaftIndex: structs.RaftIndex{
						CreateIndex: index,
						ModifyIndex: index,
					},
					EnterpriseMeta: *structs.DefaultEnterpriseMeta(),
				},
			},
			&structs.TxnResult{
				Check: &structs.HealthCheck{
					Node:    a.config.NodeName,
					CheckID: "nodecheck",
					Name:    "Node http check",
					Status:  api.HealthPassing,
					Notes:   "Http based health check",
					Output:  "success",
					Definition: structs.HealthCheckDefinition{
						Interval:                       15 * time.Second,
						Timeout:                        15 * time.Second,
						DeregisterCriticalServiceAfter: 30 * time.Minute,
						HTTP:                           "http://localhost:9000",
						TLSSkipVerify:                  false,
					},
					RaftIndex: structs.RaftIndex{
						CreateIndex: index,
						ModifyIndex: index,
					},
					EnterpriseMeta: *structs.DefaultEnterpriseMeta(),
				},
			},
		},
	}
	assert.Equal(t, expected, txnResp)
}