From c8567b6d7d2ef475810fe17cd0a080fa38cd0d2c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= Date: Mon, 22 Apr 2019 11:26:59 +0200 Subject: Make gun_tls_proxy a gen_statem There is now a not_connected state that is used to postpone events that can't be processed when the proxy socket is not ready. --- src/gun_tls_proxy.erl | 190 +++++++++++++++++++++++++------------------------ test/rfc7231_SUITE.erl | 54 +++----------- 2 files changed, 107 insertions(+), 137 deletions(-) diff --git a/src/gun_tls_proxy.erl b/src/gun_tls_proxy.erl index 23737ab..8adb5b6 100644 --- a/src/gun_tls_proxy.erl +++ b/src/gun_tls_proxy.erl @@ -39,8 +39,7 @@ %% to the fake ssl socket and then to the outgoing socket). -module(gun_tls_proxy). - --behaviour(gen_server). +-behaviour(gen_statem). %% Gun-specific interface. -export([start_link/6]). @@ -60,12 +59,12 @@ -export([sockname/1]). -export([close/1]). -%% gen_server. +%% Internals. +-export([callback_mode/0]). -export([init/1]). -export([connect_proc/5]). --export([handle_call/3]). --export([handle_cast/2]). --export([handle_info/2]). +-export([not_connected/3]). +-export([connected/3]). -record(state, { %% The pid of the owner process. This is where we send active messages. @@ -103,7 +102,7 @@ start_link(Host, Port, Opts, Timeout, OutSocket, OutTransport) -> ?DEBUG_LOG("host ~0p port ~0p opts ~0p timeout ~0p out_socket ~0p out_transport ~0p", [Host, Port, Opts, Timeout, OutSocket, OutTransport]), - case gen_server:start_link(?MODULE, + case gen_statem:start_link(?MODULE, {self(), Host, Port, Opts, Timeout, OutSocket, OutTransport}, []) of {ok, Pid} when is_port(OutSocket) -> @@ -120,12 +119,12 @@ start_link(Host, Port, Opts, Timeout, OutSocket, OutTransport) -> cb_controlling_process(Pid, ControllingPid) -> ?DEBUG_LOG("pid ~0p controlling_pid ~0p", [Pid, ControllingPid]), - gen_server:cast(Pid, {?FUNCTION_NAME, ControllingPid}). + gen_statem:cast(Pid, {?FUNCTION_NAME, ControllingPid}). cb_send(Pid, Data) -> ?DEBUG_LOG("pid ~0p data ~0p", [Pid, Data]), try - gen_server:call(Pid, {?FUNCTION_NAME, Data}) + gen_statem:call(Pid, {?FUNCTION_NAME, Data}) catch exit:{noproc, _} -> {error, closed} @@ -134,7 +133,7 @@ cb_send(Pid, Data) -> cb_setopts(Pid, Opts) -> ?DEBUG_LOG("pid ~0p opts ~0p", [Pid, Opts]), try - gen_server:call(Pid, {?FUNCTION_NAME, Opts}) + gen_statem:call(Pid, {?FUNCTION_NAME, Opts}) catch exit:{noproc, _} -> {error, einval} @@ -157,31 +156,32 @@ connect(_, _, _, _) -> -spec send(pid(), iodata()) -> ok | {error, atom()}. send(Pid, Data) -> ?DEBUG_LOG("pid ~0p data ~0p", [Pid, Data]), - gen_server:call(Pid, {?FUNCTION_NAME, Data}). + gen_statem:call(Pid, {?FUNCTION_NAME, Data}). -spec setopts(pid(), list()) -> ok. setopts(Pid, Opts) -> ?DEBUG_LOG("pid ~0p opts ~0p", [Pid, Opts]), - gen_server:cast(Pid, {?FUNCTION_NAME, Opts}). + gen_statem:cast(Pid, {?FUNCTION_NAME, Opts}). -spec sockname(pid()) -> {ok, {inet:ip_address(), inet:port_number()}} | {error, atom()}. sockname(Pid) -> ?DEBUG_LOG("pid ~0p", [Pid]), - gen_server:call(Pid, ?FUNCTION_NAME). + gen_statem:call(Pid, ?FUNCTION_NAME). -spec close(pid()) -> ok. close(Pid) -> ?DEBUG_LOG("pid ~0p", [Pid]), - gen_server:call(Pid, ?FUNCTION_NAME). + gen_statem:call(Pid, ?FUNCTION_NAME). + +%% gen_statem. -%% gen_server. -%% @todo Probably need to gen_statem it to avoid trying to send stuff before being connected. +callback_mode() -> state_functions. init({OwnerPid, Host, Port, Opts, Timeout, OutSocket, OutTransport}) -> if is_pid(OutSocket) -> - gen_server:cast(OutSocket, {set_owner, self()}); + gen_statem:cast(OutSocket, {set_owner, self()}); true -> ok end, @@ -194,7 +194,7 @@ init({OwnerPid, Host, Port, Opts, Timeout, OutSocket, OutTransport}) -> ?DEBUG_LOG("owner_pid ~0p host ~0p port ~0p opts ~0p timeout ~0p" " out_socket ~0p out_transport ~0p proxy_pid ~0p", [OwnerPid, Host, Port, Opts, Timeout, OutSocket, OutTransport, ProxyPid]), - {ok, #state{owner_pid=OwnerPid, host=Host, port=Port, proxy_pid=ProxyPid, + {ok, not_connected, #state{owner_pid=OwnerPid, host=Host, port=Port, proxy_pid=ProxyPid, out_socket=OutSocket, out_transport=OutTransport, out_messages=Messages}}. connect_proc(ProxyPid, Host, Port, Opts, Timeout) -> @@ -208,110 +208,121 @@ connect_proc(ProxyPid, Host, Port, Opts, Timeout) -> {ok, Socket} -> ?DEBUG_LOG("socket ~0p", [Socket]), ssl:controlling_process(Socket, ProxyPid), - gen_server:cast(ProxyPid, {?FUNCTION_NAME, {ok, Socket}}); + gen_statem:cast(ProxyPid, {?FUNCTION_NAME, {ok, Socket}}); Error -> ?DEBUG_LOG("error ~0p", [Error]), - gen_server:cast(ProxyPid, {?FUNCTION_NAME, Error}) + gen_statem:cast(ProxyPid, {?FUNCTION_NAME, Error}) end, ok. -handle_call(Msg={cb_send, Data}, From, State=#state{ +%% Postpone events that require the proxy socket to be up. +not_connected({call, _}, Msg={send, _}, State) -> + ?DEBUG_LOG("postpone ~0p state ~0p", [Msg, State]), + {keep_state_and_data, postpone}; +not_connected(cast, Msg={setopts, _}, State) -> + ?DEBUG_LOG("postpone ~0p state ~0p", [Msg, State]), + {keep_state_and_data, postpone}; +not_connected(cast, Msg={connect_proc, {ok, Socket}}, State) -> + ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), + ok = ssl:setopts(Socket, [{active, true}]), + {next_state, connected, State#state{proxy_socket=Socket}}; +not_connected(cast, Msg={connect_proc, Error}, State) -> + ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), + {stop, Error, State}; +not_connected(Type, Event, State) -> + handle_common(Type, Event, State). + +%% Send data through the proxy socket. +connected({call, From}, Msg={send, Data}, State=#state{proxy_socket=Socket}) -> + ?DEBUG_LOG("msg ~0p from ~0p state ~0p", [Msg, From, State]), + Self = self(), + SpawnedPid = spawn(fun() -> + gen_statem:cast(Self, {send_result, From, ssl:send(Socket, Data)}) + end), + ?DEBUG_LOG("spawned ~0p", [SpawnedPid]), + keep_state_and_data; +%% Messages from the proxy socket. +connected(info, Msg={ssl, Socket, Data}, State=#state{owner_pid=OwnerPid, proxy_socket=Socket}) -> + ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), + OwnerPid ! {tls_proxy, self(), Data}, + keep_state_and_data; +connected(info, Msg={ssl_closed, Socket}, State=#state{owner_pid=OwnerPid, proxy_socket=Socket}) -> + ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), + OwnerPid ! {tls_proxy_closed, self()}, + keep_state_and_data; +connected(info, Msg={ssl_error, Socket, Reason}, State=#state{owner_pid=OwnerPid, proxy_socket=Socket}) -> + ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), + OwnerPid ! {tls_proxy_error, self(), Reason}, + keep_state_and_data; +connected(Type, Event, State) -> + handle_common(Type, Event, State). + +handle_common({call, From}, Msg={cb_send, Data}, State=#state{ out_socket=OutSocket, out_transport=OutTransport}) -> ?DEBUG_LOG("msg ~0p from ~0p state ~0p", [Msg, From, State]), Self = self(), SpawnedPid = spawn(fun() -> - gen_server:cast(Self, {send_result, From, OutTransport:send(OutSocket, Data)}) + gen_statem:cast(Self, {send_result, From, OutTransport:send(OutSocket, Data)}) end), ?DEBUG_LOG("spawned ~0p", [SpawnedPid]), - {noreply, State}; -handle_call(Msg={cb_setopts, Opts}, From, State=#state{ + keep_state_and_data; +handle_common({call, From}, Msg={cb_setopts, Opts}, State=#state{ out_socket=OutSocket, out_transport=OutTransport0}) -> ?DEBUG_LOG("msg ~0p from ~0p state ~0p", [Msg, From, State]), OutTransport = case OutTransport0 of gen_tcp -> inet; _ -> OutTransport0 end, - {reply, OutTransport:setopts(OutSocket, [{active, true}]), proxy_setopts(Opts, State)}; -%% @todo If Socket is undefined here we need to buffer input -%% and send it when we receive the {connect_proc, {ok, Socket}} message. -handle_call(Msg={send, Data}, From, State=#state{proxy_socket=Socket}) -> - ?DEBUG_LOG("msg ~0p from ~0p state ~0p", [Msg, From, State]), - Self = self(), - SpawnedPid = spawn(fun() -> - gen_server:cast(Self, {send_result, From, ssl:send(Socket, Data)}) - end), - ?DEBUG_LOG("spawned ~0p", [SpawnedPid]), - {noreply, State}; -handle_call(Msg=sockname, From, State=#state{ + {keep_state, proxy_setopts(Opts, State), + {reply, From, OutTransport:setopts(OutSocket, [{active, true}])}}; +handle_common({call, From}, Msg=sockname, State=#state{ out_socket=OutSocket, out_transport=OutTransport}) -> ?DEBUG_LOG("msg ~0p from ~0p state ~0p", [Msg, From, State]), - {reply, OutTransport:sockname(OutSocket), State}; -handle_call(Msg=close, From, State) -> + {keep_state, State, + {reply, From, OutTransport:sockname(OutSocket)}}; +handle_common({call, From}, Msg=close, State) -> ?DEBUG_LOG("msg ~0p from ~0p state ~0p", [Msg, From, State]), - {stop, {shutdown, close}, State}; -handle_call(Msg, From, State) -> - ?DEBUG_LOG("IGNORED msg ~0p from ~0p state ~0p", [Msg, From, State]), - {reply, {error, bad_call}, State}. - -handle_cast(Msg={set_owner, OwnerPid}, State) -> - ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - {noreply, State#state{owner_pid=OwnerPid}}; -handle_cast(Msg={connect_proc, {ok, Socket}}, State) -> + {stop_and_reply, {shutdown, close}, {reply, From, ok}}; +handle_common(cast, Msg={set_owner, OwnerPid}, State) -> ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - ok = ssl:setopts(Socket, [{active, true}]), - {noreply, State#state{proxy_socket=Socket}}; -handle_cast(Msg={connect_proc, Error}, State) -> - ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - {stop, Error, State}; -handle_cast(Msg={cb_controlling_process, ProxyPid}, State) -> + {keep_state, State#state{owner_pid=OwnerPid}}; +handle_common(cast, Msg={cb_controlling_process, ProxyPid}, State) -> ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), %% We link so that the ssl process terminates when we do. link(ProxyPid), - {noreply, State#state{proxy_pid=ProxyPid}}; -handle_cast(Msg={setopts, Opts}, State) -> + {keep_state, State#state{proxy_pid=ProxyPid}}; +handle_common(cast, Msg={setopts, Opts}, State) -> ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - {noreply, owner_setopts(Opts, State)}; -handle_cast(Msg={send_result, From, Result}, State) -> + {keep_state, owner_setopts(Opts, State)}; +handle_common(cast, Msg={send_result, From, Result}, State) -> ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - gen_server:reply(From, Result), - {noreply, State}; -handle_cast(Msg, State) -> - ?DEBUG_LOG("IGNORED msg ~0p state ~0p", [Msg, State]), - {noreply, State}. - + gen_statem:reply(From, Result), + keep_state_and_data; %% Messages from the real socket. -handle_info(Msg={OK, Socket, Data}, State=#state{proxy_pid=ProxyPid, +handle_common(info, Msg={OK, Socket, Data}, State=#state{proxy_pid=ProxyPid, out_socket=Socket, out_messages={OK, _, _}}) -> ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), ProxyPid ! {tls_proxy, self(), Data}, - {noreply, State}; -handle_info(Msg={Closed, Socket}, State=#state{proxy_pid=ProxyPid, + keep_state_and_data; +handle_common(info, Msg={Closed, Socket}, State=#state{proxy_pid=ProxyPid, out_socket=Socket, out_messages={_, Closed, _}}) -> ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), ProxyPid ! {tls_proxy_closed, self()}, - {noreply, State}; -handle_info(Msg={Error, Socket, Reason}, State=#state{proxy_pid=ProxyPid, + keep_state_and_data; +handle_common(info, Msg={Error, Socket, Reason}, State=#state{proxy_pid=ProxyPid, out_socket=Socket, out_messages={_, _, Error}}) -> ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), ProxyPid ! {tls_proxy_error, self(), Reason}, - {noreply, State}; -%% Messages from the proxy socket. -handle_info(Msg={ssl, Socket, Data}, State=#state{owner_pid=OwnerPid, proxy_socket=Socket}) -> - ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - OwnerPid ! {tls_proxy, self(), Data}, - {noreply, State}; -handle_info(Msg={ssl_closed, Socket}, State=#state{owner_pid=OwnerPid, proxy_socket=Socket}) -> - ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - OwnerPid ! {tls_proxy_closed, self()}, - {noreply, State}; -handle_info(Msg={ssl_error, Socket, Reason}, State=#state{owner_pid=OwnerPid, proxy_socket=Socket}) -> - ?DEBUG_LOG("msg ~0p state ~0p", [Msg, State]), - OwnerPid ! {tls_proxy_error, self(), Reason}, - {noreply, State}; + keep_state_and_data; %% Other messages. -handle_info(Msg, State) -> - ?DEBUG_LOG("IGNORED msg ~0p state ~0p", [Msg, State]), - {noreply, State}. +handle_common(Type, Msg, State) -> + ?DEBUG_LOG("IGNORED type ~0p msg ~0p state ~0p", [Type, Msg, State]), + case Type of + {call, From} -> + {keep_state, State, {reply, From, {error, bad_call}}}; + _ -> + keep_state_and_data + end. %% Internal. @@ -365,9 +376,7 @@ tcp_test() -> ssl:start(), {ok, Socket} = gen_tcp:connect("google.com", 443, [binary, {active, false}]), {ok, ProxyPid1} = start_link("google.com", 443, [], 5000, Socket, gen_tcp), - timer:sleep(500), send(ProxyPid1, <<"GET / HTTP/1.1\r\nHost: google.com\r\n\r\n">>), - timer:sleep(1000), receive {tls_proxy, ProxyPid1, <<"HTTP/1.1 ", _/bits>>} -> ok after 1000 -> error(timeout) end. ssl_test() -> @@ -375,11 +384,8 @@ ssl_test() -> _ = (catch ct_helper:make_certs_in_ets()), {ok, _, Port} = do_proxy_start("google.com", 443), {ok, Socket} = ssl:connect("localhost", Port, [binary, {active, false}]), - timer:sleep(500), {ok, ProxyPid1} = start_link("google.com", 443, [], 5000, Socket, ssl), - timer:sleep(500), send(ProxyPid1, <<"GET / HTTP/1.1\r\nHost: google.com\r\n\r\n">>), - timer:sleep(1000), receive {tls_proxy, ProxyPid1, <<"HTTP/1.1 ", _/bits>>} -> ok after 1000 -> error(timeout) end. ssl2_test() -> @@ -388,13 +394,9 @@ ssl2_test() -> {ok, _, Port1} = do_proxy_start("google.com", 443), {ok, _, Port2} = do_proxy_start("localhost", Port1), {ok, Socket} = ssl:connect("localhost", Port2, [binary, {active, false}]), - timer:sleep(500), {ok, ProxyPid1} = start_link("localhost", Port1, [], 5000, Socket, ssl), - timer:sleep(500), {ok, ProxyPid2} = start_link("google.com", 443, [], 5000, ProxyPid1, ?MODULE), - timer:sleep(500), send(ProxyPid2, <<"GET / HTTP/1.1\r\nHost: google.com\r\n\r\n">>), - timer:sleep(1000), receive {tls_proxy, ProxyPid2, <<"HTTP/1.1 ", _/bits>>} -> ok after 1000 -> error(timeout) end. do_proxy_start(Host, Port) -> diff --git a/test/rfc7231_SUITE.erl b/test/rfc7231_SUITE.erl index 66ada7f..cbace9b 100644 --- a/test/rfc7231_SUITE.erl +++ b/test/rfc7231_SUITE.erl @@ -125,74 +125,42 @@ do_proxy_loop(Transport, ClientSocket, OriginSocket) -> connect_http(_) -> doc("CONNECT can be used to establish a TCP connection " "to an HTTP/1.1 server via an HTTP proxy. (RFC7231 4.3.6)"), - do_connect_http(tcp). + do_connect_http(tcp, tcp). connect_https(_) -> doc("CONNECT can be used to establish a TLS connection " "to an HTTP/1.1 server via an HTTP proxy. (RFC7231 4.3.6)"), - do_connect_http(tls). - -do_connect_http(Transport) -> - {ok, OriginPid, OriginPort} = init_origin(Transport, http), - {ok, ProxyPid, ProxyPort} = do_proxy_start(tcp), - Authority = iolist_to_binary(["localhost:", integer_to_binary(OriginPort)]), - {ok, ConnPid} = gun:open("localhost", ProxyPort), - {ok, http} = gun:await_up(ConnPid), - StreamRef = gun:connect(ConnPid, #{ - host => "localhost", - port => OriginPort, - transport => Transport - }), - {request, <<"CONNECT">>, Authority, 'HTTP/1.1', _} = receive_from(ProxyPid), - {response, fin, 200, _} = gun:await(ConnPid, StreamRef), - _ = gun:get(ConnPid, "/proxied"), - Data = receive_from(OriginPid), - Lines = binary:split(Data, <<"\r\n">>, [global]), - [<<"host: ", Authority/bits>>] = [L || <<"host: ", _/bits>> = L <- Lines], - #{ - transport := Transport, - protocol := http, - origin_host := "localhost", - origin_port := OriginPort, - intermediaries := [#{ - type := connect, - host := "localhost", - port := ProxyPort, - transport := tcp, - protocol := http - }]} = gun:info(ConnPid), - gun:close(ConnPid). + do_connect_http(tls, tcp). connect_http_over_https_proxy(_) -> doc("CONNECT can be used to establish a TCP connection " "to an HTTP/1.1 server via an HTTPS proxy. (RFC7231 4.3.6)"), - do_connect_http_over_https_proxy(tcp). + do_connect_http(tcp, tls). connect_https_over_https_proxy(_) -> doc("CONNECT can be used to establish a TLS connection " "to an HTTP/1.1 server via an HTTPS proxy. (RFC7231 4.3.6)"), - do_connect_http_over_https_proxy(tls). + do_connect_http(tls, tls). -do_connect_http_over_https_proxy(Transport) -> - {ok, OriginPid, OriginPort} = init_origin(Transport, http), - {ok, ProxyPid, ProxyPort} = do_proxy_start(tls), +do_connect_http(OriginTransport, ProxyTransport) -> + {ok, OriginPid, OriginPort} = init_origin(OriginTransport, http), + {ok, ProxyPid, ProxyPort} = do_proxy_start(ProxyTransport), Authority = iolist_to_binary(["localhost:", integer_to_binary(OriginPort)]), - {ok, ConnPid} = gun:open("localhost", ProxyPort, #{transport => tls}), + {ok, ConnPid} = gun:open("localhost", ProxyPort, #{transport => ProxyTransport}), {ok, http} = gun:await_up(ConnPid), StreamRef = gun:connect(ConnPid, #{ host => "localhost", port => OriginPort, - transport => Transport + transport => OriginTransport }), {request, <<"CONNECT">>, Authority, 'HTTP/1.1', _} = receive_from(ProxyPid), {response, fin, 200, _} = gun:await(ConnPid, StreamRef), -% timer:sleep(2000), _ = gun:get(ConnPid, "/proxied"), Data = receive_from(OriginPid), Lines = binary:split(Data, <<"\r\n">>, [global]), [<<"host: ", Authority/bits>>] = [L || <<"host: ", _/bits>> = L <- Lines], #{ - transport := Transport, + transport := OriginTransport, protocol := http, origin_host := "localhost", origin_port := OriginPort, @@ -200,7 +168,7 @@ do_connect_http_over_https_proxy(Transport) -> type := connect, host := "localhost", port := ProxyPort, - transport := tls, + transport := ProxyTransport, protocol := http }]} = gun:info(ConnPid), gun:close(ConnPid). -- cgit v1.2.3