Skip to content

Commit

Permalink
Adds default fallback ssh key and skip agent options (#110)
Browse files Browse the repository at this point in the history
* Adds default fallback ssh key and skip agent options

* Better auth method cycling

* Updates logic for auth items

* Update shell test to gen private keys
  • Loading branch information
bomoko authored Dec 6, 2023
1 parent 986ac0d commit 6815673
Show file tree
Hide file tree
Showing 6 changed files with 262 additions and 46 deletions.
4 changes: 4 additions & 0 deletions cmd/sync.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ var SSHHost string
var SSHPort string
var SSHKey string
var SSHVerbose bool
var SSHSkipAgent bool
var CmdSSHKey string
var noCliInteraction bool
var dryRun bool
Expand Down Expand Up @@ -154,12 +155,14 @@ func syncCommandRun(cmd *cobra.Command, args []string) {
if sshConfig.Verbose && !sshVerbose {
sshVerbose = sshConfig.Verbose
}

sshOptions := synchers.SSHOptions{
Host: sshHost,
PrivateKey: sshKey,
Port: sshPort,
Verbose: sshVerbose,
RsyncArgs: RsyncArguments,
SkipAgent: SSHSkipAgent,
}

// let's update the named transfer resource if it is set
Expand Down Expand Up @@ -224,6 +227,7 @@ func init() {
syncCmd.PersistentFlags().StringVarP(&SSHHost, "ssh-host", "H", "ssh.lagoon.amazeeio.cloud", "Specify your lagoon ssh host, defaults to 'ssh.lagoon.amazeeio.cloud'")
syncCmd.PersistentFlags().StringVarP(&SSHPort, "ssh-port", "P", "32222", "Specify your ssh port, defaults to '32222'")
syncCmd.PersistentFlags().StringVarP(&SSHKey, "ssh-key", "i", "", "Specify path to a specific SSH key to use for authentication")
syncCmd.PersistentFlags().BoolVar(&SSHSkipAgent, "ssh-skip-agent", false, "Do not attempt to use an ssh-agent for key management")
syncCmd.PersistentFlags().BoolVar(&SSHVerbose, "verbose", false, "Run ssh commands in verbose (useful for debugging)")
syncCmd.PersistentFlags().BoolVar(&noCliInteraction, "no-interaction", false, "Disallow interaction")
syncCmd.PersistentFlags().BoolVar(&dryRun, "dry-run", false, "Don't run the commands, just preview what will be run")
Expand Down
34 changes: 21 additions & 13 deletions synchers/prerequisiteSyncUtils.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,29 +26,38 @@ func RunPrerequisiteCommand(environment Environment, syncer Syncer, syncerType s
var execString string
var configRespSuccessful bool

command, commandErr := syncer.GetPrerequisiteCommand(environment, "config").GetCommand()
execString, commandErr := syncer.GetPrerequisiteCommand(environment, "config").GetCommand()
if commandErr != nil {
return environment, commandErr
}

if environment.EnvironmentName == LOCAL_ENVIRONMENT_NAME {
execString = command
} else {
execString = GenerateRemoteCommand(environment, command, sshOptions)
}

utils.LogExecutionStep("Running the following prerequisite command", execString)

err, configResponseJson, errstring := utils.Shellout(execString)
if err != nil {
fmt.Println(errstring)
var output string

if environment.EnvironmentName == LOCAL_ENVIRONMENT_NAME {
err, response, errstring := utils.Shellout(execString)
if err != nil {
log.Printf(errstring)
return environment, err
}
if response != "" && debug == false {
log.Println(response)
}
} else {
err, output := utils.RemoteShellout(execString, environment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent)
utils.LogDebugInfo(output, nil)
if err != nil {
utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil)
return environment, err
}
}

data := &prerequisite.PreRequisiteResponse{}
json.Unmarshal([]byte(configResponseJson), &data)
json.Unmarshal([]byte(output), &data)

if !data.IsPrerequisiteResponseEmpty() {
utils.LogDebugInfo("'lagoon-sync config' response", configResponseJson)
utils.LogDebugInfo("'lagoon-sync config' response", output)
configRespSuccessful = true
} else {
utils.LogWarning("'lagoon-sync' is not available on", environment.EnvironmentName)
Expand Down Expand Up @@ -84,7 +93,6 @@ func RunPrerequisiteCommand(environment Environment, syncer Syncer, syncerType s
// Add rsync to env
rsyncPath, err := createRsync(environment, syncer, lagoonSyncVersion, sshOptions)
if err != nil {
fmt.Println(errstring)
return environment, err
}

Expand Down
1 change: 1 addition & 0 deletions synchers/syncdefs.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ type SSHOptions struct {
Port string `yaml:"port,omitempty" json:"port,omitempty"`
Verbose bool `yaml:"verbose,omitempty" json:"verbose,omitempty"`
PrivateKey string `yaml:"privateKey,omitempty" json:"privateKey,omitempty"`
SkipAgent bool
RsyncArgs string `yaml:"rsyncArgs,omitempty" json:"rsyncArgs,omitempty"`
}

Expand Down
11 changes: 6 additions & 5 deletions synchers/syncutils.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,8 @@ func RunSyncProcess(args RunSyncProcessFunctionTypeArguments) error {

//TODO: this can come out.
args.SourceEnvironment, err = RunPrerequisiteCommand(args.SourceEnvironment, args.LagoonSyncer, args.SyncerType, args.DryRun, args.SshOptions)
sourceRsyncPath := args.SourceEnvironment.RsyncPath
sourceRsyncPath := "rsync" //args.SourceEnvironment.RsyncPath
args.SourceEnvironment.RsyncPath = "rsync"
if err != nil {
_ = PrerequisiteCleanUp(args.SourceEnvironment, sourceRsyncPath, args.DryRun, args.SshOptions)
return err
Expand Down Expand Up @@ -137,7 +138,7 @@ func SyncRunSourceCommand(remoteEnvironment Environment, syncer Syncer, dryRun b
log.Println(response)
}
} else {
err, output := utils.RemoteShellout(execString, remoteEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey)
err, output := utils.RemoteShellout(execString, remoteEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent)
utils.LogDebugInfo(output, nil)
if err != nil {
utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil)
Expand Down Expand Up @@ -237,7 +238,7 @@ func SyncRunTransfer(sourceEnvironment Environment, targetEnvironment Environmen
if !dryRun {

if executeRsyncRemotelyOnTarget {
err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey)
err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent)
utils.LogDebugInfo(output, nil)
if err != nil {
utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil)
Expand Down Expand Up @@ -282,7 +283,7 @@ func SyncRunTargetCommand(targetEnvironment Environment, syncer Syncer, dryRun b
return err
}
} else {
err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey)
err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent)
utils.LogDebugInfo(output, nil)
if err != nil {
utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil)
Expand Down Expand Up @@ -312,7 +313,7 @@ func SyncCleanUp(environment Environment, syncer Syncer, dryRun bool, sshOptions
utils.LogExecutionStep("Running the following", execString)
if !dryRun {
if environment.EnvironmentName != LOCAL_ENVIRONMENT_NAME {
err, output := utils.RemoteShellout(execString, environment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey)
err, output := utils.RemoteShellout(execString, environment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent)
utils.LogDebugInfo(output, nil)
if err != nil {
utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil)
Expand Down
169 changes: 141 additions & 28 deletions utils/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,20 @@ package utils

import (
"bytes"
"errors"
"fmt"
"golang.org/x/crypto/ssh"
"golang.org/x/crypto/ssh/agent"
"net"
"os"
"os/exec"
"path/filepath"
)

const ShellToUse = "sh"

var validAuthMethod *ssh.AuthMethod

func Shellout(command string) (error, string, string) {
var stdout bytes.Buffer
var stderr bytes.Buffer
Expand All @@ -21,42 +26,83 @@ func Shellout(command string) (error, string, string) {
return err, stdout.String(), stderr.String()
}

func RemoteShellout(command string, remoteUser string, remoteHost string, remotePort string, privateKeyfile string) (error, string) {
// Read the private key file
func getAuthMethodFromPrivateKey(filename string) (ssh.AuthMethod, error) {
privateKeyBytes, err := os.ReadFile(filename)

if err != nil {
return nil, err
}

if len(privateKeyBytes) > 0 {
// Parse the private key
signer, err := ssh.ParsePrivateKey(privateKeyBytes)
if err != nil {
return nil, err
}

// SSH client configuration
authKeys := ssh.PublicKeys(signer)
return authKeys, nil

skipAgent := false
}
return nil, errors.New(fmt.Sprint("No data in privateKey: ", filename))
}

func getSSHAuthMethodsFromDirectory(directory string) ([]ssh.AuthMethod, error) {
var authMethods []ssh.AuthMethod
err := filepath.Walk(directory, func(path string, info os.FileInfo, err error) error {
if err != nil {
return err
}

if skipAgent != true {
// Connect to SSH agent to ask for unencrypted private keys
if sshAgentConn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil {
sshAgent := agent.NewClient(sshAgentConn)
keys, _ := sshAgent.List()
if len(keys) > 0 {
agentAuthmethods := ssh.PublicKeysCallback(sshAgent.Signers)
authMethods = append(authMethods, agentAuthmethods)
if !info.IsDir() && filepath.Ext(path) != ".pub" {

// let's test this is a valid ssh key
am, err := getAuthMethodFromPrivateKey(path)
if err != nil {
switch {
case isPassphraseMissingError(err):
LogDebugInfo(fmt.Sprintf("Found a passphrase based ssh key: %v", err.Error()), os.Stdout)
default:
LogWarning(err.Error(), os.Stdout)
}
} else {
LogDebugInfo(fmt.Sprintf("Found a valid key at %v - will try auth", path), os.Stdout)
authMethods = append(authMethods, am)
}
}
return nil
})
if err != nil {
return nil, err
}
return authMethods, nil
}

privateKeyBytes, err := os.ReadFile(privateKeyfile)
func isPassphraseMissingError(err error) bool {
_, ok := err.(*ssh.PassphraseMissingError)
return ok
}

// if there are authMethods already, let's keep going
if err != nil && len(authMethods) == 0 {
return err, ""
}
func RemoteShellout(command string, remoteUser string, remoteHost string, remotePort string, privateKeyfile string, skipSshAgent bool) (error, string) {

if len(privateKeyBytes) > 0 {
// Parse the private key
signer, err := ssh.ParsePrivateKey(privateKeyBytes)
if err != nil {
return err, ""
sshAuthSock, present := os.LookupEnv("SSH_AUTH_SOCK")
skipAgent := !present || skipSshAgent

var authMethods []ssh.AuthMethod

if validAuthMethod == nil { // This makes it so that in subsequent calls, we don't have to recheck all auth methods
LogDebugInfo("First time running, no cached valid auth methods", os.Stdout)
authMethods = getAuthmethods(skipAgent, privateKeyfile, sshAuthSock, authMethods)
} else {
LogDebugInfo("Found existing auth method", os.Stdout)
authMethods = []ssh.AuthMethod{
*validAuthMethod,
}
}

// SSH client configuration
authKeys := ssh.PublicKeys(signer)
authMethods = append(authMethods, authKeys)
if len(authMethods) == 0 && validAuthMethod == nil {
return errors.New("No valid authentication methods provided"), ""
}

config := &ssh.ClientConfig{
Expand All @@ -65,10 +111,30 @@ func RemoteShellout(command string, remoteUser string, remoteHost string, remote
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
}

// Connect to the remote server
client, err := ssh.Dial("tcp", remoteHost+":"+remotePort, config)
if err != nil {
return err, ""
var client *ssh.Client
var err error

//we need to iterate over the auth methods till we find one that works
// for subsequent runs, this will only run once, since only whatever is in
// validAuthMethod will be attemtped.
for _, am := range authMethods {
config.Auth = []ssh.AuthMethod{
am,
}
client, err = ssh.Dial("tcp", remoteHost+":"+remotePort, config)
if err != nil {
continue
}

if validAuthMethod == nil {
LogDebugInfo("Dial success - caching auth method for subsequent runs", os.Stdout)
validAuthMethod = &am // set the valid auth method so that future calls won't need to retry
}
break
}

if validAuthMethod == nil {
return errors.New("unable to find valid auth method for ssh"), ""
}

defer client.Close()
Expand Down Expand Up @@ -103,3 +169,50 @@ func RemoteShellout(command string, remoteUser string, remoteHost string, remote

return nil, outputBuffer.String()
}

func getAuthmethods(skipAgent bool, privateKeyfile string, sshAuthSock string, authMethods []ssh.AuthMethod) []ssh.AuthMethod {
if skipAgent != true && privateKeyfile == "" {
// Connect to SSH agent to ask for unencrypted private keys
if sshAgentConn, err := net.Dial("unix", sshAuthSock); err == nil {
sshAgent := agent.NewClient(sshAgentConn)
keys, _ := sshAgent.List()
if len(keys) > 0 {
agentAuthmethods := ssh.PublicKeysCallback(sshAgent.Signers)
authMethods = append(authMethods, agentAuthmethods)
}
}
} else {
LogDebugInfo("Skipping ssh agent", os.Stdout)
}

if privateKeyfile == "" { //let's try guess it from the OS
userPath, err := os.UserHomeDir()
if err != nil {
LogWarning("No ssh key given and no home directory available", os.Stdout)
}

userPath = filepath.Join(userPath, ".ssh")

if _, err := os.Stat(userPath); err == nil {
sshAm, err := getSSHAuthMethodsFromDirectory(userPath)
if err != nil {
LogWarning(err.Error(), os.Stdout)
}
authMethods = append(authMethods, sshAm...)
} else {
LogWarning("Unable to find .ssh directory in user home", os.Stdout)
}
} else {
privateKeyFiles := []string{
privateKeyfile,
}

for _, kf := range privateKeyFiles {
am, err := getAuthMethodFromPrivateKey(kf)
if err == nil {
authMethods = append(authMethods, am)
}
}
}
return authMethods
}
Loading

0 comments on commit 6815673

Please sign in to comment.