consul/agent/consul/raft_rpc.go
hashicorp-copywrite[bot] 5fb9df1640
[COMPLIANCE] License changes (#18443)
* Adding explicit MPL license for sub-package

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Adding explicit MPL license for sub-package

This directory and its subdirectories (packages) contain files licensed with the MPLv2 `LICENSE` file in this directory and are intentionally licensed separately from the BSL `LICENSE` file at the root of this repository.

* Updating the license from MPL to Business Source License

Going forward, this project will be licensed under the Business Source License v1.1. Please see our blog post for more details at <Blog URL>, FAQ at www.hashicorp.com/licensing-faq, and details of the license at www.hashicorp.com/bsl.

* add missing license headers

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

* Update copyright file headers to BUSL-1.1

---------

Co-authored-by: hashicorp-copywrite[bot] <110428419+hashicorp-copywrite[bot]@users.noreply.github.com>
2023-08-11 09:12:13 -04:00

128 lines
2.9 KiB
Go

// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1
package consul
import (
"fmt"
"net"
"sync"
"time"
"github.com/hashicorp/raft"
"github.com/hashicorp/consul/agent/pool"
"github.com/hashicorp/consul/tlsutil"
)
// RaftLayer implements the raft.StreamLayer interface,
// so that we can use a single RPC layer for Raft and Consul
type RaftLayer struct {
// src is the address for outgoing connections.
src net.Addr
// addr is the listener address to return.
addr net.Addr
// connCh is used to accept connections.
connCh chan net.Conn
// TLS wrapper
tlsWrap tlsutil.Wrapper
// Tracks if we are closed
closed bool
closeCh chan struct{}
closeLock sync.Mutex
// tlsFunc is a callback to determine whether to use TLS for connecting to
// a given Raft server
tlsFunc func(raft.ServerAddress) bool
}
// NewRaftLayer is used to initialize a new RaftLayer which can
// be used as a StreamLayer for Raft. If a tlsConfig is provided,
// then the connection will use TLS.
func NewRaftLayer(src, addr net.Addr, tlsWrap tlsutil.Wrapper, tlsFunc func(raft.ServerAddress) bool) *RaftLayer {
layer := &RaftLayer{
src: src,
addr: addr,
connCh: make(chan net.Conn),
tlsWrap: tlsWrap,
closeCh: make(chan struct{}),
tlsFunc: tlsFunc,
}
return layer
}
// Handoff is used to hand off a connection to the
// RaftLayer. This allows it to be Accept()'ed
func (l *RaftLayer) Handoff(c net.Conn) error {
select {
case l.connCh <- c:
return nil
case <-l.closeCh:
return fmt.Errorf("Raft RPC layer closed")
}
}
// Accept is used to return connection which are
// dialed to be used with the Raft layer
func (l *RaftLayer) Accept() (net.Conn, error) {
select {
case conn := <-l.connCh:
return conn, nil
case <-l.closeCh:
return nil, fmt.Errorf("Raft RPC layer closed")
}
}
// Close is used to stop listening for Raft connections
func (l *RaftLayer) Close() error {
l.closeLock.Lock()
defer l.closeLock.Unlock()
if !l.closed {
l.closed = true
close(l.closeCh)
}
return nil
}
// Addr is used to return the address of the listener
func (l *RaftLayer) Addr() net.Addr {
return l.addr
}
// Dial is used to create a new outgoing connection
func (l *RaftLayer) Dial(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) {
d := &net.Dialer{LocalAddr: l.src, Timeout: timeout}
conn, err := d.Dial("tcp", string(address))
if err != nil {
return nil, err
}
// Check for tls mode
if l.tlsFunc(address) && l.tlsWrap != nil {
// Switch the connection into TLS mode
if _, err := conn.Write([]byte{byte(pool.RPCTLS)}); err != nil {
conn.Close()
return nil, err
}
// Wrap the connection in a TLS client
conn, err = l.tlsWrap(conn)
if err != nil {
return nil, err
}
}
// Write the Raft byte to set the mode
_, err = conn.Write([]byte{byte(pool.RPCRaft)})
if err != nil {
conn.Close()
return nil, err
}
return conn, err
}