Skip to content

Commit

Permalink
Merge pull request #10 from JuliaWeb/sockets
Browse files Browse the repository at this point in the history
Add support for forwarding directly to a TCPSocket
  • Loading branch information
JamesWrigley authored Mar 13, 2024
2 parents aa32a26 + ff03161 commit 0655dbc
Show file tree
Hide file tree
Showing 6 changed files with 160 additions and 81 deletions.
7 changes: 7 additions & 0 deletions docs/src/changelog.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,13 @@ CurrentModule = LibSSH
This documents notable changes in LibSSH.jl. The format is based on [Keep a
Changelog](https://keepachangelog.com).

## Unreleased

### Added

- A new [`Forwarder(::Session, ::String, ::Int)`](@ref) constructor to allow for
forwarding a port to an internal socket instead of to a port ([#10]).

## [v0.4.0] - 2024-03-12

### Added
Expand Down
1 change: 1 addition & 0 deletions docs/src/sessions_and_channels.md
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ Base.success(::Cmd, ::Session)
```@docs
Forwarder
Forwarder(::Session, ::Int, ::String, ::Int)
Forwarder(::Session, ::String, ::Int)
Forwarder(::Function)
Base.close(::Forwarder)
```
1 change: 1 addition & 0 deletions src/LibSSH.jl
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,7 @@ function _safe_poll_fd(args...; kwargs...)
return result
end

include("utils.jl")
include("gssapi.jl")
include("pki.jl")
include("callbacks.jl")
Expand Down
200 changes: 121 additions & 79 deletions src/channel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -703,6 +703,46 @@ mutable struct _ForwardingClient
sshchan::SshChannel
callbacks::Callbacks.ChannelCallbacks
client_task::Union{Task, Nothing}

function _ForwardingClient(forwarder, socket::TCPSocket)
remotehost = forwarder.remotehost
remoteport = forwarder.remoteport

# Open a forwarding channel
local_ip = string(getaddrinfo(gethostname()))
sshchan = SshChannel(forwarder._session)
ret = _session_trywait(forwarder._session) do
GC.@preserve remotehost local_ip begin
lib.ssh_channel_open_forward(sshchan.ptr,
Base.unsafe_convert(Ptr{Cchar}, remotehost), remoteport,
Base.unsafe_convert(Ptr{Cchar}, local_ip), forwarder.localport)
end
end
if ret != SSH_OK
throw(LibSSHException("Could not open a forwarding channel: $(get_error(forwarder._session))"))
end

# Set callbacks for the channel
callbacks = Callbacks.ChannelCallbacks(nothing;
on_data=_on_client_channel_data,
on_eof=_on_client_channel_eof,
on_close=_on_client_channel_close)
set_channel_callbacks(sshchan, callbacks)

# Create a client and set the callbacks userdata to the new client object
self = new(forwarder._next_client_id, forwarder.verbose, socket,
sshchan, callbacks, nothing)
callbacks.userdata = self

# Start a listener on the new socket to forward data to the server
self.client_task = Threads.@spawn try
_handle_forwarding_client(self)
catch ex
@error "Error when handling SSH port forward client $(self.id)!" exception=(ex, catch_backtrace())
end

return self
end
end

# Helper function to log messages from a forwarding client
Expand Down Expand Up @@ -742,69 +782,98 @@ end
$(TYPEDEF)
$(TYPEDFIELDS)
This object manages a direct forwarding channel between `localport` and `remotehost:remoteport`.
This object manages a direct forwarding channel between `localport` and
`remotehost:remoteport`. Fields beginning with an underscore `_` are private and
should not be used.
"""
mutable struct Forwarder
@kwdef mutable struct Forwarder
remotehost::String
remoteport::Int
localinterface::Sockets.IPAddr
localport::Int
localinterface::Sockets.IPAddr = Sockets.localhost
localport::Int = -1

out::Union{TCPSocket, Nothing} = nothing

_listen_server::TCPServer
_listener_task::Union{Task, Nothing}
_clients::Vector{_ForwardingClient}
_listen_server::TCPServer = TCPServer()
_listener_task::Union{Task, Nothing} = nothing
_clients::Vector{_ForwardingClient} = _ForwardingClient[]
_next_client_id::Int = 1

_session::Session
verbose::Bool
end

@doc """
$(TYPEDSIGNATURES)
"""
$(TYPEDSIGNATURES)
Create a `Forwarder` object to forward data from `localport` to
`remotehost:remoteport`. This will handle an internal [`SshChannel`](@ref)
for forwarding.
# Arguments
- `session`: The session to create a forwarding channel over.
- `localport`: The local port to bind to.
- `remotehost`: The remote host.
- `remoteport`: The remote port to bind to.
- `verbose`: Print logging messages on callbacks etc (not equivalent to
setting `log_verbosity` on a [`Session`](@ref)).
- `localinterface=IPv4(0)`: The interface to bind `localport` on.
"""
function Forwarder(session::Session, localport::Int, remotehost::String, remoteport::Int;
verbose=false, localinterface::Sockets.IPAddr=IPv4(0))
listen_server = Sockets.listen(localinterface, localport)
Create a `Forwarder` object that will forward its data to a single
`TCPSocket`. This is useful if there is only one client and binding to a port
available to other processes is not desirable. The socket will be stored in the
`Forwarder.out` property, and it will be closed when the `Forwarder` is closed.
self = new(remotehost, remoteport, localinterface, localport,
listen_server, nothing, _ForwardingClient[],
session, verbose)
All arguments mean the same as in [`Forwarder(::Session, ::Int, ::String,
::Int)`](@ref).
"""
function Forwarder(session::Session, remotehost::String, remoteport::Int;
verbose=false)
sock1, sock2 = _socketpair()
self = Forwarder(; remotehost, remoteport, out=sock2, _session=session, verbose)
push!(self._clients, _ForwardingClient(self, sock1))

# Start the listener
self._listener_task = Threads.@spawn try
_fwd_listen(self)
catch ex
@error "Error in listen loop for Forwarder!" exception=(ex, catch_backtrace())
end
return self
end

finalizer(close, self)
"""
$(TYPEDSIGNATURES)
Create a `Forwarder` object to forward data from `localport` to
`remotehost:remoteport`. This will handle an internal [`SshChannel`](@ref)
for forwarding.
# Arguments
- `session`: The session to create a forwarding channel over.
- `localport`: The local port to bind to.
- `remotehost`: The remote host.
- `remoteport`: The remote port to bind to.
- `verbose`: Print logging messages on callbacks etc (not equivalent to
setting `log_verbosity` on a [`Session`](@ref)).
- `localinterface=IPv4(0)`: The interface to bind `localport` on.
"""
function Forwarder(session::Session, localport::Int, remotehost::String, remoteport::Int;
verbose=false, localinterface::Sockets.IPAddr=IPv4(0))
_listen_server = Sockets.listen(localinterface, localport)

self = Forwarder(; remotehost, remoteport, localinterface, localport,
_listen_server, _session=session, verbose)

# Start the listener
self._listener_task = Threads.@spawn try
_fwd_listen(self)
catch ex
@error "Error in listen loop for Forwarder!" exception=(ex, catch_backtrace())
end

finalizer(close, self)
end


function Base.show(io::IO, f::Forwarder)
if !isopen(f)
print(io, Forwarder, "()")
else
print(io, Forwarder, "($(f.localinterface):$(f.localport)$(f.remotehost):$(f.remoteport))")
if isnothing(forwarder.out)
print(io, Forwarder, "($(f.localinterface):$(f.localport)$(f.remotehost):$(f.remoteport))")
else
print(io, Forwarder, "($(f.out)$(f.remotehost):$(f.remoteport))")
end
end
end

"""
$(TYPEDSIGNATURES)
Do-constructor for a `Forwarder`. All arguments are forwarded to the
[`Forwarder(::Session, ::Int, ::String, ::Int)`](@ref) constructor.
Do-constructor for a `Forwarder`. All arguments are forwarded to the other
constructors.
"""
function Forwarder(f::Function, args...; kwargs...)
forwarder = Forwarder(args...; kwargs...)
Expand All @@ -825,23 +894,29 @@ socket.
function Base.close(forwarder::Forwarder)
# Stop accepting new clients
close(forwarder._listen_server)
wait(forwarder._listener_task)
if !isnothing(forwarder._listener_task)
wait(forwarder._listener_task)
end

# Close existing clients
for client in forwarder._clients
close(client)
end
end

Base.isopen(forwarder::Forwarder) = isopen(forwarder._listen_server)
function Base.isopen(forwarder::Forwarder)
# If we're forwarding to a bound port then check if the TCPServer is
# running, otherwise check if the single client socket is still open.
if isnothing(forwarder.out)
isopen(forwarder._listen_server)
else
isopen(forwarder.out)
end
end

# This function accepts connections on the local port and sets up
# _ForwardingClient's for them.
function _fwd_listen(forwarder::Forwarder)
next_client_id = 1
remotehost = forwarder.remotehost
remoteport = forwarder.remoteport

while isopen(forwarder._listen_server)
local sock
try
Expand All @@ -854,40 +929,7 @@ function _fwd_listen(forwarder::Forwarder)
end
end

# Open a forwarding channel
local_ip = string(getaddrinfo(gethostname()))
sshchan = SshChannel(forwarder._session)
ret = _session_trywait(forwarder._session) do
GC.@preserve remotehost local_ip begin
lib.ssh_channel_open_forward(sshchan.ptr,
Base.unsafe_convert(Ptr{Cchar}, remotehost), remoteport,
Base.unsafe_convert(Ptr{Cchar}, local_ip), forwarder.localport)
end
end
if ret != SSH_OK
throw(LibSSHException("Could not open a forwarding channel: $(get_error(forwarder._session))"))
end

# Set callbacks for the channel
callbacks = Callbacks.ChannelCallbacks(nothing;
on_data=_on_client_channel_data,
on_eof=_on_client_channel_eof,
on_close=_on_client_channel_close)
set_channel_callbacks(sshchan, callbacks)

# Create a client and set the callbacks userdata to the new client object
client = _ForwardingClient(next_client_id, forwarder.verbose, sock,
sshchan, callbacks, nothing)
callbacks.userdata = client

# Start a listener on the new socket to forward data to the server
client.client_task = Threads.@spawn try
_handle_forwarding_client(client)
catch ex
@error "Error when handling SSH port forward client $(client.id)!" exception=(ex, catch_backtrace())
end

push!(forwarder._clients, client)
next_client_id += 1
push!(forwarder._clients, _ForwardingClient(forwarder, sock))
forwarder._next_client_id += 1
end
end
14 changes: 14 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import Sockets

# High-level, portable implementation of socketpair(2)
function _socketpair()
port, server = Sockets.listenany(Sockets.localhost, 2048)
acceptor = Threads.@spawn Sockets.accept(server)

sock1 = Sockets.connect(Sockets.localhost, port)
sock2 = fetch(acceptor)

close(server)

return sock1, sock2
end
18 changes: 16 additions & 2 deletions test/LibSSHTests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import LibSSH.Demo: DemoServer

username() = Sys.iswindows() ? ENV["USERNAME"] : ENV["USER"]

const HTTP_200 = "HTTP/1.1 200 OK\r\n\r\n"

# Dummy HTTP server that only responds 200 to requests
function http_server(f::Function, port)
start_event = Base.Event()
Expand All @@ -38,7 +40,7 @@ function http_server(f::Function, port)
# Wait for any request, doesn't matter what
data = readavailable(sock)
if !isempty(data)
write(sock, "HTTP/1.1 200 OK\r\n\r\n")
write(sock, HTTP_200)
end

closewrite(sock)
Expand Down Expand Up @@ -428,12 +430,13 @@ end
end

@testset "Direct port forwarding" begin
# Test port forwarding
# Smoke test
demo_server_with_session(2222) do session
forwarder = ssh.Forwarder(session, 8080, "localhost", 9090)
close(forwarder)
end

# Test forwarding to a port
demo_server_with_session(2222) do session
ssh.Forwarder(session, 8080, "localhost", 9090) do forwarder
http_server(9090) do
Expand All @@ -448,6 +451,17 @@ end
end
end
end

# Test forwarding to a socket
demo_server_with_session(2222) do session
ssh.Forwarder(session, "localhost", 9090) do forwarder
http_server(9090) do
socket = forwarder.out
write(socket, "foo")
@test read(socket, String) == HTTP_200
end
end
end
end
end

Expand Down

0 comments on commit 0655dbc

Please sign in to comment.