Skip to content

Commit

Permalink
Add ssh module for running commands over ssh (#39)
Browse files Browse the repository at this point in the history
* Add ssh module for running commands over ssh

* Add comments and ability to configure host key checks

* Fix build

* Use newer base image with newer terraform version

* Rebase and rename gruntwork-cli

* Return sshAgent handle if there is an error parsing the privatekey

* Add support for decoding keys other than rsa
  • Loading branch information
yorinasub17 authored Mar 12, 2021
1 parent 96ff82e commit 177b60e
Show file tree
Hide file tree
Showing 13 changed files with 1,512 additions and 11 deletions.
2 changes: 1 addition & 1 deletion .circleci/config.yml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
defaults: &defaults
docker:
- image: 087285199408.dkr.ecr.us-east-1.amazonaws.com/circle-ci-test-image-base:go1.13
- image: 087285199408.dkr.ecr.us-east-1.amazonaws.com/circle-ci-test-image-base:tf14.4
version: 2
jobs:
test:
Expand Down
6 changes: 6 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ This repo contains the following packages:
* files
* logging
* shell
* ssh

Each of these packages is described below.

Expand Down Expand Up @@ -128,6 +129,11 @@ This package contains two types of helpers:
* `cmd.go`: This file contains helpers for running shell commands.
* `prompt.go`: This file contains helpers for prompting the user for input (e.g. yes/no).

### ssh

This package contains helper methods for initiating SSH connections and running commands over the connection.


## Running tests

```
Expand Down
3 changes: 2 additions & 1 deletion entrypoint/assertions.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ package entrypoint

import (
"fmt"
"github.com/urfave/cli"
"os"

"github.com/urfave/cli"
)

type RequiredArgsError struct {
Expand Down
7 changes: 5 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,12 @@ go 1.13
require (
github.com/bgentry/speakeasy v0.1.0
github.com/fatih/color v1.9.0
github.com/go-errors/errors v1.0.1
github.com/mattn/go-zglob v0.0.1
github.com/go-errors/errors v1.0.2-0.20180813162953-d98b870cc4e0
github.com/gruntwork-io/terratest v0.32.9
github.com/hashicorp/go-multierror v1.1.0
github.com/mattn/go-zglob v0.0.2-0.20190814121620-e3c945676326
github.com/sirupsen/logrus v1.4.2
github.com/stretchr/testify v1.4.0
github.com/urfave/cli v1.22.2
golang.org/x/crypto v0.0.0-20210220033148-5ea612d1eb83
)
636 changes: 629 additions & 7 deletions go.sum

Large diffs are not rendered by default.

4 changes: 4 additions & 0 deletions ssh/.gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
.terraform
stages
*.tfstate
*.tfstate.backup
151 changes: 151 additions & 0 deletions ssh/agent.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
package ssh

import (
"crypto/x509"
"encoding/pem"
"io"
"io/ioutil"
"net"
"os"
"path/filepath"

multierror "github.com/hashicorp/go-multierror"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh/agent"
)

// SSHAgent represents an instance of the ssh-agent process.
type SSHAgent struct {
stop chan bool
stopped chan bool
socketDir string
socketFile string
agent agent.Agent
ln net.Listener

logger *logrus.Entry
}

// Create SSH agent, start it in background and returns control back to the main thread
// You should stop the agent to cleanup files afterwards by calling `defer s.Stop()`
func NewSSHAgent(logger *logrus.Entry, socketDir string, socketFile string) (*SSHAgent, error) {
var err error
s := &SSHAgent{
stop: make(chan bool),
stopped: make(chan bool),
socketDir: socketDir,
socketFile: socketFile,
agent: agent.NewKeyring(),
}
s.ln, err = net.Listen("unix", s.socketFile)
if err != nil {
return nil, err
}
go s.run()
return s, nil
}

// expose socketFile variable
func (s *SSHAgent) SocketFile() string {
return s.socketFile
}

// SSH Agent listener and handler
func (s *SSHAgent) run() {
defer close(s.stopped)
for {
select {
case <-s.stop:
return
default:
c, err := s.ln.Accept()
if err != nil {
select {
// When s.Stop() closes the listener, s.ln.Accept() returns an error that can be ignored
// since the agent is in stopping process
case <-s.stop:
return
// When s.ln.Accept() returns a legit error, we print it and continue accepting further requests
default:
if s.logger != nil {
s.logger.Errorf("could not accept connection to agent %v", err)
}
continue
}
} else {
defer c.Close()
go func(c io.ReadWriter) {
err := agent.ServeAgent(s.agent, c)
if err != nil && s.logger != nil {
s.logger.Errorf("could not serve ssh agent %v", err)
}
}(c)
}
}
}
}

// Stop and clean up SSH agent
func (s *SSHAgent) Stop() {
close(s.stop)
s.ln.Close()
<-s.stopped
os.RemoveAll(s.socketDir)
}

// Instantiates and returns an in-memory ssh agent with the given private key already added
// You should stop the agent to cleanup files afterwards by calling `defer sshAgent.Stop()`
func SSHAgentWithPrivateKey(logger *logrus.Entry, privateKey string) (*SSHAgent, error) {
sshAgent, err := SSHAgentWithPrivateKeys(logger, []string{privateKey})
return sshAgent, err
}

// Instantiates and returns an in-memory ssh agent with the given private key(s) already added
// You should stop the agent to cleanup files afterwards by calling `defer sshAgent.Stop()`
func SSHAgentWithPrivateKeys(logger *logrus.Entry, privateKeys []string) (*SSHAgent, error) {
// Instantiate a temporary SSH agent
socketDir, err := ioutil.TempDir("", "ssh-agent-")
if err != nil {
return nil, err
}
socketFile := filepath.Join(socketDir, "ssh_auth.sock")
sshAgent, err := NewSSHAgent(logger, socketDir, socketFile)
if err != nil {
return nil, err
}

// add given ssh keys to the newly created agent
var allErrs *multierror.Error
for _, privateKey := range privateKeys {
// Create SSH key for the agent using the given SSH key pair(s)
block, _ := pem.Decode([]byte(privateKey))
decodedPrivateKey, err := decodePrivateKey(block.Bytes)
if err != nil {
logger.Error("Error decoding private key for adding to ssh-agent")
allErrs = multierror.Append(allErrs, err)
} else {
key := agent.AddedKey{PrivateKey: decodedPrivateKey}
if err := sshAgent.agent.Add(key); err != nil {
logger.Error("Error adding private key ssh-agent")
allErrs = multierror.Append(allErrs, err)
}
}
}
return sshAgent, allErrs.ErrorOrNil()
}

// decodePrivateKey first attempts to decode the key as PKCS8, and then fallsback to PKCS1 if that fails.
// This function returns a *rsa.PrivateKey, a *ecdsa.PrivateKey, or a ed25519.PrivateKey.
func decodePrivateKey(keyBytes []byte) (interface{}, error) {
var allErrs *multierror.Error
decodedPrivateKey, err := x509.ParsePKCS8PrivateKey(keyBytes)
if err != nil {
allErrs = multierror.Append(allErrs, err)
decodedPrivateKey, err = x509.ParsePKCS1PrivateKey(keyBytes)
if err != nil {
allErrs = multierror.Append(allErrs, err)
return nil, allErrs.ErrorOrNil()
}
}
return decodedPrivateKey, nil
}
182 changes: 182 additions & 0 deletions ssh/options.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
package ssh

import (
"errors"
"fmt"
"io"
"net"
"os"
"reflect"
"strconv"

multierror "github.com/hashicorp/go-multierror"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"

"github.com/gruntwork-io/go-commons/collections"
)

// Host is a remote host.
type Host struct {
Hostname string // host name or ip address
SSHUserName string // user name
CustomPort int // port number to use to connect to the host (port 22 will be used if unset)

JumpHost *Host // Another host to use as a jump host to reach this host.

HostKeyCallback ssh.HostKeyCallback // Callback function for handling host key checks.

// set one or more authentication methods,
// the first valid method will be used
PrivateKey string // ssh private key to use as authentication method (disabled by default)
SSHAgent bool // enable authentication using your existing local SSH agent (disabled by default)
OverrideSSHAgent *SSHAgent // enable an in process `SSHAgent` for connections to this host (disabled by default)
Password string // plain text password (blank by default)
}

// getSSHConnectionOptions converts the host configuration into a set of options that can be used for managing the SSH
// connection.
func (host *Host) getSSHConnectionOptions() (*sshConnectionOptions, error) {
if host == nil {
return nil, nil
}

authMethods, err := host.createAuthMethods()
if err != nil {
return nil, err
}

hostOptions := sshConnectionOptions{
Username: host.SSHUserName,
Address: host.Hostname,
Port: host.getPort(),
HostKeyCallback: host.HostKeyCallback,
AuthMethods: authMethods,
}
return &hostOptions, nil
}

// getPort gets the port that should be used to communicate with the host
func (h Host) getPort() int {

//If a CustomPort is not set use standard ssh port
if h.CustomPort == 0 {
return 22
} else {
return h.CustomPort
}
}

// createAuthMethods returns an array of authentication methods
func (host Host) createAuthMethods() ([]ssh.AuthMethod, error) {
var methods []ssh.AuthMethod

// override local ssh agent with given sshAgent instance
if host.OverrideSSHAgent != nil {
conn, err := net.Dial("unix", host.OverrideSSHAgent.socketFile)
if err != nil {
fmt.Print("Failed to dial in memory ssh agent")
return methods, err
}
agentClient := agent.NewClient(conn)
methods = append(methods, []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}...)
}

// use existing ssh agent socket
// if agent authentication is enabled and no agent is set up, returns an error
if host.SSHAgent {
socket := os.Getenv("SSH_AUTH_SOCK")
conn, err := net.Dial("unix", socket)
if err != nil {
return methods, err
}
agentClient := agent.NewClient(conn)
methods = append(methods, []ssh.AuthMethod{ssh.PublicKeysCallback(agentClient.Signers)}...)
}

// use provided ssh key pair
if host.PrivateKey != "" {
signer, err := ssh.ParsePrivateKey([]byte(host.PrivateKey))
if err != nil {
return methods, err
}
methods = append(methods, []ssh.AuthMethod{ssh.PublicKeys(signer)}...)
}

// Use given password
if len(host.Password) > 0 {
methods = append(methods, []ssh.AuthMethod{ssh.Password(host.Password)}...)
}

// no valid authentication method was provided
if len(methods) < 1 {
return methods, errors.New("no authentication method defined")
}

return methods, nil
}

// sshConnectionOptions are the options for an SSH connection.
type sshConnectionOptions struct {
Username string
Address string
Port int
AuthMethods []ssh.AuthMethod
HostKeyCallback ssh.HostKeyCallback
Command string
JumpHostOptions *sshConnectionOptions
}

// ConnectionString returns the connection string for an SSH connection.
func (options *sshConnectionOptions) ConnectionString() string {
return net.JoinHostPort(options.Address, strconv.Itoa(options.Port))
}

// sshCloseStack is a LIFO (stack) data structure for tracking all the resources that need to be closed at the end of an
// SSH connection. This is useful for having a single defer call in a top-level method to clean up resources that are
// recursively created across multiple jump hosts.
type sshCloseStack struct {
stack []Closeable
}

// Push will push an item on the close stack by prepending the item to the top of the array.
func (this *sshCloseStack) Push(item Closeable) {
this.stack = append([]Closeable{item}, this.stack...)
}

// CloseAll iterates over all the closeable items and closes the connection one by one. This will attempt to close
// everything in the stack regardless of errors, and return a single multierror at the end that aggregates all
// encountered errors.
func (this *sshCloseStack) CloseAll() error {
allErrs := &multierror.Error{}
for _, closeable := range this.stack {
// Closing a connection may result in an EOF error if it's already closed (e.g. due to hitting CTRL + D), so
// don't report those errors, as there is nothing actually wrong in that case.
allErrs = multierror.Append(allErrs, Close(closeable, io.EOF.Error()))
}
return allErrs.ErrorOrNil()
}

// Closeable can be closed.
type Closeable interface {
Close() error
}

// Close closes a Closeable.
func Close(closeable Closeable, ignoreErrors ...string) error {
if interfaceIsNil(closeable) {
return nil
}

if err := closeable.Close(); err != nil && !collections.ListContainsElement(ignoreErrors, err.Error()) {
return err
}
return nil
}

// Checking an interface directly against nil does not work, and if you don't know the exact types the interface may be
// ahead of time, the only way to know if you're dealing with nil is to use reflection.
// http://stackoverflow.com/questions/13476349/check-for-nil-and-nil-interface-in-go
func interfaceIsNil(i interface{}) bool {
return i == nil || reflect.ValueOf(i).IsNil()
}
Loading

0 comments on commit 177b60e

Please sign in to comment.