// Copyright (c) HashiCorp, Inc. // SPDX-License-Identifier: BUSL-1.1 package snapshot import ( "bytes" "crypto/rand" "fmt" "io" "path/filepath" "strings" "sync" "testing" "time" "github.com/hashicorp/consul-net-rpc/go-msgpack/codec" "github.com/hashicorp/consul/agent/structs" "github.com/hashicorp/consul/sdk/testutil" "github.com/hashicorp/raft" "github.com/stretchr/testify/require" ) // MockFSM is a simple FSM for testing that simply stores its logs in a slice of // byte slices. type MockFSM struct { sync.Mutex logs [][]byte } // MockSnapshot is a snapshot sink for testing that encodes the contents of a // MockFSM using msgpack. type MockSnapshot struct { logs [][]byte maxIndex int } // See raft.FSM. func (m *MockFSM) Apply(log *raft.Log) interface{} { m.Lock() defer m.Unlock() m.logs = append(m.logs, log.Data) return len(m.logs) } // See raft.FSM. func (m *MockFSM) Snapshot() (raft.FSMSnapshot, error) { m.Lock() defer m.Unlock() return &MockSnapshot{m.logs, len(m.logs)}, nil } // See raft.FSM. func (m *MockFSM) Restore(in io.ReadCloser) error { m.Lock() defer m.Unlock() defer in.Close() dec := codec.NewDecoder(in, structs.MsgpackHandle) m.logs = nil return dec.Decode(&m.logs) } // See raft.SnapshotSink. func (m *MockSnapshot) Persist(sink raft.SnapshotSink) error { enc := codec.NewEncoder(sink, structs.MsgpackHandle) if err := enc.Encode(m.logs[:m.maxIndex]); err != nil { sink.Cancel() return err } sink.Close() return nil } // See raft.SnapshotSink. func (m *MockSnapshot) Release() { } // makeRaft returns a Raft and its FSM, with snapshots based in the given dir. func makeRaft(t *testing.T, dir string) (*raft.Raft, *MockFSM) { snaps, err := raft.NewFileSnapshotStore(dir, 5, nil) if err != nil { t.Fatalf("err: %v", err) } fsm := &MockFSM{} store := raft.NewInmemStore() addr, trans := raft.NewInmemTransport("") config := raft.DefaultConfig() config.LocalID = raft.ServerID(fmt.Sprintf("server-%s", addr)) var members raft.Configuration members.Servers = append(members.Servers, raft.Server{ Suffrage: raft.Voter, ID: config.LocalID, Address: addr, }) err = raft.BootstrapCluster(config, store, store, snaps, trans, members) if err != nil { t.Fatalf("err: %v", err) } raft, err := raft.NewRaft(config, fsm, store, store, snaps, trans) if err != nil { t.Fatalf("err: %v", err) } timeout := time.After(10 * time.Second) for { if raft.Leader() != "" { break } select { case <-raft.LeaderCh(): case <-time.After(1 * time.Second): // Need to poll because we might have missed the first // go with the leader channel. case <-timeout: t.Fatalf("timed out waiting for leader") } } return raft, fsm } func TestSnapshot(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } dir := testutil.TempDir(t, "snapshot") // Make a Raft and populate it with some data. We tee everything we // apply off to a buffer for checking post-snapshot. var expected []bytes.Buffer entries := 64 * 1024 before, _ := makeRaft(t, filepath.Join(dir, "before")) defer before.Shutdown() for i := 0; i < entries; i++ { var log bytes.Buffer var copy bytes.Buffer both := io.MultiWriter(&log, ©) if _, err := io.CopyN(both, rand.Reader, 256); err != nil { t.Fatalf("err: %v", err) } future := before.Apply(log.Bytes(), time.Second) if err := future.Error(); err != nil { t.Fatalf("err: %v", err) } expected = append(expected, copy) } // Take a snapshot. logger := testutil.Logger(t) snap, err := New(logger, before) if err != nil { t.Fatalf("err: %v", err) } defer snap.Close() // Verify the snapshot. We have to rewind it after for the restore. metadata, err := Verify(snap) if err != nil { t.Fatalf("err: %v", err) } if _, err := snap.file.Seek(0, 0); err != nil { t.Fatalf("err: %v", err) } if int(metadata.Index) != entries+2 { t.Fatalf("bad: %d", metadata.Index) } if metadata.Term != 2 { t.Fatalf("bad: %d", metadata.Index) } if metadata.Version != raft.SnapshotVersionMax { t.Fatalf("bad: %d", metadata.Version) } // Make a new, independent Raft. after, fsm := makeRaft(t, filepath.Join(dir, "after")) defer after.Shutdown() // Put some initial data in there that the snapshot should overwrite. for i := 0; i < 16; i++ { var log bytes.Buffer if _, err := io.CopyN(&log, rand.Reader, 256); err != nil { t.Fatalf("err: %v", err) } future := after.Apply(log.Bytes(), time.Second) if err := future.Error(); err != nil { t.Fatalf("err: %v", err) } } // Restore the snapshot. if err := Restore(logger, snap, after); err != nil { t.Fatalf("err: %v", err) } // Compare the contents. fsm.Lock() defer fsm.Unlock() if len(fsm.logs) != len(expected) { t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected)) } for i := range fsm.logs { if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) { t.Fatalf("bad: log %d doesn't match", i) } } } func TestSnapshot_Nil(t *testing.T) { var snap *Snapshot if idx := snap.Index(); idx != 0 { t.Fatalf("bad: %d", idx) } n, err := snap.Read(make([]byte, 16)) if n != 0 || err != io.EOF { t.Fatalf("bad: %d %v", n, err) } if err := snap.Close(); err != nil { t.Fatalf("err: %v", err) } } func TestSnapshot_BadVerify(t *testing.T) { buf := bytes.NewBuffer([]byte("nope")) _, err := Verify(buf) if err == nil || !strings.Contains(err.Error(), "unexpected EOF") { t.Fatalf("err: %v", err) } } func TestSnapshot_TruncatedVerify(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } dir := testutil.TempDir(t, "snapshot") // Make a Raft and populate it with some data. We tee everything we // apply off to a buffer for checking post-snapshot. entries := 64 * 1024 before, _ := makeRaft(t, filepath.Join(dir, "before")) defer before.Shutdown() for i := 0; i < entries; i++ { var log bytes.Buffer var copy bytes.Buffer both := io.MultiWriter(&log, ©) _, err := io.CopyN(both, rand.Reader, 256) require.NoError(t, err) future := before.Apply(log.Bytes(), time.Second) require.NoError(t, future.Error()) } // Take a snapshot. logger := testutil.Logger(t) snap, err := New(logger, before) require.NoError(t, err) defer snap.Close() var data []byte { var buf bytes.Buffer _, err = io.Copy(&buf, snap) require.NoError(t, err) data = buf.Bytes() } for _, removeBytes := range []int{200, 16, 8, 4, 2, 1} { t.Run(fmt.Sprintf("truncate %d bytes from end", removeBytes), func(t *testing.T) { // Lop off part of the end. buf := bytes.NewReader(data[0 : len(data)-removeBytes]) _, err = Verify(buf) require.Error(t, err) }) } } func TestSnapshot_BadRestore(t *testing.T) { if testing.Short() { t.Skip("too slow for testing.Short") } dir := testutil.TempDir(t, "snapshot") // Make a Raft and populate it with some data. before, _ := makeRaft(t, filepath.Join(dir, "before")) defer before.Shutdown() for i := 0; i < 16*1024; i++ { var log bytes.Buffer if _, err := io.CopyN(&log, rand.Reader, 256); err != nil { t.Fatalf("err: %v", err) } future := before.Apply(log.Bytes(), time.Second) if err := future.Error(); err != nil { t.Fatalf("err: %v", err) } } // Take a snapshot. logger := testutil.Logger(t) snap, err := New(logger, before) if err != nil { t.Fatalf("err: %v", err) } // Make a new, independent Raft. after, fsm := makeRaft(t, filepath.Join(dir, "after")) defer after.Shutdown() // Put some initial data in there that should not be harmed by the // failed restore attempt. var expected []bytes.Buffer for i := 0; i < 16; i++ { var log bytes.Buffer var copy bytes.Buffer both := io.MultiWriter(&log, ©) if _, err := io.CopyN(both, rand.Reader, 256); err != nil { t.Fatalf("err: %v", err) } future := after.Apply(log.Bytes(), time.Second) if err := future.Error(); err != nil { t.Fatalf("err: %v", err) } expected = append(expected, copy) } // Attempt to restore a truncated version of the snapshot. This is // expected to fail. err = Restore(logger, io.LimitReader(snap, 512), after) if err == nil || !strings.Contains(err.Error(), "unexpected EOF") { t.Fatalf("err: %v", err) } // Compare the contents to make sure the aborted restore didn't harm // anything. fsm.Lock() defer fsm.Unlock() if len(fsm.logs) != len(expected) { t.Fatalf("bad: %d vs. %d", len(fsm.logs), len(expected)) } for i := range fsm.logs { if !bytes.Equal(fsm.logs[i], expected[i].Bytes()) { t.Fatalf("bad: log %d doesn't match", i) } } }