consul/command/snapshot/save/snapshot_save_test.go
hashicorp-copywrite[bot] 5fb9df1640
[COMPLIANCE] License changes (#18443)
* Adding explicit MPL license for sub-package

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 09:12:13 -04:00

215 lines
4.7 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package save
import (
"crypto/rand"
"fmt"
"io"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"strings"
"sync/atomic"
"testing"
"github.com/mitchellh/cli"
"github.com/stretchr/testify/require"
"github.com/hashicorp/consul/agent"
"github.com/hashicorp/consul/api"
"github.com/hashicorp/consul/sdk/testutil"
"github.com/hashicorp/consul/sdk/testutil/retry"
)
func TestSnapshotSaveCommand_noTabs(t *testing.T) {
t.Parallel()
if strings.ContainsRune(New(cli.NewMockUi()).Help(), '\t') {
t.Fatal("help has tabs")
}
}
func TestSnapshotSaveCommand_Validation(t *testing.T) {
t.Parallel()
cases := map[string]struct {
args []string
output string
}{
"no file": {
[]string{},
"Missing FILE argument",
},
"extra args": {
[]string{"foo", "bar", "baz"},
"Too many arguments",
},
}
for name, tc := range cases {
ui := cli.NewMockUi()
c := New(ui)
// Ensure our buffer is always clear
if ui.ErrorWriter != nil {
ui.ErrorWriter.Reset()
}
if ui.OutputWriter != nil {
ui.OutputWriter.Reset()
}
code := c.Run(tc.args)
if code == 0 {
t.Errorf("%s: expected non-zero exit", name)
}
output := ui.ErrorWriter.String()
if !strings.Contains(output, tc.output) {
t.Errorf("%s: expected %q to contain %q", name, output, tc.output)
}
}
}
func TestSnapshotSaveCommand(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := agent.NewTestAgent(t, ``)
defer a.Shutdown()
client := a.Client()
ui := cli.NewMockUi()
c := New(ui)
dir := testutil.TempDir(t, "snapshot")
file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + a.HTTPAddr(),
file,
}
code := c.Run(args)
if code != 0 {
t.Fatalf("bad: %d. %#v", code, ui.ErrorWriter.String())
}
fi, err := os.Stat(file)
require.NoError(t, err)
require.Equal(t, fi.Mode(), os.FileMode(0600))
f, err := os.Open(file)
if err != nil {
t.Fatalf("err: %v", err)
}
defer f.Close()
if err := client.Snapshot().Restore(nil, f); err != nil {
t.Fatalf("err: %v", err)
}
}
func TestSnapshotSaveCommand_TruncatedStream(t *testing.T) {
if testing.Short() {
t.Skip("too slow for testing.Short")
}
t.Parallel()
a := agent.NewTestAgent(t, ``)
defer a.Shutdown()
client := a.Client()
// Seed it with 64K of random data just so we have something to work with.
{
blob := make([]byte, 64*1024)
_, err := rand.Read(blob)
require.NoError(t, err)
_, err = client.KV().Put(&api.KVPair{Key: "blob", Value: blob}, nil)
require.NoError(t, err)
}
// Do a manual snapshot so we can send back roughly reasonable data.
var inputData []byte
{
rc, _, err := client.Snapshot().Save(nil)
require.NoError(t, err)
defer rc.Close()
inputData, err = io.ReadAll(rc)
require.NoError(t, err)
}
var fakeResult atomic.Value
// Run a fake webserver to pretend to be the snapshot API.
srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
if req.URL.Path != "/v1/snapshot" {
w.WriteHeader(http.StatusNotFound)
return
}
if req.Method != "GET" {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
raw := fakeResult.Load()
if raw == nil {
w.WriteHeader(http.StatusNotFound)
return
}
data := raw.([]byte)
_, _ = w.Write(data)
}))
t.Cleanup(srv.Close)
// Wait until the server is actually listening.
retry.Run(t, func(r *retry.R) {
resp, err := srv.Client().Get(srv.URL + "/not-real")
require.NoError(r, err)
require.Equal(r, http.StatusNotFound, resp.StatusCode)
})
dir := testutil.TempDir(t, "snapshot")
for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} {
t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) {
// Lop off part of the end.
data := inputData[0 : len(inputData)-removeBytes]
fakeResult.Store(data)
ui := cli.NewMockUi()
c := New(ui)
file := filepath.Join(dir, "backup.tgz")
args := []string{
"-http-addr=" + srv.Listener.Addr().String(), // point to the fake
file,
}
code := c.Run(args)
require.Equal(t, 1, code, "expected non-zero exit")
output := ui.ErrorWriter.String()
require.Contains(t, output, "Error verifying snapshot file")
require.Contains(t, output, "EOF")
// file should not have been created
_, err := os.Stat(file)
require.Error(t, err, "file is not supposed to exist")
require.True(t, os.IsNotExist(err), "file is not supposed to exist")
// also check that the unverified inputs are gone as well
_, err = os.Stat(file + ".unverified")
require.Error(t, err, "unverified file is not supposed to exist")
require.True(t, os.IsNotExist(err), "unverified file is not supposed to exist")
})
}
}