2020-11-22 15:55:57 +01:00

247 lines
5.5 KiB
Go

package sshd
import (
"errors"
"fmt"
"net"
"sync"
"time"
"github.com/shazow/ssh-chat/sshd/terminal"
"golang.org/x/crypto/ssh"
)
var keepaliveInterval = time.Second * 30
var keepaliveRequest = "keepalive@ssh-chat"
// ErrNoSessionChannel is returned when there is no session channel.
var ErrNoSessionChannel = errors.New("no session channel")
// ErrNotSessionChannel is returned when a channel is not a session channel.
var ErrNotSessionChannel = errors.New("terminal requires session channel")
// Connection is an interface with fields necessary to operate an sshd host.
type Connection interface {
PublicKey() ssh.PublicKey
RemoteAddr() net.Addr
Name() string
ClientVersion() []byte
Close() error
}
type sshConn struct {
*ssh.ServerConn
}
func (c sshConn) PublicKey() ssh.PublicKey {
if c.Permissions == nil {
return nil
}
s, ok := c.Permissions.Extensions["pubkey"]
if !ok {
return nil
}
key, err := ssh.ParsePublicKey([]byte(s))
if err != nil {
return nil
}
return key
}
func (c sshConn) Name() string {
return c.User()
}
// EnvVar is an environment variable key-value pair
type EnvVar struct {
Key string
Value string
}
func (v EnvVar) String() string {
return v.Key + "=" + v.Value
}
// Env is a wrapper type around []EnvVar with some helper methods
type Env []EnvVar
// Get returns the latest value for a given key, or empty string if not found
func (e Env) Get(key string) string {
for i := len(e) - 1; i >= 0; i-- {
if e[i].Key == key {
return e[i].Value
}
}
return ""
}
// Terminal extends ssh/terminal to include a close method
type Terminal struct {
terminal.Terminal
Conn Connection
Channel ssh.Channel
done chan struct{}
closeOnce sync.Once
mu sync.Mutex
env []EnvVar
term string
}
// Make new terminal from a session channel
// TODO: For v2, make a separate `Serve(ctx context.Context) error` method to activate the Terminal
func NewTerminal(conn *ssh.ServerConn, ch ssh.NewChannel) (*Terminal, error) {
if ch.ChannelType() != "session" {
return nil, ErrNotSessionChannel
}
channel, requests, err := ch.Accept()
if err != nil {
return nil, err
}
term := Terminal{
Terminal: *terminal.NewTerminal(channel, ""),
Conn: sshConn{conn},
Channel: channel,
done: make(chan struct{}),
}
ready := make(chan struct{})
go term.listen(requests, ready)
go func() {
// Keep-Alive Ticker
ticker := time.Tick(keepaliveInterval)
for {
select {
case <-ticker:
_, err := channel.SendRequest(keepaliveRequest, true, nil)
if err != nil {
// Connection is gone
logger.Printf("[%s] Keepalive failed, closing terminal: %s", term.Conn.RemoteAddr(), err)
term.Close()
return
}
case <-term.done:
return
}
}
}()
// We need to wait for term.ready to acquire a shell before we return, this
// gives the SSH session a chance to populate the env vars and other state.
// TODO: Make the timeout configurable
// TODO: Use context.Context for abort/timeout in the future, will need to change the API.
select {
case <-ready: // shell acquired
return &term, nil
case <-term.done:
return nil, errors.New("terminal aborted")
case <-time.NewTimer(time.Minute).C:
return nil, errors.New("timed out starting terminal")
}
}
// NewSession Finds a session channel and make a Terminal from it
func NewSession(conn *ssh.ServerConn, channels <-chan ssh.NewChannel) (*Terminal, error) {
// Make a terminal from the first session found
for ch := range channels {
if t := ch.ChannelType(); t != "session" {
logger.Printf("[%s] Ignored channel type: %s", conn.RemoteAddr(), t)
ch.Reject(ssh.UnknownChannelType, fmt.Sprintf("unknown channel type: %s", t))
continue
}
return NewTerminal(conn, ch)
}
return nil, ErrNoSessionChannel
}
// Close terminal and ssh connection
func (t *Terminal) Close() error {
var err error
t.closeOnce.Do(func() {
close(t.done)
t.Channel.Close()
err = t.Conn.Close()
})
return err
}
// listen negotiates the terminal type and state
// ready is closed when the terminal is ready.
func (t *Terminal) listen(requests <-chan *ssh.Request, ready chan<- struct{}) {
hasShell := false
for req := range requests {
var width, height int
var ok bool
switch req.Type {
case "shell":
if !hasShell {
ok = true
hasShell = true
close(ready)
}
case "pty-req":
var term string
term, width, height, ok = parsePtyRequest(req.Payload)
if ok {
// TODO: Hardcode width to 100000?
err := t.SetSize(width, height)
ok = err == nil
// Save the term:
t.mu.Lock()
t.term = term
t.mu.Unlock()
}
case "window-change":
width, height, ok = parseWinchRequest(req.Payload)
if ok {
// TODO: Hardcode width to 100000?
err := t.SetSize(width, height)
ok = err == nil
}
case "env":
var v EnvVar
if err := ssh.Unmarshal(req.Payload, &v); err == nil {
t.mu.Lock()
t.env = append(t.env, v)
t.mu.Unlock()
ok = true
}
}
if req.WantReply {
req.Reply(ok, nil)
}
}
}
// Env returns a list of environment key-values that have been set. They are
// returned in the order that they have been set, there is no deduplication or
// other pre-processing applied.
func (t *Terminal) Env() Env {
t.mu.Lock()
defer t.mu.Unlock()
return Env(t.env)
}
// Term returns the terminal string value as set by the pty.
// If there was no pty request, it falls back to the TERM value passed in as an
// Env variable.
func (t *Terminal) Term() string {
t.mu.Lock()
defer t.mu.Unlock()
if t.term != "" {
return t.term
}
return Env(t.env).Get("TERM")
}