diff --git a/cmd/sync.go b/cmd/sync.go index 1d37b69..8916d2a 100644 --- a/cmd/sync.go +++ b/cmd/sync.go @@ -25,6 +25,7 @@ var SSHHost string var SSHPort string var SSHKey string var SSHVerbose bool +var SSHSkipAgent bool var CmdSSHKey string var noCliInteraction bool var dryRun bool @@ -154,12 +155,14 @@ func syncCommandRun(cmd *cobra.Command, args []string) { if sshConfig.Verbose && !sshVerbose { sshVerbose = sshConfig.Verbose } + sshOptions := synchers.SSHOptions{ Host: sshHost, PrivateKey: sshKey, Port: sshPort, Verbose: sshVerbose, RsyncArgs: RsyncArguments, + SkipAgent: SSHSkipAgent, } // let's update the named transfer resource if it is set @@ -224,6 +227,7 @@ func init() { syncCmd.PersistentFlags().StringVarP(&SSHHost, "ssh-host", "H", "ssh.lagoon.amazeeio.cloud", "Specify your lagoon ssh host, defaults to 'ssh.lagoon.amazeeio.cloud'") syncCmd.PersistentFlags().StringVarP(&SSHPort, "ssh-port", "P", "32222", "Specify your ssh port, defaults to '32222'") syncCmd.PersistentFlags().StringVarP(&SSHKey, "ssh-key", "i", "", "Specify path to a specific SSH key to use for authentication") + syncCmd.PersistentFlags().BoolVar(&SSHSkipAgent, "ssh-skip-agent", false, "Do not attempt to use an ssh-agent for key management") syncCmd.PersistentFlags().BoolVar(&SSHVerbose, "verbose", false, "Run ssh commands in verbose (useful for debugging)") syncCmd.PersistentFlags().BoolVar(&noCliInteraction, "no-interaction", false, "Disallow interaction") syncCmd.PersistentFlags().BoolVar(&dryRun, "dry-run", false, "Don't run the commands, just preview what will be run") diff --git a/synchers/prerequisiteSyncUtils.go b/synchers/prerequisiteSyncUtils.go index 742c1ec..f37eeca 100644 --- a/synchers/prerequisiteSyncUtils.go +++ b/synchers/prerequisiteSyncUtils.go @@ -26,29 +26,38 @@ func RunPrerequisiteCommand(environment Environment, syncer Syncer, syncerType s var execString string var configRespSuccessful bool - command, commandErr := syncer.GetPrerequisiteCommand(environment, "config").GetCommand() + execString, commandErr := syncer.GetPrerequisiteCommand(environment, "config").GetCommand() if commandErr != nil { return environment, commandErr } - if environment.EnvironmentName == LOCAL_ENVIRONMENT_NAME { - execString = command - } else { - execString = GenerateRemoteCommand(environment, command, sshOptions) - } - utils.LogExecutionStep("Running the following prerequisite command", execString) - err, configResponseJson, errstring := utils.Shellout(execString) - if err != nil { - fmt.Println(errstring) + var output string + + if environment.EnvironmentName == LOCAL_ENVIRONMENT_NAME { + err, response, errstring := utils.Shellout(execString) + if err != nil { + log.Printf(errstring) + return environment, err + } + if response != "" && debug == false { + log.Println(response) + } + } else { + err, output := utils.RemoteShellout(execString, environment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent) + utils.LogDebugInfo(output, nil) + if err != nil { + utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil) + return environment, err + } } data := &prerequisite.PreRequisiteResponse{} - json.Unmarshal([]byte(configResponseJson), &data) + json.Unmarshal([]byte(output), &data) if !data.IsPrerequisiteResponseEmpty() { - utils.LogDebugInfo("'lagoon-sync config' response", configResponseJson) + utils.LogDebugInfo("'lagoon-sync config' response", output) configRespSuccessful = true } else { utils.LogWarning("'lagoon-sync' is not available on", environment.EnvironmentName) @@ -84,7 +93,6 @@ func RunPrerequisiteCommand(environment Environment, syncer Syncer, syncerType s // Add rsync to env rsyncPath, err := createRsync(environment, syncer, lagoonSyncVersion, sshOptions) if err != nil { - fmt.Println(errstring) return environment, err } diff --git a/synchers/syncdefs.go b/synchers/syncdefs.go index 9aa8eca..1de7e41 100644 --- a/synchers/syncdefs.go +++ b/synchers/syncdefs.go @@ -69,6 +69,7 @@ type SSHOptions struct { Port string `yaml:"port,omitempty" json:"port,omitempty"` Verbose bool `yaml:"verbose,omitempty" json:"verbose,omitempty"` PrivateKey string `yaml:"privateKey,omitempty" json:"privateKey,omitempty"` + SkipAgent bool RsyncArgs string `yaml:"rsyncArgs,omitempty" json:"rsyncArgs,omitempty"` } diff --git a/synchers/syncutils.go b/synchers/syncutils.go index 426af74..a7b34a4 100644 --- a/synchers/syncutils.go +++ b/synchers/syncutils.go @@ -51,7 +51,8 @@ func RunSyncProcess(args RunSyncProcessFunctionTypeArguments) error { //TODO: this can come out. args.SourceEnvironment, err = RunPrerequisiteCommand(args.SourceEnvironment, args.LagoonSyncer, args.SyncerType, args.DryRun, args.SshOptions) - sourceRsyncPath := args.SourceEnvironment.RsyncPath + sourceRsyncPath := "rsync" //args.SourceEnvironment.RsyncPath + args.SourceEnvironment.RsyncPath = "rsync" if err != nil { _ = PrerequisiteCleanUp(args.SourceEnvironment, sourceRsyncPath, args.DryRun, args.SshOptions) return err @@ -137,7 +138,7 @@ func SyncRunSourceCommand(remoteEnvironment Environment, syncer Syncer, dryRun b log.Println(response) } } else { - err, output := utils.RemoteShellout(execString, remoteEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey) + err, output := utils.RemoteShellout(execString, remoteEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent) utils.LogDebugInfo(output, nil) if err != nil { utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil) @@ -237,7 +238,7 @@ func SyncRunTransfer(sourceEnvironment Environment, targetEnvironment Environmen if !dryRun { if executeRsyncRemotelyOnTarget { - err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey) + err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent) utils.LogDebugInfo(output, nil) if err != nil { utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil) @@ -282,7 +283,7 @@ func SyncRunTargetCommand(targetEnvironment Environment, syncer Syncer, dryRun b return err } } else { - err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey) + err, output := utils.RemoteShellout(execString, targetEnvironment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent) utils.LogDebugInfo(output, nil) if err != nil { utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil) @@ -312,7 +313,7 @@ func SyncCleanUp(environment Environment, syncer Syncer, dryRun bool, sshOptions utils.LogExecutionStep("Running the following", execString) if !dryRun { if environment.EnvironmentName != LOCAL_ENVIRONMENT_NAME { - err, output := utils.RemoteShellout(execString, environment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey) + err, output := utils.RemoteShellout(execString, environment.GetOpenshiftProjectName(), sshOptions.Host, sshOptions.Port, sshOptions.PrivateKey, sshOptions.SkipAgent) utils.LogDebugInfo(output, nil) if err != nil { utils.LogFatalError("Unable to exec remote command: "+err.Error(), nil) diff --git a/utils/shell.go b/utils/shell.go index a8691bc..da2fb7d 100644 --- a/utils/shell.go +++ b/utils/shell.go @@ -2,15 +2,20 @@ package utils import ( "bytes" + "errors" + "fmt" "golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh/agent" "net" "os" "os/exec" + "path/filepath" ) const ShellToUse = "sh" +var validAuthMethod *ssh.AuthMethod + func Shellout(command string) (error, string, string) { var stdout bytes.Buffer var stderr bytes.Buffer @@ -21,42 +26,83 @@ func Shellout(command string) (error, string, string) { return err, stdout.String(), stderr.String() } -func RemoteShellout(command string, remoteUser string, remoteHost string, remotePort string, privateKeyfile string) (error, string) { - // Read the private key file +func getAuthMethodFromPrivateKey(filename string) (ssh.AuthMethod, error) { + privateKeyBytes, err := os.ReadFile(filename) + + if err != nil { + return nil, err + } + + if len(privateKeyBytes) > 0 { + // Parse the private key + signer, err := ssh.ParsePrivateKey(privateKeyBytes) + if err != nil { + return nil, err + } + + // SSH client configuration + authKeys := ssh.PublicKeys(signer) + return authKeys, nil - skipAgent := false + } + return nil, errors.New(fmt.Sprint("No data in privateKey: ", filename)) +} +func getSSHAuthMethodsFromDirectory(directory string) ([]ssh.AuthMethod, error) { var authMethods []ssh.AuthMethod + err := filepath.Walk(directory, func(path string, info os.FileInfo, err error) error { + if err != nil { + return err + } - if skipAgent != true { - // Connect to SSH agent to ask for unencrypted private keys - if sshAgentConn, err := net.Dial("unix", os.Getenv("SSH_AUTH_SOCK")); err == nil { - sshAgent := agent.NewClient(sshAgentConn) - keys, _ := sshAgent.List() - if len(keys) > 0 { - agentAuthmethods := ssh.PublicKeysCallback(sshAgent.Signers) - authMethods = append(authMethods, agentAuthmethods) + if !info.IsDir() && filepath.Ext(path) != ".pub" { + + // let's test this is a valid ssh key + am, err := getAuthMethodFromPrivateKey(path) + if err != nil { + switch { + case isPassphraseMissingError(err): + LogDebugInfo(fmt.Sprintf("Found a passphrase based ssh key: %v", err.Error()), os.Stdout) + default: + LogWarning(err.Error(), os.Stdout) + } + } else { + LogDebugInfo(fmt.Sprintf("Found a valid key at %v - will try auth", path), os.Stdout) + authMethods = append(authMethods, am) } } + return nil + }) + if err != nil { + return nil, err } + return authMethods, nil +} - privateKeyBytes, err := os.ReadFile(privateKeyfile) +func isPassphraseMissingError(err error) bool { + _, ok := err.(*ssh.PassphraseMissingError) + return ok +} - // if there are authMethods already, let's keep going - if err != nil && len(authMethods) == 0 { - return err, "" - } +func RemoteShellout(command string, remoteUser string, remoteHost string, remotePort string, privateKeyfile string, skipSshAgent bool) (error, string) { - if len(privateKeyBytes) > 0 { - // Parse the private key - signer, err := ssh.ParsePrivateKey(privateKeyBytes) - if err != nil { - return err, "" + sshAuthSock, present := os.LookupEnv("SSH_AUTH_SOCK") + skipAgent := !present || skipSshAgent + + var authMethods []ssh.AuthMethod + + if validAuthMethod == nil { // This makes it so that in subsequent calls, we don't have to recheck all auth methods + LogDebugInfo("First time running, no cached valid auth methods", os.Stdout) + authMethods = getAuthmethods(skipAgent, privateKeyfile, sshAuthSock, authMethods) + } else { + LogDebugInfo("Found existing auth method", os.Stdout) + authMethods = []ssh.AuthMethod{ + *validAuthMethod, } + } - // SSH client configuration - authKeys := ssh.PublicKeys(signer) - authMethods = append(authMethods, authKeys) + if len(authMethods) == 0 && validAuthMethod == nil { + return errors.New("No valid authentication methods provided"), "" } config := &ssh.ClientConfig{ @@ -65,10 +111,30 @@ func RemoteShellout(command string, remoteUser string, remoteHost string, remote HostKeyCallback: ssh.InsecureIgnoreHostKey(), } - // Connect to the remote server - client, err := ssh.Dial("tcp", remoteHost+":"+remotePort, config) - if err != nil { - return err, "" + var client *ssh.Client + var err error + + //we need to iterate over the auth methods till we find one that works + // for subsequent runs, this will only run once, since only whatever is in + // validAuthMethod will be attemtped. + for _, am := range authMethods { + config.Auth = []ssh.AuthMethod{ + am, + } + client, err = ssh.Dial("tcp", remoteHost+":"+remotePort, config) + if err != nil { + continue + } + + if validAuthMethod == nil { + LogDebugInfo("Dial success - caching auth method for subsequent runs", os.Stdout) + validAuthMethod = &am // set the valid auth method so that future calls won't need to retry + } + break + } + + if validAuthMethod == nil { + return errors.New("unable to find valid auth method for ssh"), "" } defer client.Close() @@ -103,3 +169,50 @@ func RemoteShellout(command string, remoteUser string, remoteHost string, remote return nil, outputBuffer.String() } + +func getAuthmethods(skipAgent bool, privateKeyfile string, sshAuthSock string, authMethods []ssh.AuthMethod) []ssh.AuthMethod { + if skipAgent != true && privateKeyfile == "" { + // Connect to SSH agent to ask for unencrypted private keys + if sshAgentConn, err := net.Dial("unix", sshAuthSock); err == nil { + sshAgent := agent.NewClient(sshAgentConn) + keys, _ := sshAgent.List() + if len(keys) > 0 { + agentAuthmethods := ssh.PublicKeysCallback(sshAgent.Signers) + authMethods = append(authMethods, agentAuthmethods) + } + } + } else { + LogDebugInfo("Skipping ssh agent", os.Stdout) + } + + if privateKeyfile == "" { //let's try guess it from the OS + userPath, err := os.UserHomeDir() + if err != nil { + LogWarning("No ssh key given and no home directory available", os.Stdout) + } + + userPath = filepath.Join(userPath, ".ssh") + + if _, err := os.Stat(userPath); err == nil { + sshAm, err := getSSHAuthMethodsFromDirectory(userPath) + if err != nil { + LogWarning(err.Error(), os.Stdout) + } + authMethods = append(authMethods, sshAm...) + } else { + LogWarning("Unable to find .ssh directory in user home", os.Stdout) + } + } else { + privateKeyFiles := []string{ + privateKeyfile, + } + + for _, kf := range privateKeyFiles { + am, err := getAuthMethodFromPrivateKey(kf) + if err == nil { + authMethods = append(authMethods, am) + } + } + } + return authMethods +} diff --git a/utils/shell_test.go b/utils/shell_test.go new file mode 100644 index 0000000..21fa318 --- /dev/null +++ b/utils/shell_test.go @@ -0,0 +1,89 @@ +package utils + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "fmt" + "os" + "path/filepath" + "testing" + "time" +) + +/** +* generatePrivateKey is used to generate a random private key - we're using this in our tests + */ +func generatePrivateKey(outputDir string) (string, error) { + // Generate a new private key + privateKey, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return "", err + } + + // Encode private key to PEM format + privateKeyBytes := x509.MarshalPKCS1PrivateKey(privateKey) + privateKeyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: privateKeyBytes}) + + // Generate a random filename + randomFilename := fmt.Sprintf("private_key_%d.pem", time.Now().UnixNano()) + + // Save private key to a file in the specified directory with the random filename + privateKeyPath := filepath.Join(outputDir, randomFilename) + err = os.WriteFile(privateKeyPath, privateKeyPEM, 0600) + if err != nil { + return "", err + } + + return privateKeyPath, nil +} + +const test_findSSHKeyFilesNumber = 3 + +func Test_findSSHKeyFiles(t *testing.T) { + tests := []struct { + name string + want int + wantErr bool + }{ + { + name: "Run on test directory", + want: test_findSSHKeyFilesNumber, + wantErr: false, + }, + } + + // Let's generate the files + tmpDir, err := os.MkdirTemp("", "keypair_test") + if err != nil { + t.Fatal("Unable to create temporary directory: ", err.Error()) + } + + defer func() { + os.RemoveAll(tmpDir) + }() + + privateKeys := []string{} + + for i := 0; i < test_findSSHKeyFilesNumber; i++ { + key, err := generatePrivateKey(tmpDir) + if err != nil { + t.Fatal("Unable to create private key: ", err.Error()) + } + privateKeys = append(privateKeys, key) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := getSSHAuthMethodsFromDirectory(tmpDir) + if (err != nil) != tt.wantErr { + t.Errorf("getSSHAuthMethodsFromDirectory() error = %v, wantErr %v", err, tt.wantErr) + return + } + if len(got) != tt.want { + t.Errorf("getSSHAuthMethodsFromDirectory() got = %v, want %v", got, tt.want) + } + }) + } +}