diff options
Diffstat (limited to 'lib/ssl/src/tls_connection.erl')
-rw-r--r-- | lib/ssl/src/tls_connection.erl | 218 |
1 files changed, 151 insertions, 67 deletions
diff --git a/lib/ssl/src/tls_connection.erl b/lib/ssl/src/tls_connection.erl index a3002830d1..6c7511d2b3 100644 --- a/lib/ssl/src/tls_connection.erl +++ b/lib/ssl/src/tls_connection.erl @@ -43,30 +43,35 @@ %% Internal application API %% Setup --export([start_fsm/8, start_link/7, init/1]). +-export([start_fsm/8, start_link/8, init/1, pids/1]). %% State transition handling --export([next_record/1, next_event/3, next_event/4, handle_common_event/4]). +-export([next_record/1, next_event/3, next_event/4, + handle_common_event/4]). %% Handshake handling --export([renegotiate/2, send_handshake/2, +-export([renegotiation/2, renegotiate/2, send_handshake/2, queue_handshake/2, queue_change_cipher/2, - reinit_handshake_data/1, select_sni_extension/1, empty_connection_state/2]). + reinit/1, reinit_handshake_data/1, select_sni_extension/1, + empty_connection_state/2]). %% Alert and close handling --export([encode_alert/3, send_alert/2, close/5, protocol_name/0]). +-export([send_alert/2, send_alert_in_connection/2, encode_alert/3, close/5, protocol_name/0]). %% Data handling --export([encode_data/3, passive_receive/2, next_record_if_active/1, send/3, - socket/5, setopts/3, getopts/3]). +-export([encode_data/3, passive_receive/2, next_record_if_active/1, + send/3, socket/5, setopts/3, getopts/3]). %% gen_statem state functions -export([init/3, error/3, downgrade/3, %% Initiation and take down states hello/3, user_hello/3, certify/3, cipher/3, abbreviated/3, %% Handshake states - connection/3, death_row/3]). + connection/3]). %% gen_statem callbacks -export([callback_mode/0, terminate/3, code_change/4, format_status/2]). + +-define(DIST_CNTRL_SPAWN_OPTS, [{priority, max}]). + %%==================================================================== %% Internal application API %%==================================================================== @@ -77,9 +82,10 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_, Tracker} User, {CbModule, _,_, _} = CbInfo, Timeout) -> try - {ok, Pid} = tls_connection_sup:start_child([Role, Host, Port, Socket, + {ok, Sender} = tls_sender:start(), + {ok, Pid} = tls_connection_sup:start_child([Role, Sender, Host, Port, Socket, Opts, User, CbInfo]), - {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker), + {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, [Pid, Sender], CbModule, Tracker), ssl_connection:handshake(SslSocket, Timeout) catch error:{badmatch, {error, _} = Error} -> @@ -90,9 +96,10 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} = User, {CbModule, _,_, _} = CbInfo, Timeout) -> try - {ok, Pid} = tls_connection_sup:start_child_dist([Role, Host, Port, Socket, + {ok, Sender} = tls_sender:start([{spawn_opt, ?DIST_CNTRL_SPAWN_OPTS}]), + {ok, Pid} = tls_connection_sup:start_child_dist([Role, Sender, Host, Port, Socket, Opts, User, CbInfo]), - {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker), + {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, [Pid, Sender], CbModule, Tracker), ssl_connection:handshake(SslSocket, Timeout) catch error:{badmatch, {error, _} = Error} -> @@ -100,24 +107,37 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} = end. %%-------------------------------------------------------------------- --spec start_link(atom(), host(), inet:port_number(), port(), list(), pid(), tuple()) -> +-spec start_link(atom(), pid(), host(), inet:port_number(), port(), list(), pid(), tuple()) -> {ok, pid()} | ignore | {error, reason()}. %% %% Description: Creates a gen_statem process which calls Module:init/1 to %% initialize. %%-------------------------------------------------------------------- -start_link(Role, Host, Port, Socket, Options, User, CbInfo) -> - {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Host, Port, Socket, Options, User, CbInfo]])}. +start_link(Role, Sender, Host, Port, Socket, Options, User, CbInfo) -> + {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Sender, Host, Port, Socket, Options, User, CbInfo]])}. -init([Role, Host, Port, Socket, Options, User, CbInfo]) -> +init([Role, Sender, Host, Port, Socket, {SslOpts, _, _} = Options, User, CbInfo]) -> process_flag(trap_exit, true), - State0 = initial_state(Role, Host, Port, Socket, Options, User, CbInfo), + case SslOpts#ssl_options.erl_dist of + true -> + process_flag(priority, max); + _ -> + ok + end, + State0 = #state{protocol_specific = Map} = initial_state(Role, Sender, + Host, Port, Socket, Options, User, CbInfo), try State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0), - gen_statem:enter_loop(?MODULE, [], init, State) + initialize_tls_sender(State), + gen_statem:enter_loop(?MODULE, [], init, State) catch throw:Error -> - gen_statem:enter_loop(?MODULE, [], error, {Error, State0}) + EState = State0#state{protocol_specific = Map#{error => Error}}, + gen_statem:enter_loop(?MODULE, [], error, EState) end. + +pids(#state{protocol_specific = #{sender := Sender}}) -> + [self(), Sender]. + %%==================================================================== %% State transition handling %%==================================================================== @@ -234,13 +254,15 @@ handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) -> %%==================================================================== %% Handshake handling %%==================================================================== +renegotiation(Pid, WriteState) -> + gen_statem:call(Pid, {user_renegotiate, WriteState}). + renegotiate(#state{role = client} = State, Actions) -> %% Handle same way as if server requested %% the renegotiation Hs0 = ssl_handshake:init_handshake_history(), {next_state, connection, State#state{tls_handshake_history = Hs0}, [{next_event, internal, #hello_request{}} | Actions]}; - renegotiate(#state{role = server, socket = Socket, transport_cb = Transport, @@ -285,6 +307,12 @@ queue_change_cipher(Msg, #state{negotiated_version = Version, State0#state{connection_states = ConnectionStates, flight_buffer = Flight0 ++ [BinChangeCipher]}. +reinit(#state{protocol_specific = #{sender := Sender}, + negotiated_version = Version, + connection_states = #{current_write := Write}} = State) -> + tls_sender:update_connection_state(Sender, Write, Version), + reinit_handshake_data(State). + reinit_handshake_data(State) -> %% premaster_secret, public_key_info and tls_handshake_info %% are only needed during the handshake phase. @@ -306,15 +334,6 @@ empty_connection_state(ConnectionEnd, BeastMitigation) -> %%==================================================================== %% Alert and close handling %%==================================================================== -send_alert(Alert, #state{negotiated_version = Version, - socket = Socket, - transport_cb = Transport, - connection_states = ConnectionStates0} = State0) -> - {BinMsg, ConnectionStates} = - encode_alert(Alert, Version, ConnectionStates0), - send(Transport, Socket, BinMsg), - State0#state{connection_states = ConnectionStates}. - %%-------------------------------------------------------------------- -spec encode_alert(#alert{}, ssl_record:ssl_version(), ssl_record:connection_states()) -> {iolist(), ssl_record:connection_states()}. @@ -323,6 +342,20 @@ send_alert(Alert, #state{negotiated_version = Version, %%-------------------------------------------------------------------- encode_alert(#alert{} = Alert, Version, ConnectionStates) -> tls_record:encode_alert_record(Alert, Version, ConnectionStates). + +send_alert(Alert, #state{negotiated_version = Version, + socket = Socket, + protocol_cb = Connection, + transport_cb = Transport, + connection_states = ConnectionStates0} = StateData0) -> + {BinMsg, ConnectionStates} = + Connection:encode_alert(Alert, Version, ConnectionStates0), + Connection:send(Transport, Socket, BinMsg), + StateData0#state{connection_states = ConnectionStates}. + +send_alert_in_connection(Alert, #state{protocol_specific = #{sender := Sender}}) -> + tls_sender:send_alert(Sender, Alert). + %% User closes or recursive call! close({close, Timeout}, Socket, Transport = gen_tcp, _,_) -> tls_socket:setopts(Transport, Socket, [{active, false}]), @@ -377,8 +410,8 @@ next_record_if_active(State) -> send(Transport, Socket, Data) -> tls_socket:send(Transport, Socket, Data). -socket(Pid, Transport, Socket, Connection, Tracker) -> - tls_socket:socket(Pid, Transport, Socket, Connection, Tracker). +socket(Pids, Transport, Socket, Connection, Tracker) -> + tls_socket:socket(Pids, Transport, Socket, Connection, Tracker). setopts(Transport, Socket, Other) -> tls_socket:setopts(Transport, Socket, Other). @@ -432,17 +465,12 @@ init(Type, Event, State) -> {start, timeout()} | term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- - -error({call, From}, {start, _Timeout}, {Error, State}) -> - ssl_connection:stop_and_reply( - normal, {reply, From, {error, Error}}, State); error({call, From}, {start, _Timeout}, #state{protocol_specific = #{error := Error}} = State) -> ssl_connection:stop_and_reply( normal, {reply, From, {error, Error}}, State); -error({call, _} = Call, Msg, {Error, #state{protocol_specific = Map} = State}) -> - gen_handshake(?FUNCTION_NAME, Call, Msg, - State#state{protocol_specific = Map#{error => Error}}); +error({call, _} = Call, Msg, State) -> + gen_handshake(?FUNCTION_NAME, Call, Msg, State); error(_, _, _) -> {keep_state_and_data, [postpone]}. @@ -452,15 +480,17 @@ error(_, _, _) -> #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- -hello(internal, #client_hello{extensions = Extensions} = Hello, #state{ssl_options = #ssl_options{handshake = hello}, - start_or_recv_from = From} = State) -> - {next_state, user_hello, State#state{start_or_recv_from = undefined, +hello(internal, #client_hello{extensions = Extensions} = Hello, + #state{ssl_options = #ssl_options{handshake = hello}, + start_or_recv_from = From} = State) -> + {next_state, user_hello, State#state{start_or_recv_from = undefined, hello = Hello}, [{reply, From, {ok, ssl_connection:map_extensions(Extensions)}}]}; -hello(internal, #server_hello{extensions = Extensions} = Hello, #state{ssl_options = #ssl_options{handshake = hello}, - start_or_recv_from = From} = State) -> +hello(internal, #server_hello{extensions = Extensions} = Hello, + #state{ssl_options = #ssl_options{handshake = hello}, + start_or_recv_from = From} = State) -> {next_state, user_hello, State#state{start_or_recv_from = undefined, - hello = Hello}, + hello = Hello}, [{reply, From, {ok, ssl_connection:map_extensions(Extensions)}}]}; hello(internal, #client_hello{client_version = ClientVersion} = Hello, #state{connection_states = ConnectionStates0, @@ -544,14 +574,19 @@ cipher(Type, Event, State) -> %%-------------------------------------------------------------------- connection(info, Event, State) -> gen_info(Event, ?FUNCTION_NAME, State); +connection({call, From}, {user_renegotiate, WriteState}, + #state{connection_states = ConnectionStates} = State) -> + {next_state, ?FUNCTION_NAME, State#state{connection_states = ConnectionStates#{current_write => WriteState}}, + [{next_event,{call, From}, renegotiate}]}; connection(internal, #hello_request{}, - #state{role = client, host = Host, port = Port, + #state{role = client, + renegotiation = {Renegotiation, _}, + host = Host, port = Port, session = #session{own_certificate = Cert} = Session0, session_cache = Cache, session_cache_cb = CacheCb, - ssl_options = SslOpts, - connection_states = ConnectionStates0, - renegotiation = {Renegotiation, _}} = State0) -> - Hello = tls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, + ssl_options = SslOpts, + connection_states = ConnectionStates} = State0) -> + Hello = tls_handshake:client_hello(Host, Port, ConnectionStates, SslOpts, Cache, CacheCb, Renegotiation, Cert), {State1, Actions} = send_handshake(Hello, State0), {Record, State} = @@ -560,7 +595,10 @@ connection(internal, #hello_request{}, = Hello#client_hello.session_id}}), next_event(hello, Record, State, Actions); connection(internal, #client_hello{} = Hello, - #state{role = server, allow_renegotiate = true} = State0) -> + #state{role = server, allow_renegotiate = true, connection_states = CS, + %%protocol_cb = Connection, + protocol_specific = #{sender := Sender} + } = State0) -> %% Mitigate Computational DoS attack %% http://www.educatedguesswork.org/2011/10/ssltls_and_computational_dos.html %% http://www.thc.org/thc-ssl-dos/ Rather than disabling client @@ -569,24 +607,21 @@ connection(internal, #client_hello{} = Hello, erlang:send_after(?WAIT_TO_ALLOW_RENEGOTIATION, self(), allow_renegotiate), {Record, State} = next_record(State0#state{allow_renegotiate = false, renegotiation = {true, peer}}), - next_event(hello, Record, State, [{next_event, internal, Hello}]); + {ok, Write} = tls_sender:renegotiate(Sender), + next_event(hello, Record, State#state{connection_states = CS#{current_write => Write}}, + [{next_event, internal, Hello}]); connection(internal, #client_hello{}, - #state{role = server, allow_renegotiate = false} = State0) -> + #state{role = server, allow_renegotiate = false, + protocol_cb = Connection} = State0) -> Alert = ?ALERT_REC(?WARNING, ?NO_RENEGOTIATION), - State1 = send_alert(Alert, State0), - {Record, State} = ssl_connection:prepare_connection(State1, ?MODULE), + send_alert_in_connection(Alert, State0), + State1 = Connection:reinit_handshake_data(State0), + {Record, State} = next_record(State1), next_event(?FUNCTION_NAME, Record, State); connection(Type, Event, State) -> ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE). %%-------------------------------------------------------------------- --spec death_row(gen_statem:event_type(), term(), #state{}) -> - gen_statem:state_function_result(). -%%-------------------------------------------------------------------- -death_row(Type, Event, State) -> - ssl_connection:death_row(Type, Event, State, ?MODULE). - -%%-------------------------------------------------------------------- -spec downgrade(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- @@ -600,6 +635,7 @@ callback_mode() -> state_functions. terminate(Reason, StateName, State) -> + ensure_sender_terminate(Reason, State), catch ssl_connection:terminate(Reason, StateName, State). format_status(Type, Data) -> @@ -611,11 +647,13 @@ code_change(_OldVsn, StateName, State, _) -> %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- -initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, Tracker}, User, +initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Tracker}, User, {CbModule, DataTag, CloseTag, ErrorTag}) -> - #ssl_options{beast_mitigation = BeastMitigation} = SSLOptions, + #ssl_options{beast_mitigation = BeastMitigation, + erl_dist = IsErlDist} = SSLOptions, ConnectionStates = tls_record:init_connection_states(Role, BeastMitigation), + ErlDistData = erl_dist_data(IsErlDist), SessionCacheCb = case application:get_env(ssl, session_cb) of {ok, Cb} when is_atom(Cb) -> Cb; @@ -623,7 +661,7 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, Tracker}, Us ssl_session_cache end, - Monitor = erlang:monitor(process, User), + UserMonitor = erlang:monitor(process, User), #state{socket_options = SocketOptions, ssl_options = SSLOptions, @@ -636,9 +674,10 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, Tracker}, Us host = Host, port = Port, socket = Socket, + erl_dist_data = ErlDistData, connection_states = ConnectionStates, protocol_buffers = #protocol_buffers{}, - user_application = {Monitor, User}, + user_application = {UserMonitor, User}, user_data_buffer = <<>>, session_cache_cb = SessionCacheCb, renegotiation = {false, first}, @@ -646,9 +685,37 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, Tracker}, Us start_or_recv_from = undefined, protocol_cb = ?MODULE, tracker = Tracker, - flight_buffer = [] + flight_buffer = [], + protocol_specific = #{sender => Sender} }. +erl_dist_data(true) -> + #{dist_handle => undefined, + dist_buffer => <<>>}; +erl_dist_data(false) -> + #{}. + +initialize_tls_sender(#state{role = Role, + socket = Socket, + socket_options = SockOpts, + tracker = Tracker, + protocol_cb = Connection, + transport_cb = Transport, + negotiated_version = Version, + ssl_options = #ssl_options{renegotiate_at = RenegotiateAt}, + connection_states = #{current_write := ConnectionWriteState}, + protocol_specific = #{sender := Sender}}) -> + Init = #{current_write => ConnectionWriteState, + role => Role, + socket => Socket, + socket_options => SockOpts, + tracker => Tracker, + protocol_cb => Connection, + transport_cb => Transport, + negotiated_version => Version, + renegotiate_at => RenegotiateAt}, + tls_sender:initialize(Sender, Init). + next_tls_record(Data, StateName, #state{protocol_buffers = #protocol_buffers{tls_record_buffer = Buf0, tls_cipher_texts = CT0} = Buffers} @@ -720,6 +787,9 @@ handle_info({CloseTag, Socket}, StateName, %% and then receive the final message. next_event(StateName, no_record, State) end; +handle_info({'EXIT', Pid, Reason}, _, + #state{protocol_specific = Pid} = State) -> + {stop, {shutdown, sender_died, Reason}, State}; handle_info(Msg, StateName, State) -> ssl_connection:StateName(info, Msg, State, ?MODULE). @@ -788,7 +858,8 @@ unprocessed_events(Events) -> erlang:length(Events)-1. -assert_buffer_sanity(<<?BYTE(_Type), ?UINT24(Length), Rest/binary>>, #ssl_options{max_handshake_size = Max}) when +assert_buffer_sanity(<<?BYTE(_Type), ?UINT24(Length), Rest/binary>>, + #ssl_options{max_handshake_size = Max}) when Length =< Max -> case size(Rest) of N when N < Length -> @@ -808,3 +879,16 @@ assert_buffer_sanity(Bin, _) -> throw(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, malformed_handshake_data)) end. + +ensure_sender_terminate(downgrade, _) -> + ok; %% Do not terminate sender during downgrade phase +ensure_sender_terminate(_, #state{protocol_specific = #{sender := Sender}}) -> + %% Make sure TLS sender dies when connection process is terminated normally + %% This is needed if the tls_sender is blocked in prim_inet:send + Kill = fun() -> + receive + after 5000 -> + catch (exit(Sender, kill)) + end + end, + spawn(Kill). |