aboutsummaryrefslogtreecommitdiffstats
path: root/lib/ssl/src/dtls_connection.erl
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r--lib/ssl/src/dtls_connection.erl215
1 files changed, 119 insertions, 96 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index 070a90d481..440607e99d 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -39,7 +39,7 @@
-export([start_fsm/8, start_link/7, init/1]).
%% State transition handling
--export([next_record/1, next_event/3]).
+-export([next_record/1, next_event/3, next_event/4]).
%% Handshake handling
-export([renegotiate/2,
@@ -53,7 +53,7 @@
%% Data handling
-export([encode_data/3, passive_receive/2, next_record_if_active/1, handle_common_event/4,
- send/3]).
+ send/3, socket/5]).
%% gen_statem state functions
-export([init/3, error/3, downgrade/3, %% Initiation and take down states
@@ -77,20 +77,6 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_, Tracker}
catch
error:{badmatch, {error, _} = Error} ->
Error
- end;
-
-start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} = Opts,
- User, {CbModule, _,_, _} = CbInfo,
- Timeout) ->
- try
- {ok, Pid} = dtls_connection_sup:start_child_dist([Role, Host, Port, Socket,
- Opts, User, CbInfo]),
- {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker),
- ok = ssl_connection:handshake(SslSocket, Timeout),
- {ok, SslSocket}
- catch
- error:{badmatch, {error, _} = Error} ->
- Error
end.
send_handshake(Handshake, #state{connection_states = ConnectionStates} = States) ->
@@ -201,6 +187,7 @@ reinit_handshake_data(#state{protocol_buffers = Buffers} = State) ->
State#state{premaster_secret = undefined,
public_key_info = undefined,
tls_handshake_history = ssl_handshake:init_handshake_history(),
+ flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT},
protocol_buffers =
Buffers#protocol_buffers{
dtls_handshake_next_seq = 0,
@@ -213,6 +200,9 @@ select_sni_extension(#client_hello{extensions = HelloExtensions}) ->
select_sni_extension(_) ->
undefined.
+socket(Pid, Transport, Socket, Connection, _) ->
+ dtls_socket:socket(Pid, Transport, Socket, Connection).
+
%%====================================================================
%% tls_connection_sup API
%%====================================================================
@@ -243,7 +233,7 @@ callback_mode() ->
state_functions.
%%--------------------------------------------------------------------
-%% State functionsconnection/2
+%% State functions
%%--------------------------------------------------------------------
init({call, From}, {start, Timeout},
@@ -262,18 +252,25 @@ init({call, From}, {start, Timeout},
Version = Hello#client_hello.client_version,
HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions),
State1 = prepare_flight(State0#state{negotiated_version = Version}),
- State2 = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),
+ {State2, Actions} = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),
State3 = State2#state{negotiated_version = Version, %% Requested version
session =
Session0#session{session_id = Hello#client_hello.session_id},
start_or_recv_from = From,
- timer = Timer},
+ timer = Timer,
+ flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}
+ },
{Record, State} = next_record(State3),
- next_event(hello, Record, State);
+ next_event(hello, Record, State, Actions);
init({call, _} = Type, Event, #state{role = server, transport_cb = gen_udp} = State) ->
- ssl_connection:init(Type, Event,
- State#state{flight_state = {waiting, undefined, ?INITIAL_RETRANSMIT_TIMEOUT}},
- ?MODULE);
+ Result = ssl_connection:init(Type, Event,
+ State#state{flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT},
+ protocol_specific = #{current_cookie_secret => dtls_v1:cookie_secret(),
+ previous_cookie_secret => <<>>}},
+ ?MODULE),
+ erlang:send_after(dtls_v1:cookie_timeout(), self(), new_cookie_secret),
+ Result;
+
init({call, _} = Type, Event, #state{role = server} = State) ->
%% I.E. DTLS over sctp
ssl_connection:init(Type, Event, State#state{flight_state = reliable}, ?MODULE);
@@ -296,26 +293,32 @@ error(_, _, _) ->
hello(internal, #client_hello{cookie = <<>>,
client_version = Version} = Hello, #state{role = server,
transport_cb = Transport,
- socket = Socket} = State0) ->
- %% TODO: not hard code key
+ socket = Socket,
+ protocol_specific = #{current_cookie_secret := Secret}} = State0) ->
{ok, {IP, Port}} = dtls_socket:peername(Transport, Socket),
- Cookie = dtls_handshake:cookie(<<"secret">>, IP, Port, Hello),
+ Cookie = dtls_handshake:cookie(Secret, IP, Port, Hello),
VerifyRequest = dtls_handshake:hello_verify_request(Cookie, Version),
State1 = prepare_flight(State0#state{negotiated_version = Version}),
- State2 = send_handshake(VerifyRequest, State1),
+ {State2, Actions} = send_handshake(VerifyRequest, State1),
{Record, State} = next_record(State2),
- next_event(hello, Record, State#state{tls_handshake_history = ssl_handshake:init_handshake_history()});
+ next_event(hello, Record, State#state{tls_handshake_history = ssl_handshake:init_handshake_history()}, Actions);
hello(internal, #client_hello{cookie = Cookie} = Hello, #state{role = server,
transport_cb = Transport,
- socket = Socket} = State0) ->
+ socket = Socket,
+ protocol_specific = #{current_cookie_secret := Secret,
+ previous_cookie_secret := PSecret}} = State0) ->
{ok, {IP, Port}} = dtls_socket:peername(Transport, Socket),
- %% TODO: not hard code key
- case dtls_handshake:cookie(<<"secret">>, IP, Port, Hello) of
+ case dtls_handshake:cookie(Secret, IP, Port, Hello) of
Cookie ->
handle_client_hello(Hello, State0);
_ ->
- %% Handle bad cookie as new cookie request RFC 6347 4.1.2
- hello(internal, Hello#client_hello{cookie = <<>>}, State0)
+ case dtls_handshake:cookie(PSecret, IP, Port, Hello) of
+ Cookie ->
+ handle_client_hello(Hello, State0);
+ _ ->
+ %% Handle bad cookie as new cookie request RFC 6347 4.1.2
+ hello(internal, Hello#client_hello{cookie = <<>>}, State0)
+ end
end;
hello(internal, #hello_verify_request{cookie = Cookie}, #state{role = client,
host = Host, port = Port,
@@ -333,13 +336,13 @@ hello(internal, #hello_verify_request{cookie = Cookie}, #state{role = client,
Cache, CacheCb, Renegotiation, OwnCert),
Version = Hello#client_hello.client_version,
HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions),
- State2 = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),
+ {State2, Actions} = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),
State3 = State2#state{negotiated_version = Version, %% Requested version
session =
Session0#session{session_id =
Hello#client_hello.session_id}},
{Record, State} = next_record(State3),
- next_event(hello, Record, State);
+ next_event(hello, Record, State, Actions);
hello(internal, #server_hello{} = Hello,
#state{connection_states = ConnectionStates0,
negotiated_version = ReqVersion,
@@ -356,13 +359,13 @@ hello(internal, #server_hello{} = Hello,
hello(internal, {handshake, {#client_hello{cookie = <<>>} = Handshake, _}}, State) ->
%% Initial hello should not be in handshake history
{next_state, hello, State, [{next_event, internal, Handshake}]};
-
hello(internal, {handshake, {#hello_verify_request{} = Handshake, _}}, State) ->
%% hello_verify should not be in handshake history
{next_state, hello, State, [{next_event, internal, Handshake}]};
-
hello(info, Event, State) ->
handle_info(Event, hello, State);
+hello(state_timeout, Event, State) ->
+ handle_state_timeout(Event, hello, State);
hello(Type, Event, State) ->
ssl_connection:hello(Type, Event, State, ?MODULE).
@@ -375,7 +378,11 @@ abbreviated(internal = Type,
ConnectionStates = dtls_record:next_epoch(ConnectionStates1, read),
ssl_connection:abbreviated(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE);
abbreviated(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) ->
- ssl_connection:cipher(Type, Event, prepare_flight(State#state{connection_states = ConnectionStates}), ?MODULE);
+ ssl_connection:abbreviated(Type, Event,
+ prepare_flight(State#state{connection_states = ConnectionStates,
+ flight_state = connection}), ?MODULE);
+abbreviated(state_timeout, Event, State) ->
+ handle_state_timeout(Event, abbreviated, State);
abbreviated(Type, Event, State) ->
ssl_connection:abbreviated(Type, Event, State, ?MODULE).
@@ -383,6 +390,8 @@ certify(info, Event, State) ->
handle_info(Event, certify, State);
certify(internal = Type, #server_hello_done{} = Event, State) ->
ssl_connection:certify(Type, Event, prepare_flight(State), ?MODULE);
+certify(state_timeout, Event, State) ->
+ handle_state_timeout(Event, certify, State);
certify(Type, Event, State) ->
ssl_connection:certify(Type, Event, State, ?MODULE).
@@ -395,7 +404,11 @@ cipher(internal = Type, #change_cipher_spec{type = <<1>>} = Event,
ssl_connection:cipher(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE);
cipher(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) ->
ssl_connection:cipher(Type, Event,
- prepare_flight(State#state{connection_states = ConnectionStates}), ?MODULE);
+ prepare_flight(State#state{connection_states = ConnectionStates,
+ flight_state = connection}),
+ ?MODULE);
+cipher(state_timeout, Event, State) ->
+ handle_state_timeout(Event, cipher, State);
cipher(Type, Event, State) ->
ssl_connection:cipher(Type, Event, State, ?MODULE).
@@ -409,12 +422,12 @@ connection(internal, #hello_request{}, #state{host = Host, port = Port,
renegotiation = {Renegotiation, _}} = State0) ->
Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts,
Cache, CacheCb, Renegotiation, Cert),
- State1 = send_handshake(Hello, State0),
+ {State1, Actions} = send_handshake(Hello, State0),
{Record, State} =
next_record(
State1#state{session = Session0#session{session_id
= Hello#client_hello.session_id}}),
- next_event(hello, Record, State);
+ next_event(hello, Record, State, Actions);
connection(internal, #client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) ->
%% Mitigate Computational DoS attack
%% http://www.educatedguesswork.org/2011/10/ssltls_and_computational_dos.html
@@ -434,7 +447,6 @@ connection(Type, Event, State) ->
downgrade(Type, Event, State) ->
ssl_connection:downgrade(Type, Event, State, ?MODULE).
-
%%--------------------------------------------------------------------
%% Description: This function is called by a gen_fsm when it receives any
%% other message than a synchronous or asynchronous event
@@ -442,16 +454,6 @@ downgrade(Type, Event, State) ->
%%--------------------------------------------------------------------
%% raw data from socket, unpack records
-handle_info({_,flight_retransmission_timeout}, connection, _) ->
- {next_state, keep_state_and_data};
-handle_info({Ref, flight_retransmission_timeout}, StateName,
- #state{flight_state = {waiting, Ref, NextTimeout}} = State0) ->
- State1 = send_handshake_flight(State0#state{flight_state = {retransmit_timer, NextTimeout}},
- retransmit_epoch(StateName, State0)),
- {Record, State} = next_record(State1),
- next_event(StateName, Record, State);
-handle_info({_, flight_retransmission_timeout}, _, _) ->
- {next_state, keep_state_and_data};
handle_info({Protocol, _, _, _, Data}, StateName,
#state{data_tag = Protocol} = State0) ->
case next_dtls_record(Data, State0) of
@@ -462,34 +464,52 @@ handle_info({Protocol, _, _, _, Data}, StateName,
{stop, {shutdown, own_alert}}
end;
handle_info({CloseTag, Socket}, StateName,
- #state{socket = Socket, close_tag = CloseTag,
+ #state{socket = Socket,
+ socket_options = #socket_options{active = Active},
+ protocol_buffers = #protocol_buffers{dtls_cipher_texts = CTs},
+ close_tag = CloseTag,
negotiated_version = Version} = State) ->
%% Note that as of DTLS 1.2 (TLS 1.1),
%% failure to properly close a connection no longer requires that a
%% session not be resumed. This is a change from DTLS 1.0 to conform
%% with widespread implementation practice.
- case Version of
- {254, N} when N =< 253 ->
- ok;
- _ ->
- %% As invalidate_sessions here causes performance issues,
- %% we will conform to the widespread implementation
- %% practice and go aginst the spec
- %%invalidate_session(Role, Host, Port, Session)
- ok
- end,
- ssl_connection:handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State),
- {stop, {shutdown, transport_closed}};
+ case (Active == false) andalso (CTs =/= []) of
+ false ->
+ case Version of
+ {254, N} when N =< 253 ->
+ ok;
+ _ ->
+ %% As invalidate_sessions here causes performance issues,
+ %% we will conform to the widespread implementation
+ %% practice and go aginst the spec
+ %%invalidate_session(Role, Host, Port, Session)
+ ok
+ end,
+ ssl_connection:handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State),
+ {stop, {shutdown, transport_closed}};
+ true ->
+ %% Fixes non-delivery of final DTLS record in {active, once}.
+ %% Basically allows the application the opportunity to set {active, once} again
+ %% and then receive the final message.
+ next_event(StateName, no_record, State)
+ end;
+
+handle_info(new_cookie_secret, StateName,
+ #state{protocol_specific = #{current_cookie_secret := Secret} = CookieInfo} = State) ->
+ erlang:send_after(dtls_v1:cookie_timeout(), self(), new_cookie_secret),
+ {next_state, StateName, State#state{protocol_specific =
+ CookieInfo#{current_cookie_secret => dtls_v1:cookie_secret(),
+ previous_cookie_secret => Secret}}};
handle_info(Msg, StateName, State) ->
ssl_connection:handle_info(Msg, StateName, State).
+
handle_call(Event, From, StateName, State) ->
ssl_connection:handle_call(Event, From, StateName, State, ?MODULE).
handle_common_event(internal, #alert{} = Alert, StateName,
#state{negotiated_version = Version} = State) ->
ssl_connection:handle_own_alert(Alert, Version, StateName, State);
-
%%% DTLS record protocol level handshake messages
handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE,
fragment = Data},
@@ -498,19 +518,14 @@ handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE,
negotiated_version = Version} = State0) ->
try
case dtls_handshake:get_dtls_handshake(Version, Data, Buffers0) of
- {more_data, Buffers} ->
+ {[], Buffers} ->
{Record, State} = next_record(State0#state{protocol_buffers = Buffers}),
next_event(StateName, Record, State);
{Packets, Buffers} ->
State = State0#state{protocol_buffers = Buffers},
Events = dtls_handshake_events(Packets),
- case StateName of
- connection ->
- ssl_connection:hibernate_after(StateName, State, Events);
- _ ->
- {next_state, StateName,
- State#state{unprocessed_handshake_events = unprocessed_events(Events)}, Events}
- end
+ {next_state, StateName,
+ State#state{unprocessed_handshake_events = unprocessed_events(Events)}, Events}
end
catch throw:#alert{} = Alert ->
ssl_connection:handle_own_alert(Alert, Version, StateName, State0)
@@ -534,6 +549,13 @@ handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, Sta
handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) ->
{next_state, StateName, State}.
+handle_state_timeout(flight_retransmission_timeout, StateName,
+ #state{flight_state = {retransmit, NextTimeout}} = State0) ->
+ {State1, Actions} = send_handshake_flight(State0#state{flight_state = {retransmit, NextTimeout}},
+ retransmit_epoch(StateName, State0)),
+ {Record, State} = next_record(State1),
+ next_event(StateName, Record, State, Actions).
+
send(Transport, {_, {{_,_}, _} = Socket}, Data) ->
send(Transport, Socket, Data);
send(Transport, Socket, Data) ->
@@ -645,7 +667,8 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, _}, User,
allow_renegotiate = SSLOptions#ssl_options.client_renegotiation,
start_or_recv_from = undefined,
protocol_cb = ?MODULE,
- flight_buffer = new_flight()
+ flight_buffer = new_flight(),
+ flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}
}.
next_dtls_record(Data, #state{protocol_buffers = #protocol_buffers{
@@ -714,14 +737,14 @@ next_event(connection = StateName, no_record,
#state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->
case next_record_if_active(State0) of
{no_record, State} ->
- ssl_connection:hibernate_after(StateName, State, Actions);
+ ssl_connection:hibernate_after(StateName, State, Actions);
{#ssl_tls{epoch = CurrentEpoch} = Record, State} ->
{next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]};
{#ssl_tls{epoch = Epoch,
type = ?HANDSHAKE,
version = _Version}, State1} = _Record when Epoch == CurrentEpoch-1 ->
- State = send_handshake_flight(State1, Epoch),
- {next_state, StateName, State, Actions};
+ {State, MoreActions} = send_handshake_flight(State1, Epoch),
+ {next_state, StateName, State, Actions ++ MoreActions};
{#ssl_tls{epoch = _Epoch,
version = _Version}, State} ->
%% TODO maybe buffer later epoch
@@ -772,17 +795,20 @@ next_flight(Flight) ->
Flight#{handshakes => [],
change_cipher_spec => undefined,
handshakes_after_change_cipher_spec => []}.
-
start_flight(#state{transport_cb = gen_udp,
- flight_state = {retransmit_timer, Timeout}} = State) ->
- Ref = erlang:make_ref(),
- _ = erlang:send_after(Timeout, self(), {Ref, flight_retransmission_timeout}),
- State#state{flight_state = {waiting, Ref, new_timeout(Timeout)}};
-
+ flight_state = {retransmit, Timeout}} = State) ->
+ start_retransmision_timer(Timeout, State);
+start_flight(#state{transport_cb = gen_udp,
+ flight_state = connection} = State) ->
+ {State, []};
start_flight(State) ->
%% No retransmision needed i.e DTLS over SCTP
- State#state{flight_state = reliable}.
+ {State#state{flight_state = reliable}, []}.
+
+start_retransmision_timer(Timeout, State) ->
+ {State#state{flight_state = {retransmit, new_timeout(Timeout)}},
+ [{state_timeout, Timeout, flight_retransmission_timeout}]}.
new_timeout(N) when N =< 30 ->
N * 2;
@@ -806,13 +832,13 @@ renegotiate(#state{role = server,
connection_states = CS0} = State0, Actions) ->
HelloRequest = ssl_handshake:hello_request(),
CS = CS0#{write_msg_seq => 0},
- State1 = send_handshake(HelloRequest,
- State0#state{connection_states =
- CS}),
+ {State1, MoreActions} = send_handshake(HelloRequest,
+ State0#state{connection_states =
+ CS}),
Hs0 = ssl_handshake:init_handshake_history(),
{Record, State} = next_record(State1#state{tls_handshake_history = Hs0,
protocol_buffers = #protocol_buffers{}}),
- next_event(hello, Record, State, Actions).
+ next_event(hello, Record, State, Actions ++ MoreActions).
handle_alerts([], Result) ->
Result;
@@ -823,15 +849,11 @@ handle_alerts([Alert | Alerts], {next_state, StateName, State}) ->
handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) ->
handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)).
-retransmit_epoch(StateName, #state{connection_states = ConnectionStates}) ->
+retransmit_epoch(_StateName, #state{connection_states = ConnectionStates}) ->
#{epoch := Epoch} =
ssl_record:current_connection_state(ConnectionStates, write),
- case StateName of
- connection ->
- Epoch-1;
- _ ->
- Epoch
- end.
+ Epoch.
+
update_handshake_history(#hello_verify_request{}, _, Hist) ->
Hist;
@@ -846,3 +868,4 @@ unprocessed_events(Events) ->
%% handshake events left to process before we should
%% process more TLS-records received on the socket.
erlang:length(Events)-1.
+