diff --git a/src/conduit-lwt-unix/conduit_lwt_tls.dummy.ml b/src/conduit-lwt-unix/conduit_lwt_tls.dummy.ml index 40955fae..23c77a07 100644 --- a/src/conduit-lwt-unix/conduit_lwt_tls.dummy.ml +++ b/src/conduit-lwt-unix/conduit_lwt_tls.dummy.ml @@ -9,6 +9,9 @@ end module Client = struct let connect ?src:_ ?certificates:_ ~authenticator:_ _host _sa = failwith "Tls not available" + + let tunnel ?certificates:_ ~authenticator:_ _host _ioc = + failwith "Tls not available" end module Server = struct diff --git a/src/conduit-lwt-unix/conduit_lwt_tls.dummy.mli b/src/conduit-lwt-unix/conduit_lwt_tls.dummy.mli index 27a3e4d0..6afef332 100644 --- a/src/conduit-lwt-unix/conduit_lwt_tls.dummy.mli +++ b/src/conduit-lwt-unix/conduit_lwt_tls.dummy.mli @@ -33,6 +33,13 @@ module Client : sig [ `host ] Domain_name.t -> Lwt_unix.sockaddr -> (Lwt_unix.file_descr * Lwt_io.input_channel * Lwt_io.output_channel) Lwt.t + + val tunnel : + ?certificates:'a -> + authenticator:X509.authenticator -> + [ `host ] Domain_name.t -> + Lwt_io.input_channel * Lwt_io.output_channel -> + (Lwt_io.input_channel * Lwt_io.output_channel) Lwt.t end module Server : sig diff --git a/src/conduit-lwt-unix/conduit_lwt_tls.real.ml b/src/conduit-lwt-unix/conduit_lwt_tls.real.ml index eb0be641..ce0cd467 100644 --- a/src/conduit-lwt-unix/conduit_lwt_tls.real.ml +++ b/src/conduit-lwt-unix/conduit_lwt_tls.real.ml @@ -30,19 +30,27 @@ module X509 = struct end module Client = struct + let config ?certificates authenticator = + match Tls.Config.client ~authenticator ?certificates () with + | Error (`Msg msg) -> failwith ("tls configuration problem: " ^ msg) + | Ok config -> config + let connect ?src ?certificates ~authenticator host sa = Conduit_lwt_server.with_socket sa (fun fd -> (match src with | None -> Lwt.return_unit | Some src_sa -> Lwt_unix.bind fd src_sa) >>= fun () -> - match Tls.Config.client ~authenticator ?certificates () with - | Error (`Msg msg) -> failwith ("tls configuration problem: " ^ msg) - | Ok config -> - Lwt_unix.connect fd sa >>= fun () -> - Tls_lwt.Unix.client_of_fd config ~host fd >|= fun t -> - let ic, oc = Tls_lwt.of_t t in - (fd, ic, oc)) + let config = config ?certificates authenticator in + Lwt_unix.connect fd sa >>= fun () -> + Tls_lwt.Unix.client_of_fd config ~host fd >|= fun t -> + let ic, oc = Tls_lwt.of_t t in + (fd, ic, oc)) + + let tunnel ?certificates ~authenticator host channels = + let config = config ?certificates authenticator in + Tls_lwt.Unix.client_of_channels config ~host channels >|= fun t -> + Tls_lwt.of_t t end module Server = struct diff --git a/src/conduit-lwt-unix/conduit_lwt_tls.real.mli b/src/conduit-lwt-unix/conduit_lwt_tls.real.mli index 1ad51ed2..8909179b 100644 --- a/src/conduit-lwt-unix/conduit_lwt_tls.real.mli +++ b/src/conduit-lwt-unix/conduit_lwt_tls.real.mli @@ -34,6 +34,13 @@ module Client : sig [ `host ] Domain_name.t -> Lwt_unix.sockaddr -> (Lwt_unix.file_descr * Lwt_io.input_channel * Lwt_io.output_channel) Lwt.t + + val tunnel : + ?certificates:Tls.Config.own_cert -> + authenticator:X509.authenticator -> + [ `host ] Domain_name.t -> + Lwt_io.input_channel * Lwt_io.output_channel -> + (Lwt_io.input_channel * Lwt_io.output_channel) Lwt.t end module Server : sig diff --git a/src/conduit-lwt-unix/conduit_lwt_unix.ml b/src/conduit-lwt-unix/conduit_lwt_unix.ml index fa3c4cdb..755a3a07 100644 --- a/src/conduit-lwt-unix/conduit_lwt_unix.ml +++ b/src/conduit-lwt-unix/conduit_lwt_unix.ml @@ -52,8 +52,8 @@ let () = (Sexplib0.Sexp.to_string (sexp_of_tls_lib !tls_library)) type +'a io = 'a Lwt.t -type ic = Lwt_io.input_channel -type oc = Lwt_io.output_channel +type ic = (Lwt_io.input_channel[@sexp.opaque]) [@@deriving sexp] +type oc = (Lwt_io.output_channel[@sexp.opaque]) [@@deriving sexp] type client_tls_config = [ `Hostname of string ] * [ `IP of Ipaddr_sexp.t ] * [ `Port of int ] @@ -61,6 +61,7 @@ type client_tls_config = type client = [ `TLS of client_tls_config + | `TLS_tunnel of [ `Hostname of string ] * ic * oc | `TLS_native of client_tls_config | `OpenSSL of client_tls_config | `TCP of [ `IP of Ipaddr_sexp.t ] * [ `Port of int ] @@ -140,6 +141,7 @@ type vchan_flow = { domid : int; port : string } [@@deriving sexp] type flow = | TCP of tcp_flow + | Tunnel of string * ic * oc | Domain_socket of domain_flow | Vchan of vchan_flow [@@deriving sexp] @@ -262,31 +264,42 @@ let set_max_active maxactive = Conduit_lwt_server.set_max_active maxactive (** TLS client connection functions *) -let connect_with_tls_native ~ctx (`Hostname hostname, `IP ip, `Port port) = - let sa = Unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in - (match ctx.tls_own_key with +let certificates ~ctx = + match ctx.tls_own_key with | `None -> Lwt.return_none | `TLS (_, _, `Password _) -> failwith "OCaml-TLS cannot handle encrypted pem files" | `TLS (`Crt_file_path cert, `Key_file_path priv_key, `No_password) -> Conduit_lwt_tls.X509.private_of_pems ~cert ~priv_key - >|= fun certificate -> Some (`Single certificate)) - >>= fun certificates -> - let hostname = - try Domain_name.(host_exn (of_string_exn hostname)) - with Invalid_argument msg -> - let s = - Printf.sprintf "couldn't convert %s to a [`host] Domain_name.t: %s" - hostname msg - in - invalid_arg s - in + >|= fun certificate -> Some (`Single certificate) + +let domain_name hostname = + try Domain_name.(host_exn (of_string_exn hostname)) + with Invalid_argument msg -> + let trace = Printexc.get_raw_backtrace () in + let msg = + Printf.sprintf "couldn't convert %s to a [`host] Domain_name.t: %s" + hostname msg + in + Printexc.raise_with_backtrace (Invalid_argument msg) trace + +let connect_with_tls_native ~ctx (`Hostname hostname, `IP ip, `Port port) = + let sa = Unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in + certificates ~ctx >>= fun certificates -> + let hostname = domain_name hostname in Conduit_lwt_tls.Client.connect ?src:ctx.src ?certificates ~authenticator:ctx.tls_authenticator hostname sa >|= fun (fd, ic, oc) -> let flow = TCP { fd; ip; port } in (flow, ic, oc) +let connect_with_tls_tunnel ~ctx (`Hostname hostname, ic, oc) = + certificates ~ctx >>= fun certificates -> + let host = domain_name hostname in + Conduit_lwt_tls.Client.tunnel ?certificates + ~authenticator:ctx.tls_authenticator host (ic, oc) + >|= fun (ic', oc') -> (Tunnel (hostname, ic, oc), ic', oc') + let connect_with_openssl ~ctx (`Hostname host_addr, `IP ip, `Port port) = let sa = Unix.ADDR_INET (Ipaddr_unix.to_inet_addr ip, port) in let ctx_ssl = @@ -331,6 +344,7 @@ let connect ~ctx (mode : client) = let flow = Domain_socket { fd; path } in Lwt.return (flow, ic, oc) | `TLS c -> connect_with_default_tls ~ctx c + | `TLS_tunnel c -> connect_with_tls_tunnel ~ctx c | `OpenSSL c -> connect_with_openssl ~ctx c | `TLS_native c -> connect_with_tls_native ~ctx c | `Vchan_direct _ -> failwith "Vchan_direct not available on unix" @@ -414,14 +428,17 @@ let serve ?backlog ?timeout ?stop ~on_exn ~(ctx : ctx) ~(mode : server) callback let fn s = Sockaddr_server.init ~on:(`Socket s) ?timeout ?stop callback in Conduit_lwt_launchd.activate fn name +type endp = [ Conduit.endp | `TLS_tunnel of string * ic * oc ] [@@deriving sexp] + let endp_of_flow = function | TCP { ip; port; _ } -> `TCP (ip, port) + | Tunnel (hostname, ic, oc) -> `TLS_tunnel (hostname, ic, oc) | Domain_socket { path; _ } -> `Unix_domain_socket path | Vchan { domid; port } -> `Vchan_direct (domid, port) (** Use the configuration of the server to interpret how to handle a particular endpoint from the resolver into a concrete implementation of type [client] *) -let endp_to_client ~ctx:_ (endp : Conduit.endp) : client Lwt.t = +let endp_to_client ~ctx:_ (endp : [< endp ]) : client Lwt.t = match endp with | `TCP (ip, port) -> Lwt.return (`TCP (`IP ip, `Port port)) | `Unix_domain_socket file -> Lwt.return (`Unix_domain_socket (`File file)) @@ -435,6 +452,8 @@ let endp_to_client ~ctx:_ (endp : Conduit.endp) : client Lwt.t = Printf.ksprintf failwith "TLS to non-TCP currently unsupported: host=%s endp=%s" host (Sexplib0.Sexp.to_string_hum (Conduit.sexp_of_endp endp)) + | `TLS_tunnel (host, ic, oc) -> + Lwt.return (`TLS_tunnel (`Hostname host, ic, oc)) | `Unknown err -> failwith ("resolution failed: " ^ err) let endp_to_server ~ctx (endp : Conduit.endp) = diff --git a/src/conduit-lwt-unix/conduit_lwt_unix.mli b/src/conduit-lwt-unix/conduit_lwt_unix.mli index 80285250..2bf539f7 100644 --- a/src/conduit-lwt-unix/conduit_lwt_unix.mli +++ b/src/conduit-lwt-unix/conduit_lwt_unix.mli @@ -26,8 +26,13 @@ type client_tls_config = [@@deriving sexp] (** Configuration fragment for a TLS client connecting to a remote endpoint *) +type 'a io = 'a Lwt.t +type ic = (Lwt_io.input_channel[@sexp.opaque]) [@@deriving sexp] +type oc = (Lwt_io.output_channel[@sexp.opaque]) [@@deriving sexp] + type client = [ `TLS of client_tls_config + | `TLS_tunnel of [ `Hostname of string ] * ic * oc | `TLS_native of client_tls_config (** Force use of native OCaml TLS stack to connect.*) | `OpenSSL of client_tls_config @@ -103,10 +108,6 @@ type server = the {{:http://mirage.github.io/ocaml-launchd/launchd/} ocaml-launchd} documentation for more. *) -type 'a io = 'a Lwt.t -type ic = Lwt_io.input_channel -type oc = Lwt_io.output_channel - type tcp_flow = private { fd : Lwt_unix.file_descr; [@sexp.opaque] ip : Ipaddr.t; @@ -129,6 +130,7 @@ type vchan_flow = private { domid : int; port : string } [@@deriving sexp_of] transport method. *) type flow = private | TCP of tcp_flow + | Tunnel of string * ic * oc | Domain_socket of domain_flow | Vchan of vchan_flow [@@deriving sexp_of] @@ -204,11 +206,18 @@ val set_max_active : int -> unit accepted. When the limit is hit accept blocks until another server connection is closed. *) -val endp_of_flow : flow -> Conduit.endp -(** [endp_of_flow flow] retrieves the original {!Conduit.endp} from the - established [flow] *) +type endp = + [ Conduit.endp + | `TLS_tunnel of string * ic * oc + (** Wrap in a TLS channel over an existing [Lwt_io.channel] connection, + [hostname,input_channel,output_channel] *) ] +[@@deriving sexp] + +val endp_of_flow : flow -> endp +(** [endp_of_flow flow] retrieves the original {!endp} from the established + [flow] *) -val endp_to_client : ctx:ctx -> Conduit.endp -> client io +val endp_to_client : ctx:ctx -> [< endp ] -> client io (** [endp_to_client ~ctx endp] converts an [endp] into a a concrete connection mechanism of type [client] *) diff --git a/tests/conduit-lwt-unix/dune b/tests/conduit-lwt-unix/dune index c28486d0..4c2fd3ef 100644 --- a/tests/conduit-lwt-unix/dune +++ b/tests/conduit-lwt-unix/dune @@ -1,6 +1,6 @@ (executables (libraries lwt_ssl ssl conduit-lwt-unix lwt_log) - (names cdtest cdtest_tls exit_test)) + (names cdtest cdtest_tls exit_test tls_over_tls)) (rule (alias runtest) diff --git a/tests/conduit-lwt-unix/tls_over_tls.ml b/tests/conduit-lwt-unix/tls_over_tls.ml new file mode 100644 index 00000000..d34d1264 --- /dev/null +++ b/tests/conduit-lwt-unix/tls_over_tls.ml @@ -0,0 +1,61 @@ +open Lwt.Infix + +let hostname = "mirage.io" + +(* To test TLS-over-TLS, the `squid` proxy can be installed locally and configured to support HTTPS: + + - Generate a certificate for localhost: https://gist.github.com/cecilemuller/9492b848eb8fe46d462abeb26656c4f8 + + $ openssl req -x509 -nodes -new -sha256 -days 1024 -newkey rsa:2048 -keyout RootCA.key -out RootCA.pem -subj "/C=US/CN=Example-Root-CA" + $ openssl x509 -outform pem -in RootCA.pem -out RootCA.crt + $ cat > domains.ext + authorityKeyIdentifier=keyid,issuer + basicConstraints=CA:FALSE + keyUsage = digitalSignature, nonRepudiation, keyEncipherment, dataEncipherment + subjectAltName = @alt_names + [alt_names] + DNS.1 = localhost + $ openssl req -new -nodes -newkey rsa:2048 -keyout localhost.key -out localhost.csr -subj "/C=US/ST=YourState/L=YourCity/O=Example-Certificates/CN=localhost.local" + $ openssl x509 -req -sha256 -days 1024 -in localhost.csr -CA RootCA.pem -CAkey RootCA.key -CAcreateserial -extfile domains.ext -out localhost.crt + + - Configure squid by adding HTTPS support on port 3129 in /etc/squid/squid.conf : + + https_port 3129 tls-cert=/path/to/localhost.crt tls-key=/path/to/localhost.key +*) + +let proxy = + `TLS + (`Hostname "localhost", `IP (Ipaddr.of_string_exn "127.0.0.1"), `Port 3129) + +let string_prefix ~prefix msg = + let len = String.length prefix in + String.length msg >= len && String.sub msg 0 len = prefix + +let main () = + let ctx = Lazy.force Conduit_lwt_unix.default_ctx in + Conduit_lwt_unix.connect ~ctx proxy >>= fun (_flow, ic, oc) -> + let req = + String.concat "\r\n" + [ "CONNECT " ^ hostname ^ ":443 HTTP/1.1"; "Host: " ^ hostname; ""; "" ] + in + Lwt_io.write oc req >>= fun () -> + let rec try_read () = + Lwt_io.read ic ~count:1024 >>= fun msg -> + if msg = "" then try_read () else Lwt.return msg + in + try_read () >>= fun msg -> + assert (string_prefix ~prefix:"HTTP/1.1 200 " msg); + + (* We are now connected to mirage.io:443 through the proxy *) + let client = `TLS_tunnel (`Hostname hostname, ic, oc) in + Conduit_lwt_unix.connect ~ctx client >>= fun (_flow, ic, oc) -> + let req = + String.concat "\r\n" [ "GET / HTTP/1.1"; "Host: " ^ hostname; ""; "" ] + in + Lwt_io.write oc req >>= fun () -> + Lwt_io.read ic ~count:4096 >>= fun msg -> + Lwt_io.print msg >>= fun () -> + Lwt_io.read ic ~count:4096 >>= fun msg -> + Lwt_io.print msg >>= fun () -> Lwt_io.print "\n" + +let () = Lwt_main.run (main ())