diff options
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 75 |
1 files changed, 67 insertions, 8 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index 53b46542e7..2a0b2b317d 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -36,7 +36,7 @@ %% Internal application API %% Setup --export([start_fsm/8, start_link/7, init/1]). +-export([start_fsm/8, start_link/7, init/1, pids/1]). %% State transition handling -export([next_record/1, next_event/3, next_event/4, handle_common_event/4]). @@ -44,10 +44,10 @@ %% Handshake handling -export([renegotiate/2, send_handshake/2, queue_handshake/2, queue_change_cipher/2, - reinit_handshake_data/1, select_sni_extension/1, empty_connection_state/2]). + reinit/1, reinit_handshake_data/1, select_sni_extension/1, empty_connection_state/2]). %% Alert and close handling --export([encode_alert/3,send_alert/2, close/5, protocol_name/0]). +-export([encode_alert/3, send_alert/2, send_alert_in_connection/2, close/5, protocol_name/0]). %% Data handling -export([encode_data/3, passive_receive/2, next_record_if_active/1, @@ -72,7 +72,7 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_, Tracker} try {ok, Pid} = dtls_connection_sup:start_child([Role, Host, Port, Socket, Opts, User, CbInfo]), - {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker), + {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, [Pid], CbModule, Tracker), ssl_connection:handshake(SslSocket, Timeout) catch error:{badmatch, {error, _} = Error} -> @@ -91,14 +91,19 @@ start_link(Role, Host, Port, Socket, Options, User, CbInfo) -> init([Role, Host, Port, Socket, Options, User, CbInfo]) -> process_flag(trap_exit, true), - State0 = initial_state(Role, Host, Port, Socket, Options, User, CbInfo), + State0 = #state{protocol_specific = Map} = initial_state(Role, Host, Port, Socket, Options, User, CbInfo), try State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0), gen_statem:enter_loop(?MODULE, [], init, State) catch throw:Error -> - gen_statem:enter_loop(?MODULE, [], error, {Error,State0}) + EState = State0#state{protocol_specific = Map#{error => Error}}, + gen_statem:enter_loop(?MODULE, [], error, EState) end. + +pids(_) -> + [self()]. + %%==================================================================== %% State transition handling %%==================================================================== @@ -327,10 +332,14 @@ queue_change_cipher(ChangeCipher, #state{flight_buffer = Flight, dtls_record:next_epoch(ConnectionStates0, write), State#state{flight_buffer = Flight#{change_cipher_spec => ChangeCipher}, connection_states = ConnectionStates}. + +reinit(State) -> + %% To be API compatible with TLS NOOP here + reinit_handshake_data(State). reinit_handshake_data(#state{protocol_buffers = Buffers} = State) -> State#state{premaster_secret = undefined, public_key_info = undefined, - tls_handshake_history = ssl_handshake:init_handshake_history(), + tls_handshake_history = ssl_handshake:init_handshake_history(), flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}, flight_buffer = new_flight(), protocol_buffers = @@ -364,6 +373,10 @@ send_alert(Alert, #state{negotiated_version = Version, send(Transport, Socket, BinMsg), State0#state{connection_states = ConnectionStates}. +send_alert_in_connection(Alert, State) -> + _ = send_alert(Alert, State), + ok. + close(downgrade, _,_,_,_) -> ok; %% Other @@ -470,7 +483,8 @@ init(Type, Event, State) -> %%-------------------------------------------------------------------- error(enter, _, State) -> {keep_state, State}; -error({call, From}, {start, _Timeout}, {Error, State}) -> +error({call, From}, {start, _Timeout}, + #state{protocol_specific = #{error := Error}} = State) -> ssl_connection:stop_and_reply( normal, {reply, From, {error, Error}}, State); error({call, _} = Call, Msg, State) -> @@ -708,6 +722,12 @@ connection(internal, #client_hello{}, #state{role = server, allow_renegotiate = State1 = send_alert(Alert, State0), {Record, State} = ssl_connection:prepare_connection(State1, ?MODULE), next_event(?FUNCTION_NAME, Record, State); +connection({call, From}, {application_data, Data}, State) -> + try + send_application_data(Data, From, ?FUNCTION_NAME, State) + catch throw:Error -> + ssl_connection:hibernate_after(?FUNCTION_NAME, State, [{reply, From, Error}]) + end; connection(Type, Event, State) -> ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE). @@ -1129,3 +1149,42 @@ log_ignore_alert(true, StateName, Alert, Role) -> [Role, StateName, Txt]); log_ignore_alert(false, _, _,_) -> ok. + +send_application_data(Data, From, _StateName, + #state{socket = Socket, + negotiated_version = Version, + protocol_cb = Connection, + transport_cb = Transport, + connection_states = ConnectionStates0, + ssl_options = #ssl_options{renegotiate_at = RenegotiateAt}} = State0) -> + + case time_to_renegotiate(Data, ConnectionStates0, RenegotiateAt) of + true -> + renegotiate(State0#state{renegotiation = {true, internal}}, + [{next_event, {call, From}, {application_data, Data}}]); + false -> + {Msgs, ConnectionStates} = + Connection:encode_data(Data, Version, ConnectionStates0), + State = State0#state{connection_states = ConnectionStates}, + case Connection:send(Transport, Socket, Msgs) of + ok -> + ssl_connection:hibernate_after(connection, State, [{reply, From, ok}]); + Result -> + ssl_connection:hibernate_after(connection, State, [{reply, From, Result}]) + end + end. + +time_to_renegotiate(_Data, + #{current_write := #{sequence_number := Num}}, + RenegotiateAt) -> + + %% We could do test: + %% is_time_to_renegotiate((erlang:byte_size(_Data) div + %% ?MAX_PLAIN_TEXT_LENGTH) + 1, RenegotiateAt), but we chose to + %% have a some what lower renegotiateAt and a much cheaper test + is_time_to_renegotiate(Num, RenegotiateAt). + +is_time_to_renegotiate(N, M) when N < M-> + false; +is_time_to_renegotiate(_,_) -> + true. |