go-libp2p-webrtc-direct/conn.go

433 lines
8.6 KiB
Go
Raw Normal View History

package libp2pwebrtcdirect
import (
"context"
"errors"
"fmt"
"io"
"io/ioutil"
"math"
"net"
"net/http"
"sync"
"time"
ic "github.com/libp2p/go-libp2p-core/crypto"
smux "github.com/libp2p/go-libp2p-core/mux"
peer "github.com/libp2p/go-libp2p-core/peer"
tpt "github.com/libp2p/go-libp2p-core/transport"
ma "github.com/multiformats/go-multiaddr"
2021-05-19 17:41:54 +00:00
manet "github.com/multiformats/go-multiaddr/net"
2019-04-06 15:47:28 +00:00
"github.com/pion/datachannel"
"github.com/pion/webrtc/v3"
)
type connConfig struct {
transport *Transport
maAddr ma.Multiaddr
addr net.Addr
isServer bool
2019-02-25 19:36:24 +00:00
remoteID peer.ID
}
func newConnConfig(transport *Transport, maAddr ma.Multiaddr, isServer bool) (*connConfig, error) {
httpMa := maAddr.Decapsulate(webrtcma)
tcpMa := httpMa.Decapsulate(httpma)
addr, err := manet.ToNetAddr(tcpMa)
if err != nil {
return nil, fmt.Errorf("failed to get net addr: %v", err)
}
return &connConfig{
transport: transport,
maAddr: maAddr,
addr: addr,
isServer: isServer,
}, nil
}
// Conn is a stream-multiplexing connection to a remote peer.
type Conn struct {
config *connConfig
2019-02-21 14:19:05 +00:00
peerConnection *webrtc.PeerConnection
initChannel datachannel.ReadWriteCloser
lock sync.RWMutex
accept chan chan detachResult
isMuxed bool
muxedConn smux.MuxedConn
}
func newConn(config *connConfig, pc *webrtc.PeerConnection, initChannel datachannel.ReadWriteCloser) *Conn {
conn := &Conn{
config: config,
peerConnection: pc,
initChannel: initChannel,
accept: make(chan chan detachResult),
isMuxed: config.transport.muxer != nil,
}
2019-02-21 14:19:05 +00:00
pc.OnDataChannel(func(dc *webrtc.DataChannel) {
// We have to detach in OnDataChannel
detachRes := detachChannel(dc)
conn.accept <- detachRes
})
return conn
}
func dial(ctx context.Context, config *connConfig) (*Conn, error) {
api := config.transport.api
2019-02-21 14:19:05 +00:00
pc, err := api.NewPeerConnection(config.transport.webrtcOptions)
if err != nil {
return nil, err
}
dc, err := pc.CreateDataChannel("data", nil)
if err != nil {
return nil, err
}
detachRes := detachChannel(dc)
offer, err := pc.CreateOffer(nil)
if err != nil {
return nil, err
}
2021-03-21 13:23:34 +00:00
// Complete ICE Gathering for single-shot signaling.
gatherComplete := webrtc.GatheringCompletePromise(pc)
2019-02-21 14:19:05 +00:00
err = pc.SetLocalDescription(offer)
if err != nil {
return nil, err
}
2021-03-21 13:23:34 +00:00
<-gatherComplete
2019-02-21 14:19:05 +00:00
2021-03-21 13:23:34 +00:00
offerEnc, err := encodeSignal(*pc.LocalDescription())
if err != nil {
return nil, err
}
req, err := http.NewRequest("GET", "http://"+config.addr.String()+"/?signal="+offerEnc, nil)
if err != nil {
return nil, err
}
req = req.WithContext(ctx)
var client = &http.Client{}
resp, err := client.Do(req)
if err != nil {
return nil, err
}
defer resp.Body.Close()
answerEnc, err := ioutil.ReadAll(resp.Body)
if err != nil && err != io.EOF {
return nil, err
}
answer, err := decodeSignal(string(answerEnc))
if err != nil {
return nil, err
}
if err := pc.SetRemoteDescription(answer); err != nil {
return nil, err
}
select {
case res := <-detachRes:
if res.err != nil {
return nil, res.err
}
return newConn(config, pc, res.dc), nil
case <-ctx.Done():
return newConn(config, pc, nil), ctx.Err()
}
}
type detachResult struct {
dc datachannel.ReadWriteCloser
err error
}
2019-02-21 14:19:05 +00:00
func detachChannel(dc *webrtc.DataChannel) chan detachResult {
onOpenRes := make(chan detachResult)
dc.OnOpen(func() {
// Detach the data channel
raw, err := dc.Detach()
onOpenRes <- detachResult{raw, err}
})
return onOpenRes
}
// Close closes the stream muxer and the the underlying net.Conn.
func (c *Conn) Close() error {
c.lock.Lock()
defer c.lock.Unlock()
var err error
if c.peerConnection != nil {
err = c.peerConnection.Close()
}
c.peerConnection = nil
close(c.accept)
return err
}
// IsClosed returns whether a connection is fully closed, so it can
// be garbage collected.
func (c *Conn) IsClosed() bool {
c.lock.RLock()
pc := c.peerConnection
c.lock.RUnlock()
return pc == nil
}
// OpenStream creates a new stream.
func (c *Conn) OpenStream(ctx context.Context) (smux.MuxedStream, error) {
muxed, err := c.getMuxed()
if err != nil {
return nil, err
}
if muxed != nil {
return muxed.OpenStream(ctx)
}
rawDC := c.checkInitChannel()
if rawDC == nil {
pc, err := c.getPC()
if err != nil {
return nil, err
}
dc, err := pc.CreateDataChannel("data", nil)
if err != nil {
return nil, err
}
detachRes := detachChannel(dc)
res := <-detachRes
if res.err != nil {
return nil, res.err
}
rawDC = res.dc
}
return newStream(rawDC), nil
}
2019-02-21 14:19:05 +00:00
func (c *Conn) getPC() (*webrtc.PeerConnection, error) {
c.lock.RLock()
pc := c.peerConnection
c.lock.RUnlock()
if pc == nil {
return nil, errors.New("Conn closed")
}
return pc, nil
}
func (c *Conn) getMuxed() (smux.MuxedConn, error) {
c.lock.Lock()
defer c.lock.Unlock()
if !c.isMuxed {
return nil, nil
}
if c.muxedConn != nil {
return c.muxedConn, nil
}
rawDC := c.initChannel
if rawDC == nil {
var err error
rawDC, err = c.awaitAccept()
if err != nil {
return nil, err
}
}
err := c.useMuxer(&dcWrapper{channel: rawDC, addr: c.config.addr, buf: make([]byte, dcWrapperBufSize)}, c.config.transport.muxer)
if err != nil {
return nil, err
}
return c.muxedConn, nil
}
// Note: caller should hold the conn lock.
func (c *Conn) useMuxer(conn net.Conn, muxer smux.Multiplexer) error {
muxed, err := muxer.NewConn(conn, c.config.isServer)
if err != nil {
return err
}
c.muxedConn = muxed
return nil
}
func (c *Conn) checkInitChannel() datachannel.ReadWriteCloser {
c.lock.Lock()
defer c.lock.Unlock()
// Since a WebRTC offer can't be empty the offering side will have
// an initial data channel opened. We return it here, the first time
// OpenStream is called.
if c.initChannel != nil {
ch := c.initChannel
c.initChannel = nil
return ch
}
return nil
}
// AcceptStream accepts a stream opened by the other side.
func (c *Conn) AcceptStream() (smux.MuxedStream, error) {
muxed, err := c.getMuxed()
if err != nil {
return nil, err
}
if muxed != nil {
return muxed.AcceptStream()
}
rawDC := c.checkInitChannel()
if rawDC == nil {
rawDC, err = c.awaitAccept()
2021-05-19 17:41:54 +00:00
if err != nil {
return nil, err
}
}
return newStream(rawDC), nil
}
func (c *Conn) awaitAccept() (datachannel.ReadWriteCloser, error) {
detachRes, ok := <-c.accept
if !ok {
return nil, errors.New("Conn closed")
}
res := <-detachRes
return res.dc, res.err
}
// LocalPeer returns our peer ID
func (c *Conn) LocalPeer() peer.ID {
2019-02-25 19:36:24 +00:00
// TODO: Base on WebRTC security?
return c.config.transport.localID
}
// LocalPrivateKey returns our private key
func (c *Conn) LocalPrivateKey() ic.PrivKey {
2019-02-25 19:36:24 +00:00
// TODO: Base on WebRTC security?
return nil
}
// RemotePeer returns the peer ID of the remote peer.
func (c *Conn) RemotePeer() peer.ID {
2019-02-25 19:36:24 +00:00
// TODO: Base on WebRTC security?
return c.config.remoteID
}
// RemotePublicKey returns the public key of the remote peer.
func (c *Conn) RemotePublicKey() ic.PubKey {
2019-02-25 19:36:24 +00:00
// TODO: Base on WebRTC security?
return nil
}
// LocalMultiaddr returns the local Multiaddr associated
// with this connection
func (c *Conn) LocalMultiaddr() ma.Multiaddr {
return c.config.maAddr
}
// RemoteMultiaddr returns the remote Multiaddr associated
// with this connection
func (c *Conn) RemoteMultiaddr() ma.Multiaddr {
return c.config.maAddr
}
// Transport returns the transport to which this connection belongs.
func (c *Conn) Transport() tpt.Transport {
return c.config.transport
}
// Limit message size until we have a better
// packetizing strategy.
const dcWrapperBufSize = math.MaxUint16
// dcWrapper wraps datachannel.ReadWriteCloser to form a net.Conn
type dcWrapper struct {
channel datachannel.ReadWriteCloser
addr net.Addr
buf []byte
bufStart int
bufEnd int
}
func (w *dcWrapper) Read(p []byte) (int, error) {
var err error
if w.bufEnd == 0 {
n := 0
n, err = w.channel.Read(w.buf)
w.bufEnd = n
}
n := 0
if w.bufEnd-w.bufStart > 0 {
n = copy(p, w.buf[w.bufStart:w.bufEnd])
w.bufStart += n
if w.bufStart >= w.bufEnd {
w.bufStart = 0
w.bufEnd = 0
}
}
return n, err
}
func (w *dcWrapper) Write(p []byte) (n int, err error) {
if len(p) > dcWrapperBufSize {
return w.channel.Write(p[:dcWrapperBufSize])
}
return w.channel.Write(p)
}
func (w *dcWrapper) Close() error {
return w.channel.Close()
}
func (w *dcWrapper) LocalAddr() net.Addr {
return w.addr
}
func (w *dcWrapper) RemoteAddr() net.Addr {
return w.addr
}
func (w *dcWrapper) SetDeadline(t time.Time) error {
return nil
}
func (w *dcWrapper) SetReadDeadline(t time.Time) error {
return nil
}
func (w *dcWrapper) SetWriteDeadline(t time.Time) error {
return nil
}