consul/internal/storage/raft/forwarding.go

269 lines
6.5 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: MPL-2.0
package raft
import (
"context"
"errors"
"net"
"sync"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/emptypb"
"github.com/hashicorp/go-hclog"
grpcinternal "github.com/hashicorp/consul/agent/grpc-internal"
"github.com/hashicorp/consul/internal/storage"
pbstorage "github.com/hashicorp/consul/proto/private/pbstorage"
)
// forwardingServer implements the gRPC forwarding service.
type forwardingServer struct {
backend *Backend
listener *grpcinternal.Listener
}
var _ pbstorage.ForwardingServiceServer = (*forwardingServer)(nil)
func newForwardingServer(backend *Backend) *forwardingServer {
return &forwardingServer{
backend: backend,
// The address here doesn't actually matter. gRPC uses it as an identifier
// internally, but we only bind the server to a single listener.
listener: grpcinternal.NewListener(&net.TCPAddr{
IP: net.ParseIP("0.0.0.0"),
Port: 0,
}),
}
}
func (s *forwardingServer) Write(ctx context.Context, req *pbstorage.WriteRequest) (*pbstorage.WriteResponse, error) {
rsp, err := s.raftApply(ctx, &pbstorage.Log{
Type: pbstorage.LogType_LOG_TYPE_WRITE,
Request: &pbstorage.Log_Write{Write: req},
})
if err != nil {
return nil, err
}
return rsp.GetWrite(), nil
}
func (s *forwardingServer) Delete(ctx context.Context, req *pbstorage.DeleteRequest) (*emptypb.Empty, error) {
_, err := s.raftApply(ctx, &pbstorage.Log{
Type: pbstorage.LogType_LOG_TYPE_DELETE,
Request: &pbstorage.Log_Delete{Delete: req},
})
if err != nil {
return nil, err
}
return &emptypb.Empty{}, nil
}
func (s *forwardingServer) Read(ctx context.Context, req *pbstorage.ReadRequest) (*pbstorage.ReadResponse, error) {
res, err := s.backend.leaderRead(ctx, req.Id)
if err != nil {
return nil, wrapError(err)
}
return &pbstorage.ReadResponse{Resource: res}, nil
}
func (s *forwardingServer) List(ctx context.Context, req *pbstorage.ListRequest) (*pbstorage.ListResponse, error) {
res, err := s.backend.leaderList(ctx, storage.UnversionedTypeFrom(req.Type), req.Tenancy, req.NamePrefix)
if err != nil {
return nil, wrapError(err)
}
return &pbstorage.ListResponse{Resources: res}, nil
}
func (s *forwardingServer) raftApply(_ context.Context, req *pbstorage.Log) (*pbstorage.LogResponse, error) {
msg, err := req.MarshalBinary()
if err != nil {
return nil, wrapError(err)
}
rsp, err := s.backend.handle.Apply(msg)
if err != nil {
return nil, wrapError(err)
}
switch t := rsp.(type) {
case *pbstorage.LogResponse:
return t, nil
default:
return nil, status.Errorf(codes.Internal, "unexpected response from Raft apply: %T", rsp)
}
}
func (s *forwardingServer) run(ctx context.Context) error {
server := grpc.NewServer()
pbstorage.RegisterForwardingServiceServer(server, s)
go func() {
<-ctx.Done()
server.Stop()
}()
return server.Serve(s.listener)
}
// forwardingClient is used to forward operations to the leader.
type forwardingClient struct {
handle Handle
logger hclog.Logger
mu sync.RWMutex
conn *grpc.ClientConn
}
func newForwardingClient(h Handle, l hclog.Logger) *forwardingClient {
return &forwardingClient{
handle: h,
logger: l,
}
}
func (c *forwardingClient) leaderChanged() {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn == nil {
return
}
if err := c.conn.Close(); err != nil {
c.logger.Error("failed to close connection to previous leader", "error", err)
}
c.conn = nil
}
func (c *forwardingClient) getConn() (*grpc.ClientConn, error) {
c.mu.Lock()
defer c.mu.Unlock()
if c.conn != nil {
return c.conn, nil
}
conn, err := c.handle.DialLeader()
if err != nil {
c.logger.Error("failed to dial leader", "error", err)
return nil, err
}
c.conn = conn
return conn, nil
}
func (c *forwardingClient) getClient() (pbstorage.ForwardingServiceClient, error) {
conn, err := c.getConn()
if err != nil {
return nil, err
}
return pbstorage.NewForwardingServiceClient(conn), nil
}
func (c *forwardingClient) delete(ctx context.Context, req *pbstorage.DeleteRequest) error {
client, err := c.getClient()
if err != nil {
return err
}
_, err = client.Delete(ctx, req)
return unwrapError(err)
}
func (c *forwardingClient) write(ctx context.Context, req *pbstorage.WriteRequest) (*pbstorage.WriteResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
rsp, err := client.Write(ctx, req)
return rsp, unwrapError(err)
}
func (c *forwardingClient) read(ctx context.Context, req *pbstorage.ReadRequest) (*pbstorage.ReadResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
rsp, err := client.Read(ctx, req)
return rsp, unwrapError(err)
}
func (c *forwardingClient) list(ctx context.Context, req *pbstorage.ListRequest) (*pbstorage.ListResponse, error) {
client, err := c.getClient()
if err != nil {
return nil, err
}
rsp, err := client.List(ctx, req)
return rsp, unwrapError(err)
}
var (
errorToCode = map[error]codes.Code{
// Note: OutOfRange is used to represent GroupVersionMismatchError, but is
// handled specially in wrapError and unwrapError because it has extra details.
storage.ErrNotFound: codes.NotFound,
storage.ErrCASFailure: codes.Aborted,
storage.ErrWrongUid: codes.AlreadyExists,
storage.ErrInconsistent: codes.FailedPrecondition,
}
codeToError = func() map[codes.Code]error {
inverted := make(map[codes.Code]error, len(errorToCode))
for k, v := range errorToCode {
inverted[v] = k
}
return inverted
}()
)
// wrapError converts the given error to a gRPC status to send over the wire.
func wrapError(err error) error {
var gvm storage.GroupVersionMismatchError
if errors.As(err, &gvm) {
s, err := status.New(codes.OutOfRange, err.Error()).
WithDetails(&pbstorage.GroupVersionMismatchErrorDetails{
RequestedType: gvm.RequestedType,
Stored: gvm.Stored,
})
if err == nil {
return s.Err()
}
}
code, ok := errorToCode[err]
if !ok {
code = codes.Internal
}
return status.Error(code, err.Error())
}
// unwrapError converts the given gRPC status error back to a storage package
// error.
func unwrapError(err error) error {
s, ok := status.FromError(err)
if !ok {
return err
}
for _, d := range s.Details() {
if gvm, ok := d.(*pbstorage.GroupVersionMismatchErrorDetails); ok {
return storage.GroupVersionMismatchError{
RequestedType: gvm.RequestedType,
Stored: gvm.Stored,
}
}
}
unwrapped, ok := codeToError[s.Code()]
if !ok {
return err
}
return unwrapped
}