diff --git a/internal/client/handlers/jumphost.go b/internal/client/handlers/jumphost.go index a88e98c..7887783 100644 --- a/internal/client/handlers/jumphost.go +++ b/internal/client/handlers/jumphost.go @@ -52,7 +52,7 @@ func JumpHandler(sshPriv ssh.Signer, serverConn ssh.Conn) func(newChannel ssh.Ne clientLog := logger.NewLog(serverConn.RemoteAddr().String()) clientLog.Info("New SSH connection, version %s", conn.ClientVersion()) - session := connection.NewSession(conn) + session := connection.NewSession(serverConn) go func(in <-chan *ssh.Request) { for r := range in { diff --git a/internal/client/handlers/session.go b/internal/client/handlers/session.go index 86c36a8..7a51d9f 100644 --- a/internal/client/handlers/session.go +++ b/internal/client/handlers/session.go @@ -87,7 +87,8 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo return } - if u, ok := isUrl(command); ok { + u, ok := isUrl(command) + if ok { command, err = download(session.ServerConnection, u) if err != nil { fmt.Fprintf(connection, "%s", err.Error()) @@ -96,10 +97,10 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo } if session.Pty != nil { - runCommandWithPty(command, line.Chunks[1:], session.Pty, requests, log, connection) + runCommandWithPty(u.Query().Get("argv"), command, line.Chunks[1:], session.Pty, requests, log, connection) return } - runCommand(command, line.Chunks[1:], connection) + runCommand(u.Query().Get("argv"), command, line.Chunks[1:], connection) return case "shell": @@ -118,7 +119,8 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo parts := strings.Split(shellPath.Cmd, " ") if len(parts) > 0 { command := parts[0] - if u, ok := isUrl(parts[0]); ok { + u, ok := isUrl(parts[0]) + if ok { command, err = download(session.ServerConnection, u) if err != nil { fmt.Fprintf(connection, "%s", err.Error()) @@ -126,7 +128,7 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo } } - runCommandWithPty(command, parts[1:], session.Pty, requests, log, connection) + runCommandWithPty(u.Query().Get("argv"), command, parts[1:], session.Pty, requests, log, connection) } return //Yes, this is here for a reason future me. Despite the RFC saying "Only one of shell,subsystem, exec can occur per channel" pty-req actually proceeds all of them @@ -152,7 +154,7 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo } } -func runCommand(command string, args []string, connection ssh.Channel) { +func runCommand(argv string, command string, args []string, connection ssh.Channel) { //Set a path if no path is set to search if len(os.Getenv("PATH")) == 0 { if runtime.GOOS != "windows" { @@ -163,6 +165,9 @@ func runCommand(command string, args []string, connection ssh.Channel) { } cmd := exec.Command(command, args...) + if len(argv) != 0 { + cmd.Args[0] = argv + } stdout, err := cmd.StdoutPipe() if err != nil { @@ -204,15 +209,26 @@ func isUrl(data string) (*url.URL, bool) { } func download(serverConnection ssh.Conn, fromUrl *url.URL) (result string, err error) { + if fromUrl == nil { + return "", errors.New("url was nil") + } var ( reader io.ReadCloser filename string ) - switch fromUrl.Scheme { + urlCopy := *fromUrl + + query := urlCopy.Query() + query.Del("argv") + + urlCopy.RawQuery = query.Encode() + + switch urlCopy.Scheme { case "http", "https": - resp, err := http.Get(fromUrl.String()) + + resp, err := http.Get(urlCopy.String()) if err != nil { return "", err } @@ -220,7 +236,7 @@ func download(serverConnection ssh.Conn, fromUrl *url.URL) (result string, err e reader = resp.Body - filename = path.Base(fromUrl.Path) + filename = path.Base(urlCopy.Path) if filename == "." { filename, err = internal.RandomString(16) if err != nil { @@ -229,7 +245,7 @@ func download(serverConnection ssh.Conn, fromUrl *url.URL) (result string, err e } case "rssh": - filename = path.Base(strings.TrimSuffix(fromUrl.String(), "rssh://")) + filename = path.Base(strings.TrimSuffix(urlCopy.String(), "rssh://")) ch, reqs, err := serverConnection.OpenChannel("rssh-download", []byte(filename)) if err != nil { diff --git a/internal/client/handlers/shell.go b/internal/client/handlers/shell.go index 976574a..c2f3463 100644 --- a/internal/client/handlers/shell.go +++ b/internal/client/handlers/shell.go @@ -124,7 +124,7 @@ func init() { } -func runCommandWithPty(command string, args []string, ptyReq *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) { +func runCommandWithPty(argv string, command string, args []string, ptyReq *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) { if ptyReq == nil { log.Error("Requested to run a command with a pty, but did not start a pty") @@ -133,6 +133,10 @@ func runCommandWithPty(command string, args []string, ptyReq *internal.PtyReq, r // Fire up a shell for this session shell := exec.Command(command, args...) + if len(argv) != 0 { + shell.Args[0] = argv + } + shell.Env = os.Environ() close := func() { @@ -208,10 +212,10 @@ func shell(ptyReq *internal.PtyReq, connection ssh.Channel, requests <-chan *ssh } if ptyReq != nil { - runCommandWithPty(path, nil, ptyReq, requests, log, connection) + runCommandWithPty("", path, nil, ptyReq, requests, log, connection) return } - runCommand(path, nil, connection) + runCommand("", path, nil, connection) } diff --git a/internal/client/handlers/shell_windows.go b/internal/client/handlers/shell_windows.go index 3dbe98e..b88e854 100644 --- a/internal/client/handlers/shell_windows.go +++ b/internal/client/handlers/shell_windows.go @@ -35,13 +35,13 @@ func shell(ptyReq *internal.PtyReq, connection ssh.Channel, requests <-chan *ssh } } - runCommandWithPty(path, nil, ptyReq, requests, log, connection) + runCommandWithPty("", path, nil, ptyReq, requests, log, connection) connection.Close() } -func runCommandWithPty(command string, args []string, pty *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) { +func runCommandWithPty(argv, command string, args []string, pty *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) { fullCommand := command + " " + strings.Join(args, " ") vsn := windows.RtlGetVersion() @@ -51,7 +51,7 @@ func runCommandWithPty(command string, args []string, pty *internal.PtyReq, requ runWithWinPty(fullCommand, connection, requests, log, pty) } else { - err := runWithConpty(fullCommand, connection, requests, log, pty) + err := runWithConpty(argv, fullCommand, connection, requests, log, pty) if err != nil { log.Error("unable to run with conpty, falling back to winpty: %v", err) runWithWinPty(fullCommand, connection, requests, log, pty) @@ -106,7 +106,7 @@ func runWithWinPty(command string, connection ssh.Channel, reqs <-chan *ssh.Requ return nil } -func runWithConpty(command string, connection ssh.Channel, reqs <-chan *ssh.Request, log logger.Logger, ptyReq *internal.PtyReq) error { +func runWithConpty(argv, command string, connection ssh.Channel, reqs <-chan *ssh.Request, log logger.Logger, ptyReq *internal.PtyReq) error { cpty, err := conpty.New(int16(ptyReq.Columns), int16(ptyReq.Rows)) if err != nil { @@ -118,10 +118,15 @@ func runWithConpty(command string, connection ssh.Channel, reqs <-chan *ssh.Requ return err } + argvParts := []string{} + if len(argv) != 0 { + argvParts = []string{argv} + } + // Spawn and catch new powershell process pid, _, err := cpty.Spawn( path, - []string{}, + argvParts, &syscall.ProcAttr{ Env: os.Environ(), },