diff options
Diffstat (limited to 'lib/ssl/src/tls_sender.erl')
-rw-r--r-- | lib/ssl/src/tls_sender.erl | 149 |
1 files changed, 117 insertions, 32 deletions
diff --git a/lib/ssl/src/tls_sender.erl b/lib/ssl/src/tls_sender.erl index db67d7ddff..11fcc6def0 100644 --- a/lib/ssl/src/tls_sender.erl +++ b/lib/ssl/src/tls_sender.erl @@ -28,7 +28,8 @@ -include("ssl_api.hrl"). %% API --export([start/0, start/1, initialize/2, send_data/2, send_alert/2, renegotiate/1, +-export([start/0, start/1, initialize/2, send_data/2, send_alert/2, + send_and_ack_alert/2, setopts/2, renegotiate/1, peer_renegotiate/1, downgrade/2, update_connection_state/3, dist_tls_socket/1, dist_handshake_complete/3]). %% gen_statem callbacks @@ -80,7 +81,7 @@ initialize(Pid, InitMsg) -> gen_statem:call(Pid, {self(), InitMsg}). %%-------------------------------------------------------------------- --spec send_data(pid(), iodata()) -> ok. +-spec send_data(pid(), iodata()) -> ok | {error, term()}. %% Description: Send application data %%-------------------------------------------------------------------- send_data(Pid, AppData) -> @@ -89,13 +90,27 @@ send_data(Pid, AppData) -> %%-------------------------------------------------------------------- -spec send_alert(pid(), #alert{}) -> _. -%% Description: TLS connection process wants to end an Alert +%% Description: TLS connection process wants to send an Alert %% in the connection state. %%-------------------------------------------------------------------- send_alert(Pid, Alert) -> gen_statem:cast(Pid, Alert). %%-------------------------------------------------------------------- +-spec send_and_ack_alert(pid(), #alert{}) -> _. +%% Description: TLS connection process wants to send an Alert +%% in the connection state and recive an ack. +%%-------------------------------------------------------------------- +send_and_ack_alert(Pid, Alert) -> + gen_statem:call(Pid, {ack_alert, Alert}, ?DEFAULT_TIMEOUT). +%%-------------------------------------------------------------------- +-spec setopts(pid(), [{packet, integer() | atom()}]) -> ok | {error, term()}. +%% Description: Send application data +%%-------------------------------------------------------------------- +setopts(Pid, Opts) -> + call(Pid, {set_opts, Opts}). + +%%-------------------------------------------------------------------- -spec renegotiate(pid()) -> {ok, WriteState::map()} | {error, closed}. %% Description: So TLS connection process can synchronize the %% encryption state to be used when handshaking. @@ -103,6 +118,15 @@ send_alert(Pid, Alert) -> renegotiate(Pid) -> %% Needs error handling for external API call(Pid, renegotiate). + +%%-------------------------------------------------------------------- +-spec peer_renegotiate(pid()) -> {ok, WriteState::map()} | {error, term()}. +%% Description: So TLS connection process can synchronize the +%% encryption state to be used when handshaking. +%%-------------------------------------------------------------------- +peer_renegotiate(Pid) -> + gen_statem:call(Pid, renegotiate, ?DEFAULT_TIMEOUT). + %%-------------------------------------------------------------------- -spec update_connection_state(pid(), WriteState::map(), tls_record:tls_version()) -> ok. %% Description: So TLS connection process can synchronize the @@ -110,6 +134,21 @@ renegotiate(Pid) -> %%-------------------------------------------------------------------- update_connection_state(Pid, NewState, Version) -> gen_statem:cast(Pid, {new_write, NewState, Version}). + +%%-------------------------------------------------------------------- +-spec downgrade(pid(), integer()) -> {ok, ssl_record:connection_state()} + | {error, timeout}. +%% Description: So TLS connection process can synchronize the +%% encryption state to be used when sending application data. +%%-------------------------------------------------------------------- +downgrade(Pid, Timeout) -> + try gen_statem:call(Pid, downgrade, Timeout) of + Result -> + Result + catch + _:_ -> + {error, timeout} + end. %%-------------------------------------------------------------------- -spec dist_handshake_complete(pid(), node(), term()) -> ok. %% Description: Erlang distribution callback @@ -185,13 +224,16 @@ connection({call, From}, renegotiate, #data{connection_states = #{current_write := Write}} = StateData) -> {next_state, handshake, StateData, [{reply, From, {ok, Write}}]}; connection({call, From}, {application_data, AppData}, - #data{socket_options = SockOpts} = StateData) -> - case encode_packet(AppData, SockOpts) of + #data{socket_options = #socket_options{packet = Packet}} = + StateData) -> + case encode_packet(Packet, AppData) of {error, _} = Error -> {next_state, ?FUNCTION_NAME, StateData, [{reply, From, Error}]}; Data -> send_application_data(Data, From, ?FUNCTION_NAME, StateData) end; +connection({call, From}, {set_opts, _} = Call, StateData) -> + handle_call(From, Call, ?FUNCTION_NAME, StateData); connection({call, From}, dist_get_tls_socket, #data{protocol_cb = Connection, transport_cb = Transport, @@ -200,13 +242,33 @@ connection({call, From}, dist_get_tls_socket, tracker = Tracker} = StateData) -> TLSSocket = Connection:socket([Pid, self()], Transport, Socket, Connection, Tracker), {next_state, ?FUNCTION_NAME, StateData, [{reply, From, {ok, TLSSocket}}]}; -connection({call, From}, {dist_handshake_complete, _Node, DHandle}, #data{connection_pid = Pid} = StateData) -> +connection({call, From}, {dist_handshake_complete, _Node, DHandle}, + #data{connection_pid = Pid, + socket_options = #socket_options{packet = Packet}} = + StateData) -> ok = erlang:dist_ctrl_input_handler(DHandle, Pid), ok = ssl_connection:dist_handshake_complete(Pid, DHandle), %% From now on we execute on normal priority process_flag(priority, normal), - Events = dist_data_events(DHandle, []), - {next_state, ?FUNCTION_NAME, StateData#data{dist_handle = DHandle}, [{reply, From, ok} | Events]}; + {next_state, ?FUNCTION_NAME, StateData#data{dist_handle = DHandle}, + [{reply, From, ok} + | case dist_data(DHandle, Packet) of + [] -> + []; + Data -> + [{next_event, internal, + {application_packets,{self(),undefined},Data}}] + end]}; +connection({call, From}, {ack_alert, #alert{} = Alert}, StateData0) -> + StateData = send_tls_alert(Alert, StateData0), + {next_state, ?FUNCTION_NAME, StateData, + [{reply,From,ok}]}; +connection({call, From}, downgrade, #data{connection_states = + #{current_write := Write}} = StateData) -> + {next_state, death_row, StateData, [{reply,From, {ok, Write}}]}; +connection(internal, {application_packets, From, Data}, StateData) -> + send_application_data(Data, From, ?FUNCTION_NAME, StateData); +%% connection(cast, #alert{} = Alert, StateData0) -> StateData = send_tls_alert(Alert, StateData0), {next_state, ?FUNCTION_NAME, StateData}; @@ -216,9 +278,19 @@ connection(cast, {new_write, WritesState, Version}, StateData#data{connection_states = ConnectionStates0#{current_write => WritesState}, negotiated_version = Version}}; -connection(info, dist_data, #data{dist_handle = DHandle} = StateData) -> - Events = dist_data_events(DHandle, []), - {next_state, ?FUNCTION_NAME, StateData, Events}; +%% +connection(info, dist_data, + #data{dist_handle = DHandle, + socket_options = #socket_options{packet = Packet}} = + StateData) -> + {next_state, ?FUNCTION_NAME, StateData, + case dist_data(DHandle, Packet) of + [] -> + []; + Data -> + [{next_event, internal, + {application_packets,{self(),undefined},Data}}] + end}; connection(info, tick, StateData) -> consume_ticks(), {next_state, ?FUNCTION_NAME, StateData, @@ -241,6 +313,8 @@ connection(info, Msg, StateData) -> StateData :: term()) -> gen_statem:event_handler_result(atom()). %%-------------------------------------------------------------------- +handshake({call, From}, {set_opts, _} = Call, StateData) -> + handle_call(From, Call, ?FUNCTION_NAME, StateData); handshake({call, _}, _, _) -> {keep_state_and_data, [postpone]}; handshake(cast, {new_write, WritesState, Version}, @@ -249,6 +323,8 @@ handshake(cast, {new_write, WritesState, Version}, StateData#data{connection_states = ConnectionStates0#{current_write => WritesState}, negotiated_version = Version}}; +handshake(internal, {application_packets,_,_}, _) -> + {keep_state_and_data, [postpone]}; handshake(info, Msg, StateData) -> handle_info(Msg, ?FUNCTION_NAME, StateData). @@ -285,6 +361,9 @@ code_change(_OldVsn, State, Data, _Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== +handle_call(From, {set_opts, Opts}, StateName, #data{socket_options = SockOpts} = StateData) -> + {next_state, StateName, StateData#data{socket_options = set_opts(SockOpts, Opts)}, [{reply, From, ok}]}. + handle_info({'DOWN', Monitor, _, _, Reason}, _, #data{connection_monitor = Monitor, dist_handle = Handle} = StateData) when Handle =/= undefined-> @@ -293,7 +372,7 @@ handle_info({'DOWN', Monitor, _, _, _}, _, #data{connection_monitor = Monitor} = StateData) -> {stop, normal, StateData}; handle_info(_,_,_) -> - {keep_state_and_data}. + keep_state_and_data. send_tls_alert(Alert, #data{negotiated_version = Version, socket = Socket, @@ -316,12 +395,13 @@ send_application_data(Data, From, StateName, renegotiate_at = RenegotiateAt} = StateData0) -> case time_to_renegotiate(Data, ConnectionStates0, RenegotiateAt) of true -> - ssl_connection:internal_renegotiation(Pid, ConnectionStates0), + ssl_connection:internal_renegotiation(Pid, ConnectionStates0), {next_state, handshake, StateData0, - [{next_event, {call, From}, {application_data, Data}}]}; + [{next_event, internal, {application_packets, From, Data}}]}; false -> {Msgs, ConnectionStates} = - Connection:encode_data(Data, Version, ConnectionStates0), + Connection:encode_data( + iolist_to_binary(Data), Version, ConnectionStates0), StateData = StateData0#data{connection_states = ConnectionStates}, case Connection:send(Transport, Socket, Msgs) of ok when DistHandle =/= undefined -> @@ -335,22 +415,23 @@ send_application_data(Data, From, StateName, end end. -encode_packet(Data, #socket_options{packet=Packet}) -> +-compile({inline, encode_packet/2}). +encode_packet(Packet, Data) -> + Len = iolist_size(Data), case Packet of - 1 -> encode_size_packet(Data, 8, (1 bsl 8) - 1); - 2 -> encode_size_packet(Data, 16, (1 bsl 16) - 1); - 4 -> encode_size_packet(Data, 32, (1 bsl 32) - 1); - _ -> Data + 1 when Len < (1 bsl 8) -> [<<Len:8>>,Data]; + 2 when Len < (1 bsl 16) -> [<<Len:16>>,Data]; + 4 when Len < (1 bsl 32) -> [<<Len:32>>,Data]; + N when N =:= 1; N =:= 2; N =:= 4 -> + {error, + {badarg, {packet_to_large, Len, (1 bsl (Packet bsl 3)) - 1}}}; + _ -> + Data end. -encode_size_packet(Bin, Size, Max) -> - Len = erlang:byte_size(Bin), - case Len > Max of - true -> - {error, {badarg, {packet_to_large, Len, Max}}}; - false -> - <<Len:Size, Bin/binary>> - end. +set_opts(SocketOptions, [{packet, N}]) -> + SocketOptions#socket_options{packet = N}. + time_to_renegotiate(_Data, #{current_write := #{sequence_number := Num}}, RenegotiateAt) -> @@ -379,14 +460,18 @@ call(FsmPid, Event) -> %%---------------Erlang distribution -------------------------------------- -dist_data_events(DHandle, Events) -> +dist_data(DHandle, Packet) -> case erlang:dist_ctrl_get_data(DHandle) of none -> erlang:dist_ctrl_get_data_notification(DHandle), - lists:reverse(Events); + []; Data -> - Event = {next_event, {call, {self(), undefined}}, {application_data, Data}}, - dist_data_events(DHandle, [Event | Events]) + %% This is encode_packet(4, Data) without Len check + %% since the emulator will always deliver a Data + %% smaller than 4 GB, and the distribution will + %% therefore always have to use {packet,4} + Len = iolist_size(Data), + [<<Len:32>>,Data|dist_data(DHandle, Packet)] end. consume_ticks() -> |