Skip to content

Commit

Permalink
fix #175, add argv[0] setting with rssh://test?argv=name parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
NHAS committed Oct 31, 2024
1 parent 3b9134f commit 6370728
Show file tree
Hide file tree
Showing 4 changed files with 44 additions and 19 deletions.
2 changes: 1 addition & 1 deletion internal/client/handlers/jumphost.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ func JumpHandler(sshPriv ssh.Signer, serverConn ssh.Conn) func(newChannel ssh.Ne
clientLog := logger.NewLog(serverConn.RemoteAddr().String())
clientLog.Info("New SSH connection, version %s", conn.ClientVersion())

session := connection.NewSession(conn)
session := connection.NewSession(serverConn)

go func(in <-chan *ssh.Request) {
for r := range in {
Expand Down
36 changes: 26 additions & 10 deletions internal/client/handlers/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo
return
}

if u, ok := isUrl(command); ok {
u, ok := isUrl(command)
if ok {
command, err = download(session.ServerConnection, u)
if err != nil {
fmt.Fprintf(connection, "%s", err.Error())
Expand All @@ -96,10 +97,10 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo
}

if session.Pty != nil {
runCommandWithPty(command, line.Chunks[1:], session.Pty, requests, log, connection)
runCommandWithPty(u.Query().Get("argv"), command, line.Chunks[1:], session.Pty, requests, log, connection)
return
}
runCommand(command, line.Chunks[1:], connection)
runCommand(u.Query().Get("argv"), command, line.Chunks[1:], connection)

return
case "shell":
Expand All @@ -118,15 +119,16 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo
parts := strings.Split(shellPath.Cmd, " ")
if len(parts) > 0 {
command := parts[0]
if u, ok := isUrl(parts[0]); ok {
u, ok := isUrl(parts[0])
if ok {
command, err = download(session.ServerConnection, u)
if err != nil {
fmt.Fprintf(connection, "%s", err.Error())
return
}
}

runCommandWithPty(command, parts[1:], session.Pty, requests, log, connection)
runCommandWithPty(u.Query().Get("argv"), command, parts[1:], session.Pty, requests, log, connection)
}
return
//Yes, this is here for a reason future me. Despite the RFC saying "Only one of shell,subsystem, exec can occur per channel" pty-req actually proceeds all of them
Expand All @@ -152,7 +154,7 @@ func Session(session *connection.Session) func(newChannel ssh.NewChannel, log lo
}
}

func runCommand(command string, args []string, connection ssh.Channel) {
func runCommand(argv string, command string, args []string, connection ssh.Channel) {
//Set a path if no path is set to search
if len(os.Getenv("PATH")) == 0 {
if runtime.GOOS != "windows" {
Expand All @@ -163,6 +165,9 @@ func runCommand(command string, args []string, connection ssh.Channel) {
}

cmd := exec.Command(command, args...)
if len(argv) != 0 {
cmd.Args[0] = argv
}

stdout, err := cmd.StdoutPipe()
if err != nil {
Expand Down Expand Up @@ -204,23 +209,34 @@ func isUrl(data string) (*url.URL, bool) {
}

func download(serverConnection ssh.Conn, fromUrl *url.URL) (result string, err error) {
if fromUrl == nil {
return "", errors.New("url was nil")
}

var (
reader io.ReadCloser
filename string
)

switch fromUrl.Scheme {
urlCopy := *fromUrl

query := urlCopy.Query()
query.Del("argv")

urlCopy.RawQuery = query.Encode()

switch urlCopy.Scheme {
case "http", "https":
resp, err := http.Get(fromUrl.String())

resp, err := http.Get(urlCopy.String())
if err != nil {
return "", err
}
defer resp.Body.Close()

reader = resp.Body

filename = path.Base(fromUrl.Path)
filename = path.Base(urlCopy.Path)
if filename == "." {
filename, err = internal.RandomString(16)
if err != nil {
Expand All @@ -229,7 +245,7 @@ func download(serverConnection ssh.Conn, fromUrl *url.URL) (result string, err e
}

case "rssh":
filename = path.Base(strings.TrimSuffix(fromUrl.String(), "rssh://"))
filename = path.Base(strings.TrimSuffix(urlCopy.String(), "rssh://"))

ch, reqs, err := serverConnection.OpenChannel("rssh-download", []byte(filename))
if err != nil {
Expand Down
10 changes: 7 additions & 3 deletions internal/client/handlers/shell.go
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ func init() {

}

func runCommandWithPty(command string, args []string, ptyReq *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) {
func runCommandWithPty(argv string, command string, args []string, ptyReq *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) {

if ptyReq == nil {
log.Error("Requested to run a command with a pty, but did not start a pty")
Expand All @@ -133,6 +133,10 @@ func runCommandWithPty(command string, args []string, ptyReq *internal.PtyReq, r

// Fire up a shell for this session
shell := exec.Command(command, args...)
if len(argv) != 0 {
shell.Args[0] = argv
}

shell.Env = os.Environ()

close := func() {
Expand Down Expand Up @@ -208,10 +212,10 @@ func shell(ptyReq *internal.PtyReq, connection ssh.Channel, requests <-chan *ssh
}

if ptyReq != nil {
runCommandWithPty(path, nil, ptyReq, requests, log, connection)
runCommandWithPty("", path, nil, ptyReq, requests, log, connection)
return
}

runCommand(path, nil, connection)
runCommand("", path, nil, connection)

}
15 changes: 10 additions & 5 deletions internal/client/handlers/shell_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,13 +35,13 @@ func shell(ptyReq *internal.PtyReq, connection ssh.Channel, requests <-chan *ssh
}
}

runCommandWithPty(path, nil, ptyReq, requests, log, connection)
runCommandWithPty("", path, nil, ptyReq, requests, log, connection)

connection.Close()

}

func runCommandWithPty(command string, args []string, pty *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) {
func runCommandWithPty(argv, command string, args []string, pty *internal.PtyReq, requests <-chan *ssh.Request, log logger.Logger, connection ssh.Channel) {

fullCommand := command + " " + strings.Join(args, " ")
vsn := windows.RtlGetVersion()
Expand All @@ -51,7 +51,7 @@ func runCommandWithPty(command string, args []string, pty *internal.PtyReq, requ
runWithWinPty(fullCommand, connection, requests, log, pty)

} else {
err := runWithConpty(fullCommand, connection, requests, log, pty)
err := runWithConpty(argv, fullCommand, connection, requests, log, pty)
if err != nil {
log.Error("unable to run with conpty, falling back to winpty: %v", err)
runWithWinPty(fullCommand, connection, requests, log, pty)
Expand Down Expand Up @@ -106,7 +106,7 @@ func runWithWinPty(command string, connection ssh.Channel, reqs <-chan *ssh.Requ
return nil
}

func runWithConpty(command string, connection ssh.Channel, reqs <-chan *ssh.Request, log logger.Logger, ptyReq *internal.PtyReq) error {
func runWithConpty(argv, command string, connection ssh.Channel, reqs <-chan *ssh.Request, log logger.Logger, ptyReq *internal.PtyReq) error {

cpty, err := conpty.New(int16(ptyReq.Columns), int16(ptyReq.Rows))
if err != nil {
Expand All @@ -118,10 +118,15 @@ func runWithConpty(command string, connection ssh.Channel, reqs <-chan *ssh.Requ
return err
}

argvParts := []string{}
if len(argv) != 0 {
argvParts = []string{argv}
}

// Spawn and catch new powershell process
pid, _, err := cpty.Spawn(
path,
[]string{},
argvParts,
&syscall.ProcAttr{
Env: os.Environ(),
},
Expand Down

0 comments on commit 6370728

Please sign in to comment.