aboutsummaryrefslogtreecommitdiffstats
path: root/lib/ssl/src/dtls_connection.erl
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r--lib/ssl/src/dtls_connection.erl1164
1 files changed, 717 insertions, 447 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index 745db788e7..03725089dd 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -39,21 +39,19 @@
-export([start_fsm/8, start_link/7, init/1]).
%% State transition handling
--export([next_record/1, next_event/3, next_event/4]).
+-export([next_record/1, next_event/3, next_event/4, handle_common_event/4]).
%% Handshake handling
--export([renegotiate/2,
- reinit_handshake_data/1,
- send_handshake/2, queue_handshake/2, queue_change_cipher/2,
- select_sni_extension/1]).
+-export([renegotiate/2, send_handshake/2,
+ queue_handshake/2, queue_change_cipher/2,
+ reinit_handshake_data/1, select_sni_extension/1, empty_connection_state/2]).
%% Alert and close handling
--export([encode_alert/3,send_alert/2, close/5]).
+-export([encode_alert/3,send_alert/2, close/5, protocol_name/0]).
%% Data handling
-
--export([encode_data/3, passive_receive/2, next_record_if_active/1, handle_common_event/4,
- send/3, socket/5]).
+-export([encode_data/3, passive_receive/2, next_record_if_active/1,
+ send/3, socket/5, setopts/3, getopts/3]).
%% gen_statem state functions
-export([init/3, error/3, downgrade/3, %% Initiation and take down states
@@ -64,6 +62,9 @@
%%====================================================================
%% Internal application API
+%%====================================================================
+%%====================================================================
+%% Setup
%%====================================================================
start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_, Tracker} = Opts,
User, {CbModule, _,_, _} = CbInfo,
@@ -79,9 +80,227 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_, Tracker}
Error
end.
-send_handshake(Handshake, #state{connection_states = ConnectionStates} = States) ->
+%%--------------------------------------------------------------------
+-spec start_link(atom(), host(), inet:port_number(), port(), list(), pid(), tuple()) ->
+ {ok, pid()} | ignore | {error, reason()}.
+%%
+%% Description: Creates a gen_statem process which calls Module:init/1 to
+%% initialize.
+%%--------------------------------------------------------------------
+start_link(Role, Host, Port, Socket, Options, User, CbInfo) ->
+ {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Host, Port, Socket, Options, User, CbInfo]])}.
+
+init([Role, Host, Port, Socket, Options, User, CbInfo]) ->
+ process_flag(trap_exit, true),
+ State0 = initial_state(Role, Host, Port, Socket, Options, User, CbInfo),
+ try
+ State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0),
+ gen_statem:enter_loop(?MODULE, [], init, State)
+ catch
+ throw:Error ->
+ gen_statem:enter_loop(?MODULE, [], error, {Error,State0})
+ end.
+%%====================================================================
+%% State transition handling
+%%====================================================================
+next_record(#state{unprocessed_handshake_events = N} = State) when N > 0 ->
+ {no_record, State#state{unprocessed_handshake_events = N-1}};
+
+next_record(#state{protocol_buffers =
+ #protocol_buffers{dtls_cipher_texts = [#ssl_tls{epoch = Epoch} = CT | Rest]}
+ = Buffers,
+ connection_states = #{current_read := #{epoch := Epoch}} = ConnectionStates} = State) ->
+ CurrentRead = dtls_record:get_connection_state_by_epoch(Epoch, ConnectionStates, read),
+ case dtls_record:replay_detect(CT, CurrentRead) of
+ false ->
+ decode_cipher_text(State#state{connection_states = ConnectionStates}) ;
+ true ->
+ %% Ignore replayed record
+ next_record(State#state{protocol_buffers =
+ Buffers#protocol_buffers{dtls_cipher_texts = Rest},
+ connection_states = ConnectionStates})
+ end;
+next_record(#state{protocol_buffers =
+ #protocol_buffers{dtls_cipher_texts = [#ssl_tls{epoch = Epoch} | Rest]}
+ = Buffers,
+ connection_states = #{current_read := #{epoch := CurrentEpoch}} = ConnectionStates} = State)
+ when Epoch > CurrentEpoch ->
+ %% TODO Buffer later Epoch message, drop it for now
+ next_record(State#state{protocol_buffers =
+ Buffers#protocol_buffers{dtls_cipher_texts = Rest},
+ connection_states = ConnectionStates});
+next_record(#state{protocol_buffers =
+ #protocol_buffers{dtls_cipher_texts = [ _ | Rest]}
+ = Buffers,
+ connection_states = ConnectionStates} = State) ->
+ %% Drop old epoch message
+ next_record(State#state{protocol_buffers =
+ Buffers#protocol_buffers{dtls_cipher_texts = Rest},
+ connection_states = ConnectionStates});
+next_record(#state{role = server,
+ socket = {Listener, {Client, _}},
+ transport_cb = gen_udp} = State) ->
+ dtls_udp_listener:active_once(Listener, Client, self()),
+ {no_record, State};
+next_record(#state{role = client,
+ socket = {_Server, Socket} = DTLSSocket,
+ close_tag = CloseTag,
+ transport_cb = Transport} = State) ->
+ case dtls_socket:setopts(Transport, Socket, [{active,once}]) of
+ ok ->
+ {no_record, State};
+ _ ->
+ self() ! {CloseTag, DTLSSocket},
+ {no_record, State}
+ end;
+next_record(State) ->
+ {no_record, State}.
+
+next_event(StateName, Record, State) ->
+ next_event(StateName, Record, State, []).
+
+next_event(connection = StateName, no_record,
+ #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->
+ case next_record_if_active(State0) of
+ {no_record, State} ->
+ ssl_connection:hibernate_after(StateName, State, Actions);
+ {#ssl_tls{epoch = CurrentEpoch,
+ type = ?HANDSHAKE,
+ version = Version} = Record, State1} ->
+ State = dtls_version(StateName, Version, State1),
+ {next_state, StateName, State,
+ [{next_event, internal, {protocol_record, Record}} | Actions]};
+ {#ssl_tls{epoch = CurrentEpoch} = Record, State} ->
+ {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]};
+ {#ssl_tls{epoch = Epoch,
+ type = ?HANDSHAKE,
+ version = _Version}, State1} = _Record when Epoch == CurrentEpoch-1 ->
+ {State2, MoreActions} = send_handshake_flight(State1, CurrentEpoch),
+ {NextRecord, State} = next_record(State2),
+ next_event(StateName, NextRecord, State, Actions ++ MoreActions);
+ %% From FLIGHT perspective CHANGE_CIPHER_SPEC is treated as a handshake
+ {#ssl_tls{epoch = Epoch,
+ type = ?CHANGE_CIPHER_SPEC,
+ version = _Version}, State1} = _Record when Epoch == CurrentEpoch-1 ->
+ {State2, MoreActions} = send_handshake_flight(State1, CurrentEpoch),
+ {NextRecord, State} = next_record(State2),
+ next_event(StateName, NextRecord, State, Actions ++ MoreActions);
+ {#ssl_tls{epoch = _Epoch,
+ version = _Version}, State1} ->
+ %% TODO maybe buffer later epoch
+ {Record, State} = next_record(State1),
+ next_event(StateName, Record, State, Actions);
+ {#alert{} = Alert, State} ->
+ {next_state, StateName, State, [{next_event, internal, Alert} | Actions]}
+ end;
+next_event(connection = StateName, Record,
+ #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->
+ case Record of
+ #ssl_tls{epoch = CurrentEpoch,
+ type = ?HANDSHAKE,
+ version = Version} = Record ->
+ State = dtls_version(StateName, Version, State0),
+ {next_state, StateName, State,
+ [{next_event, internal, {protocol_record, Record}} | Actions]};
+ #ssl_tls{epoch = CurrentEpoch} ->
+ {next_state, StateName, State0, [{next_event, internal, {protocol_record, Record}} | Actions]};
+ #ssl_tls{epoch = Epoch,
+ type = ?HANDSHAKE,
+ version = _Version} when Epoch == CurrentEpoch-1 ->
+ {State1, MoreActions} = send_handshake_flight(State0, CurrentEpoch),
+ {NextRecord, State} = next_record(State1),
+ next_event(StateName, NextRecord, State, Actions ++ MoreActions);
+ %% From FLIGHT perspective CHANGE_CIPHER_SPEC is treated as a handshake
+ #ssl_tls{epoch = Epoch,
+ type = ?CHANGE_CIPHER_SPEC,
+ version = _Version} when Epoch == CurrentEpoch-1 ->
+ {State1, MoreActions} = send_handshake_flight(State0, CurrentEpoch),
+ {NextRecord, State} = next_record(State1),
+ next_event(StateName, NextRecord, State, Actions ++ MoreActions);
+ _ ->
+ next_event(StateName, no_record, State0, Actions)
+ end;
+next_event(StateName, Record,
+ #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->
+ case Record of
+ no_record ->
+ {next_state, StateName, State0, Actions};
+ #ssl_tls{epoch = CurrentEpoch,
+ version = Version} = Record ->
+ State = dtls_version(StateName, Version, State0),
+ {next_state, StateName, State,
+ [{next_event, internal, {protocol_record, Record}} | Actions]};
+ #ssl_tls{epoch = _Epoch,
+ version = _Version} = _Record ->
+ %% TODO maybe buffer later epoch
+ {Record, State} = next_record(State0),
+ next_event(StateName, Record, State, Actions);
+ #alert{} = Alert ->
+ {next_state, StateName, State0, [{next_event, internal, Alert} | Actions]}
+ end.
+
+handle_common_event(internal, #alert{} = Alert, StateName,
+ #state{negotiated_version = Version} = State) ->
+ handle_own_alert(Alert, Version, StateName, State);
+%%% DTLS record protocol level handshake messages
+handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE,
+ fragment = Data},
+ StateName,
+ #state{protocol_buffers = Buffers0,
+ negotiated_version = Version} = State0) ->
+ try
+ case dtls_handshake:get_dtls_handshake(Version, Data, Buffers0) of
+ {[], Buffers} ->
+ {Record, State} = next_record(State0#state{protocol_buffers = Buffers}),
+ next_event(StateName, Record, State);
+ {Packets, Buffers} ->
+ State = State0#state{protocol_buffers = Buffers},
+ Events = dtls_handshake_events(Packets),
+ {next_state, StateName,
+ State#state{unprocessed_handshake_events = unprocessed_events(Events)}, Events}
+ end
+ catch throw:#alert{} = Alert ->
+ handle_own_alert(Alert, Version, StateName, State0)
+ end;
+%%% DTLS record protocol level application data messages
+handle_common_event(internal, #ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, StateName, State) ->
+ {next_state, StateName, State, [{next_event, internal, {application_data, Data}}]};
+%%% DTLS record protocol level change cipher messages
+handle_common_event(internal, #ssl_tls{type = ?CHANGE_CIPHER_SPEC, fragment = Data}, StateName, State) ->
+ {next_state, StateName, State, [{next_event, internal, #change_cipher_spec{type = Data}}]};
+%%% DTLS record protocol level Alert messages
+handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, StateName,
+ #state{negotiated_version = Version} = State) ->
+ case decode_alerts(EncAlerts) of
+ Alerts = [_|_] ->
+ handle_alerts(Alerts, {next_state, StateName, State});
+ #alert{} = Alert ->
+ handle_own_alert(Alert, Version, StateName, State)
+ end;
+%% Ignore unknown TLS record level protocol messages
+handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) ->
+ {next_state, StateName, State}.
+
+%%====================================================================
+%% Handshake handling
+%%====================================================================
+
+renegotiate(#state{role = client} = State, Actions) ->
+ %% Handle same way as if server requested
+ %% the renegotiation
+ {next_state, connection, State,
+ [{next_event, internal, #hello_request{}} | Actions]};
+
+renegotiate(#state{role = server} = State0, Actions) ->
+ HelloRequest = ssl_handshake:hello_request(),
+ State1 = prepare_flight(State0),
+ {State2, MoreActions} = send_handshake(HelloRequest, State1),
+ {Record, State} = next_record(State2),
+ next_event(hello, Record, State, Actions ++ MoreActions).
+
+send_handshake(Handshake, #state{connection_states = ConnectionStates} = State) ->
#{epoch := Epoch} = ssl_record:current_connection_state(ConnectionStates, write),
- send_handshake_flight(queue_handshake(Handshake, States), Epoch).
+ send_handshake_flight(queue_handshake(Handshake, State), Epoch).
queue_handshake(Handshake0, #state{tls_handshake_history = Hist0,
negotiated_version = Version,
@@ -104,93 +323,21 @@ queue_handshake(Handshake0, #state{tls_handshake_history = Hist0,
next_sequence => Seq +1},
tls_handshake_history = Hist}.
-
-send_handshake_flight(#state{socket = Socket,
- transport_cb = Transport,
- flight_buffer = #{handshakes := Flight,
- change_cipher_spec := undefined},
- negotiated_version = Version,
- connection_states = ConnectionStates0} = State0, Epoch) ->
- %% TODO remove hardcoded Max size
- {Encoded, ConnectionStates} =
- encode_handshake_flight(lists:reverse(Flight), Version, 1400, Epoch, ConnectionStates0),
- send(Transport, Socket, Encoded),
- start_flight(State0#state{connection_states = ConnectionStates});
-
-send_handshake_flight(#state{socket = Socket,
- transport_cb = Transport,
- flight_buffer = #{handshakes := [_|_] = Flight0,
- change_cipher_spec := ChangeCipher,
- handshakes_after_change_cipher_spec := []},
- negotiated_version = Version,
- connection_states = ConnectionStates0} = State0, Epoch) ->
- {HsBefore, ConnectionStates1} =
- encode_handshake_flight(lists:reverse(Flight0), Version, 1400, Epoch, ConnectionStates0),
- {EncChangeCipher, ConnectionStates} = encode_change_cipher(ChangeCipher, Version, Epoch, ConnectionStates1),
-
- send(Transport, Socket, [HsBefore, EncChangeCipher]),
- start_flight(State0#state{connection_states = ConnectionStates});
-
-send_handshake_flight(#state{socket = Socket,
- transport_cb = Transport,
- flight_buffer = #{handshakes := [_|_] = Flight0,
- change_cipher_spec := ChangeCipher,
- handshakes_after_change_cipher_spec := Flight1},
- negotiated_version = Version,
- connection_states = ConnectionStates0} = State0, Epoch) ->
- {HsBefore, ConnectionStates1} =
- encode_handshake_flight(lists:reverse(Flight0), Version, 1400, Epoch-1, ConnectionStates0),
- {EncChangeCipher, ConnectionStates2} =
- encode_change_cipher(ChangeCipher, Version, Epoch-1, ConnectionStates1),
- {HsAfter, ConnectionStates} =
- encode_handshake_flight(lists:reverse(Flight1), Version, 1400, Epoch, ConnectionStates2),
- send(Transport, Socket, [HsBefore, EncChangeCipher, HsAfter]),
- start_flight(State0#state{connection_states = ConnectionStates});
-
-send_handshake_flight(#state{socket = Socket,
- transport_cb = Transport,
- flight_buffer = #{handshakes := [],
- change_cipher_spec := ChangeCipher,
- handshakes_after_change_cipher_spec := Flight1},
- negotiated_version = Version,
- connection_states = ConnectionStates0} = State0, Epoch) ->
- {EncChangeCipher, ConnectionStates1} =
- encode_change_cipher(ChangeCipher, Version, Epoch-1, ConnectionStates0),
- {HsAfter, ConnectionStates} =
- encode_handshake_flight(lists:reverse(Flight1), Version, 1400, Epoch, ConnectionStates1),
- send(Transport, Socket, [EncChangeCipher, HsAfter]),
- start_flight(State0#state{connection_states = ConnectionStates}).
-
queue_change_cipher(ChangeCipher, #state{flight_buffer = Flight,
connection_states = ConnectionStates0} = State) ->
ConnectionStates =
dtls_record:next_epoch(ConnectionStates0, write),
State#state{flight_buffer = Flight#{change_cipher_spec => ChangeCipher},
connection_states = ConnectionStates}.
-
-send_alert(Alert, #state{negotiated_version = Version,
- socket = Socket,
- transport_cb = Transport,
- connection_states = ConnectionStates0} = State0) ->
- {BinMsg, ConnectionStates} =
- encode_alert(Alert, Version, ConnectionStates0),
- send(Transport, Socket, BinMsg),
- State0#state{connection_states = ConnectionStates}.
-
-close(downgrade, _,_,_,_) ->
- ok;
-%% Other
-close(_, Socket, Transport, _,_) ->
- dtls_socket:close(Transport,Socket).
-
reinit_handshake_data(#state{protocol_buffers = Buffers} = State) ->
State#state{premaster_secret = undefined,
public_key_info = undefined,
tls_handshake_history = ssl_handshake:init_handshake_history(),
flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT},
- protocol_buffers =
+ flight_buffer = new_flight(),
+ protocol_buffers =
Buffers#protocol_buffers{
- dtls_handshake_next_seq = 0,
+ dtls_handshake_next_seq = 0,
dtls_handshake_next_fragments = [],
dtls_handshake_later_fragments = []
}}.
@@ -200,42 +347,82 @@ select_sni_extension(#client_hello{extensions = HelloExtensions}) ->
select_sni_extension(_) ->
undefined.
-socket(Pid, Transport, Socket, Connection, _) ->
- dtls_socket:socket(Pid, Transport, Socket, Connection).
+empty_connection_state(ConnectionEnd, BeastMitigation) ->
+ Empty = ssl_record:empty_connection_state(ConnectionEnd, BeastMitigation),
+ dtls_record:empty_connection_state(Empty).
%%====================================================================
-%% tls_connection_sup API
+%% Alert and close handling
+%%====================================================================
+encode_alert(#alert{} = Alert, Version, ConnectionStates) ->
+ dtls_record:encode_alert_record(Alert, Version, ConnectionStates).
+
+send_alert(Alert, #state{negotiated_version = Version,
+ socket = Socket,
+ transport_cb = Transport,
+ connection_states = ConnectionStates0} = State0) ->
+ {BinMsg, ConnectionStates} =
+ encode_alert(Alert, Version, ConnectionStates0),
+ send(Transport, Socket, BinMsg),
+ State0#state{connection_states = ConnectionStates}.
+
+close(downgrade, _,_,_,_) ->
+ ok;
+%% Other
+close(_, Socket, Transport, _,_) ->
+ dtls_socket:close(Transport,Socket).
+
+protocol_name() ->
+ "DTLS".
+
%%====================================================================
+%% Data handling
+%%====================================================================
-%%--------------------------------------------------------------------
--spec start_link(atom(), host(), inet:port_number(), port(), list(), pid(), tuple()) ->
- {ok, pid()} | ignore | {error, reason()}.
-%%
-%% Description: Creates a gen_fsm process which calls Module:init/1 to
-%% initialize. To ensure a synchronized start-up procedure, this function
-%% does not return until Module:init/1 has returned.
-%%--------------------------------------------------------------------
-start_link(Role, Host, Port, Socket, Options, User, CbInfo) ->
- {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Host, Port, Socket, Options, User, CbInfo]])}.
+encode_data(Data, Version, ConnectionStates0)->
+ dtls_record:encode_data(Data, Version, ConnectionStates0).
-init([Role, Host, Port, Socket, Options, User, CbInfo]) ->
- process_flag(trap_exit, true),
- State0 = initial_state(Role, Host, Port, Socket, Options, User, CbInfo),
- try
- State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0),
- gen_statem:enter_loop(?MODULE, [], init, State)
- catch
- throw:Error ->
- gen_statem:enter_loop(?MODULE, [], error, {Error,State0})
+passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName) ->
+ case Buffer of
+ <<>> ->
+ {Record, State} = next_record(State0),
+ next_event(StateName, Record, State);
+ _ ->
+ {Record, State} = ssl_connection:read_application_data(<<>>, State0),
+ next_event(StateName, Record, State)
end.
+next_record_if_active(State =
+ #state{socket_options =
+ #socket_options{active = false}}) ->
+ {no_record ,State};
-callback_mode() ->
- state_functions.
+next_record_if_active(State) ->
+ next_record(State).
+
+send(Transport, {_, {{_,_}, _} = Socket}, Data) ->
+ send(Transport, Socket, Data);
+send(Transport, Socket, Data) ->
+ dtls_socket:send(Transport, Socket, Data).
+
+socket(Pid, Transport, Socket, Connection, _) ->
+ dtls_socket:socket(Pid, Transport, Socket, Connection).
+
+setopts(Transport, Socket, Other) ->
+ dtls_socket:setopts(Transport, Socket, Other).
+
+getopts(Transport, Socket, Tag) ->
+ dtls_socket:getopts(Transport, Socket, Tag).
%%--------------------------------------------------------------------
%% State functions
%%--------------------------------------------------------------------
-
+%%--------------------------------------------------------------------
+-spec init(gen_statem:event_type(),
+ {start, timeout()} | term(), #state{}) ->
+ gen_statem:state_function_result().
+%%--------------------------------------------------------------------
+init(enter, _, State) ->
+ {keep_state, State};
init({call, From}, {start, Timeout},
#state{host = Host, port = Port, role = client,
ssl_options = SslOpts,
@@ -250,7 +437,7 @@ init({call, From}, {start, Timeout},
Cache, CacheCb, Renegotiation, Cert),
Version = Hello#client_hello.client_version,
- HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions),
+ HelloVersion = dtls_record:hello_version(Version, SslOpts#ssl_options.versions),
State1 = prepare_flight(State0#state{negotiated_version = Version}),
{State2, Actions} = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),
State3 = State2#state{negotiated_version = Version, %% Requested version
@@ -263,19 +450,32 @@ init({call, From}, {start, Timeout},
{Record, State} = next_record(State3),
next_event(hello, Record, State, Actions);
init({call, _} = Type, Event, #state{role = server, transport_cb = gen_udp} = State) ->
- ssl_connection:init(Type, Event,
- State#state{flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}},
- ?MODULE);
+ Result = gen_handshake(?FUNCTION_NAME, Type, Event,
+ State#state{flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT},
+ protocol_specific = #{current_cookie_secret => dtls_v1:cookie_secret(),
+ previous_cookie_secret => <<>>,
+ ignored_alerts => 0,
+ max_ignored_alerts => 10}}),
+ erlang:send_after(dtls_v1:cookie_timeout(), self(), new_cookie_secret),
+ Result;
+
init({call, _} = Type, Event, #state{role = server} = State) ->
%% I.E. DTLS over sctp
- ssl_connection:init(Type, Event, State#state{flight_state = reliable}, ?MODULE);
+ gen_handshake(?FUNCTION_NAME, Type, Event, State#state{flight_state = reliable});
init(Type, Event, State) ->
- ssl_connection:init(Type, Event, State, ?MODULE).
-
+ gen_handshake(?FUNCTION_NAME, Type, Event, State).
+
+%%--------------------------------------------------------------------
+-spec error(gen_statem:event_type(),
+ {start, timeout()} | term(), #state{}) ->
+ gen_statem:state_function_result().
+%%--------------------------------------------------------------------
+error(enter, _, State) ->
+ {keep_state, State};
error({call, From}, {start, _Timeout}, {Error, State}) ->
{stop_and_reply, normal, {reply, From, {error, Error}}, State};
-error({call, From}, Msg, State) ->
- handle_call(Msg, From, error, State);
+error({call, _} = Call, Msg, State) ->
+ gen_handshake(?FUNCTION_NAME, Call, Msg, State);
error(_, _, _) ->
{keep_state_and_data, [postpone]}.
@@ -285,29 +485,46 @@ error(_, _, _) ->
#state{}) ->
gen_statem:state_function_result().
%%--------------------------------------------------------------------
+hello(enter, _, #state{role = server} = State) ->
+ {keep_state, State};
+hello(enter, _, #state{role = client} = State0) ->
+ {State, Actions} = handle_flight_timer(State0),
+ {keep_state, State, Actions};
hello(internal, #client_hello{cookie = <<>>,
client_version = Version} = Hello, #state{role = server,
transport_cb = Transport,
- socket = Socket} = State0) ->
- %% TODO: not hard code key
+ socket = Socket,
+ protocol_specific = #{current_cookie_secret := Secret}} = State0) ->
{ok, {IP, Port}} = dtls_socket:peername(Transport, Socket),
- Cookie = dtls_handshake:cookie(<<"secret">>, IP, Port, Hello),
- VerifyRequest = dtls_handshake:hello_verify_request(Cookie, Version),
+ Cookie = dtls_handshake:cookie(Secret, IP, Port, Hello),
+ %% FROM RFC 6347 regarding HelloVerifyRequest message:
+ %% The server_version field has the same syntax as in TLS. However, in
+ %% order to avoid the requirement to do version negotiation in the
+ %% initial handshake, DTLS 1.2 server implementations SHOULD use DTLS
+ %% version 1.0 regardless of the version of TLS that is expected to be
+ %% negotiated.
+ VerifyRequest = dtls_handshake:hello_verify_request(Cookie, ?HELLO_VERIFY_REQUEST_VERSION),
State1 = prepare_flight(State0#state{negotiated_version = Version}),
{State2, Actions} = send_handshake(VerifyRequest, State1),
{Record, State} = next_record(State2),
- next_event(hello, Record, State#state{tls_handshake_history = ssl_handshake:init_handshake_history()}, Actions);
+ next_event(?FUNCTION_NAME, Record, State#state{tls_handshake_history = ssl_handshake:init_handshake_history()}, Actions);
hello(internal, #client_hello{cookie = Cookie} = Hello, #state{role = server,
transport_cb = Transport,
- socket = Socket} = State0) ->
+ socket = Socket,
+ protocol_specific = #{current_cookie_secret := Secret,
+ previous_cookie_secret := PSecret}} = State0) ->
{ok, {IP, Port}} = dtls_socket:peername(Transport, Socket),
- %% TODO: not hard code key
- case dtls_handshake:cookie(<<"secret">>, IP, Port, Hello) of
+ case dtls_handshake:cookie(Secret, IP, Port, Hello) of
Cookie ->
handle_client_hello(Hello, State0);
_ ->
- %% Handle bad cookie as new cookie request RFC 6347 4.1.2
- hello(internal, Hello#client_hello{cookie = <<>>}, State0)
+ case dtls_handshake:cookie(PSecret, IP, Port, Hello) of
+ Cookie ->
+ handle_client_hello(Hello, State0);
+ _ ->
+ %% Handle bad cookie as new cookie request RFC 6347 4.1.2
+ hello(internal, Hello#client_hello{cookie = <<>>}, State0)
+ end
end;
hello(internal, #hello_verify_request{cookie = Cookie}, #state{role = client,
host = Host, port = Port,
@@ -319,19 +536,20 @@ hello(internal, #hello_verify_request{cookie = Cookie}, #state{role = client,
session_cache = Cache,
session_cache_cb = CacheCb
} = State0) ->
- State1 = prepare_flight(State0#state{tls_handshake_history = ssl_handshake:init_handshake_history()}),
+
Hello = dtls_handshake:client_hello(Host, Port, Cookie, ConnectionStates0,
SslOpts,
Cache, CacheCb, Renegotiation, OwnCert),
Version = Hello#client_hello.client_version,
- HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions),
- {State2, Actions} = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),
+ State1 = prepare_flight(State0#state{tls_handshake_history = ssl_handshake:init_handshake_history()}),
+
+ {State2, Actions} = send_handshake(Hello, State1),
State3 = State2#state{negotiated_version = Version, %% Requested version
session =
Session0#session{session_id =
Hello#client_hello.session_id}},
{Record, State} = next_record(State3),
- next_event(hello, Record, State, Actions);
+ next_event(?FUNCTION_NAME, Record, State, Actions);
hello(internal, #server_hello{} = Hello,
#state{connection_states = ConnectionStates0,
negotiated_version = ReqVersion,
@@ -340,81 +558,119 @@ hello(internal, #server_hello{} = Hello,
ssl_options = SslOptions} = State) ->
case dtls_handshake:hello(Hello, SslOptions, ConnectionStates0, Renegotiation) of
#alert{} = Alert ->
- ssl_connection:handle_own_alert(Alert, ReqVersion, hello, State);
+ handle_own_alert(Alert, ReqVersion, ?FUNCTION_NAME, State);
{Version, NewId, ConnectionStates, ProtoExt, Protocol} ->
ssl_connection:handle_session(Hello,
Version, NewId, ConnectionStates, ProtoExt, Protocol, State)
end;
hello(internal, {handshake, {#client_hello{cookie = <<>>} = Handshake, _}}, State) ->
%% Initial hello should not be in handshake history
- {next_state, hello, State, [{next_event, internal, Handshake}]};
+ {next_state, ?FUNCTION_NAME, State, [{next_event, internal, Handshake}]};
hello(internal, {handshake, {#hello_verify_request{} = Handshake, _}}, State) ->
%% hello_verify should not be in handshake history
- {next_state, hello, State, [{next_event, internal, Handshake}]};
+ {next_state, ?FUNCTION_NAME, State, [{next_event, internal, Handshake}]};
hello(info, Event, State) ->
- handle_info(Event, hello, State);
+ gen_info(Event, ?FUNCTION_NAME, State);
hello(state_timeout, Event, State) ->
- handle_state_timeout(Event, hello, State);
+ handle_state_timeout(Event, ?FUNCTION_NAME, State);
hello(Type, Event, State) ->
- ssl_connection:hello(Type, Event, State, ?MODULE).
+ gen_handshake(?FUNCTION_NAME, Type, Event, State).
+%%--------------------------------------------------------------------
+-spec abbreviated(gen_statem:event_type(), term(), #state{}) ->
+ gen_statem:state_function_result().
+%%--------------------------------------------------------------------
+abbreviated(enter, _, State0) ->
+ {State, Actions} = handle_flight_timer(State0),
+ {keep_state, State, Actions};
abbreviated(info, Event, State) ->
- handle_info(Event, abbreviated, State);
+ gen_info(Event, ?FUNCTION_NAME, State);
abbreviated(internal = Type,
#change_cipher_spec{type = <<1>>} = Event,
#state{connection_states = ConnectionStates0} = State) ->
ConnectionStates1 = dtls_record:save_current_connection_state(ConnectionStates0, read),
ConnectionStates = dtls_record:next_epoch(ConnectionStates1, read),
- ssl_connection:abbreviated(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE);
+ gen_handshake(?FUNCTION_NAME, Type, Event, State#state{connection_states = ConnectionStates});
abbreviated(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) ->
- ssl_connection:abbreviated(Type, Event,
- prepare_flight(State#state{connection_states = ConnectionStates,
- flight_state = connection}), ?MODULE);
+ gen_handshake(?FUNCTION_NAME, Type, Event,
+ prepare_flight(State#state{connection_states = ConnectionStates,
+ flight_state = connection}));
abbreviated(state_timeout, Event, State) ->
- handle_state_timeout(Event, abbreviated, State);
+ handle_state_timeout(Event, ?FUNCTION_NAME, State);
abbreviated(Type, Event, State) ->
- ssl_connection:abbreviated(Type, Event, State, ?MODULE).
-
+ gen_handshake(?FUNCTION_NAME, Type, Event, State).
+%%--------------------------------------------------------------------
+-spec certify(gen_statem:event_type(), term(), #state{}) ->
+ gen_statem:state_function_result().
+%%--------------------------------------------------------------------
+certify(enter, _, State0) ->
+ {State, Actions} = handle_flight_timer(State0),
+ {keep_state, State, Actions};
certify(info, Event, State) ->
- handle_info(Event, certify, State);
+ gen_info(Event, ?FUNCTION_NAME, State);
certify(internal = Type, #server_hello_done{} = Event, State) ->
ssl_connection:certify(Type, Event, prepare_flight(State), ?MODULE);
+certify(internal, #change_cipher_spec{type = <<1>>}, State0) ->
+ {State1, Actions0} = send_handshake_flight(State0, retransmit_epoch(?FUNCTION_NAME, State0)),
+ {Record, State2} = next_record(State1),
+ {next_state, ?FUNCTION_NAME, State, Actions} = next_event(?FUNCTION_NAME, Record, State2, Actions0),
+ %% This will reset the retransmission timer by repeating the enter state event
+ {repeat_state, State, Actions};
certify(state_timeout, Event, State) ->
- handle_state_timeout(Event, certify, State);
+ handle_state_timeout(Event, ?FUNCTION_NAME, State);
certify(Type, Event, State) ->
- ssl_connection:certify(Type, Event, State, ?MODULE).
+ gen_handshake(?FUNCTION_NAME, Type, Event, State).
+%%--------------------------------------------------------------------
+-spec cipher(gen_statem:event_type(), term(), #state{}) ->
+ gen_statem:state_function_result().
+%%--------------------------------------------------------------------
+cipher(enter, _, State0) ->
+ {State, Actions} = handle_flight_timer(State0),
+ {keep_state, State, Actions};
cipher(info, Event, State) ->
- handle_info(Event, cipher, State);
+ gen_info(Event, ?FUNCTION_NAME, State);
cipher(internal = Type, #change_cipher_spec{type = <<1>>} = Event,
#state{connection_states = ConnectionStates0} = State) ->
ConnectionStates1 = dtls_record:save_current_connection_state(ConnectionStates0, read),
ConnectionStates = dtls_record:next_epoch(ConnectionStates1, read),
- ssl_connection:cipher(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE);
+ ssl_connection:?FUNCTION_NAME(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE);
cipher(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) ->
- ssl_connection:cipher(Type, Event,
- prepare_flight(State#state{connection_states = ConnectionStates,
- flight_state = connection}),
- ?MODULE);
+ ssl_connection:?FUNCTION_NAME(Type, Event,
+ prepare_flight(State#state{connection_states = ConnectionStates,
+ flight_state = connection}),
+ ?MODULE);
cipher(state_timeout, Event, State) ->
- handle_state_timeout(Event, cipher, State);
+ handle_state_timeout(Event, ?FUNCTION_NAME, State);
cipher(Type, Event, State) ->
- ssl_connection:cipher(Type, Event, State, ?MODULE).
+ ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE).
+%%--------------------------------------------------------------------
+-spec connection(gen_statem:event_type(),
+ #hello_request{} | #client_hello{}| term(), #state{}) ->
+ gen_statem:state_function_result().
+%%--------------------------------------------------------------------
+connection(enter, _, State) ->
+ {keep_state, State};
connection(info, Event, State) ->
- handle_info(Event, connection, State);
+ gen_info(Event, ?FUNCTION_NAME, State);
connection(internal, #hello_request{}, #state{host = Host, port = Port,
- session = #session{own_certificate = Cert} = Session0,
- session_cache = Cache, session_cache_cb = CacheCb,
- ssl_options = SslOpts,
- connection_states = ConnectionStates0,
- renegotiation = {Renegotiation, _}} = State0) ->
+ session = #session{own_certificate = Cert} = Session0,
+ session_cache = Cache, session_cache_cb = CacheCb,
+ ssl_options = SslOpts,
+ connection_states = ConnectionStates0,
+ renegotiation = {Renegotiation, _}} = State0) ->
+
Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts,
Cache, CacheCb, Renegotiation, Cert),
- {State1, Actions} = send_handshake(Hello, State0),
+ Version = Hello#client_hello.client_version,
+ HelloVersion = dtls_record:hello_version(Version, SslOpts#ssl_options.versions),
+ State1 = prepare_flight(State0),
+ {State2, Actions} = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),
{Record, State} =
next_record(
- State1#state{session = Session0#session{session_id
+ State2#state{flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT},
+ session = Session0#session{session_id
= Hello#client_hello.session_id}}),
next_event(hello, Record, State, Actions);
connection(internal, #client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) ->
@@ -424,125 +680,35 @@ connection(internal, #client_hello{} = Hello, #state{role = server, allow_renego
%% initiated renegotiation we will disallow many client initiated
%% renegotiations immediately after each other.
erlang:send_after(?WAIT_TO_ALLOW_RENEGOTIATION, self(), allow_renegotiate),
- {next_state, hello, State#state{allow_renegotiate = false}, [{next_event, internal, Hello}]};
+ {next_state, hello, State#state{allow_renegotiate = false, renegotiation = {true, peer}},
+ [{next_event, internal, Hello}]};
connection(internal, #client_hello{}, #state{role = server, allow_renegotiate = false} = State0) ->
Alert = ?ALERT_REC(?WARNING, ?NO_RENEGOTIATION),
State1 = send_alert(Alert, State0),
{Record, State} = ssl_connection:prepare_connection(State1, ?MODULE),
- next_event(connection, Record, State);
+ next_event(?FUNCTION_NAME, Record, State);
connection(Type, Event, State) ->
- ssl_connection:connection(Type, Event, State, ?MODULE).
-
-downgrade(Type, Event, State) ->
- ssl_connection:downgrade(Type, Event, State, ?MODULE).
+ ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE).
+%%TODO does this make sense for DTLS ?
%%--------------------------------------------------------------------
-%% Description: This function is called by a gen_fsm when it receives any
-%% other message than a synchronous or asynchronous event
-%% (or a system message).
+-spec downgrade(gen_statem:event_type(), term(), #state{}) ->
+ gen_statem:state_function_result().
%%--------------------------------------------------------------------
+downgrade(enter, _, State) ->
+ {keep_state, State};
+downgrade(Type, Event, State) ->
+ ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE).
-%% raw data from socket, unpack records
-handle_info({Protocol, _, _, _, Data}, StateName,
- #state{data_tag = Protocol} = State0) ->
- case next_dtls_record(Data, State0) of
- {Record, State} ->
- next_event(StateName, Record, State);
- #alert{} = Alert ->
- ssl_connection:handle_normal_shutdown(Alert, StateName, State0),
- {stop, {shutdown, own_alert}}
- end;
-handle_info({CloseTag, Socket}, StateName,
- #state{socket = Socket, close_tag = CloseTag,
- negotiated_version = Version} = State) ->
- %% Note that as of DTLS 1.2 (TLS 1.1),
- %% failure to properly close a connection no longer requires that a
- %% session not be resumed. This is a change from DTLS 1.0 to conform
- %% with widespread implementation practice.
- case Version of
- {254, N} when N =< 253 ->
- ok;
- _ ->
- %% As invalidate_sessions here causes performance issues,
- %% we will conform to the widespread implementation
- %% practice and go aginst the spec
- %%invalidate_session(Role, Host, Port, Session)
- ok
- end,
- ssl_connection:handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State),
- {stop, {shutdown, transport_closed}};
-handle_info(Msg, StateName, State) ->
- ssl_connection:handle_info(Msg, StateName, State).
-
-handle_call(Event, From, StateName, State) ->
- ssl_connection:handle_call(Event, From, StateName, State, ?MODULE).
-
-handle_common_event(internal, #alert{} = Alert, StateName,
- #state{negotiated_version = Version} = State) ->
- ssl_connection:handle_own_alert(Alert, Version, StateName, State);
-%%% DTLS record protocol level handshake messages
-handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE,
- fragment = Data},
- StateName,
- #state{protocol_buffers = Buffers0,
- negotiated_version = Version} = State0) ->
- try
- case dtls_handshake:get_dtls_handshake(Version, Data, Buffers0) of
- {[], Buffers} ->
- {Record, State} = next_record(State0#state{protocol_buffers = Buffers}),
- next_event(StateName, Record, State);
- {Packets, Buffers} ->
- State = State0#state{protocol_buffers = Buffers},
- Events = dtls_handshake_events(Packets),
- {next_state, StateName,
- State#state{unprocessed_handshake_events = unprocessed_events(Events)}, Events}
- end
- catch throw:#alert{} = Alert ->
- ssl_connection:handle_own_alert(Alert, Version, StateName, State0)
- end;
-%%% DTLS record protocol level application data messages
-handle_common_event(internal, #ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, StateName, State) ->
- {next_state, StateName, State, [{next_event, internal, {application_data, Data}}]};
-%%% DTLS record protocol level change cipher messages
-handle_common_event(internal, #ssl_tls{type = ?CHANGE_CIPHER_SPEC, fragment = Data}, StateName, State) ->
- {next_state, StateName, State, [{next_event, internal, #change_cipher_spec{type = Data}}]};
-%%% DTLS record protocol level Alert messages
-handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, StateName,
- #state{negotiated_version = Version} = State) ->
- case decode_alerts(EncAlerts) of
- Alerts = [_|_] ->
- handle_alerts(Alerts, {next_state, StateName, State});
- #alert{} = Alert ->
- ssl_connection:handle_own_alert(Alert, Version, StateName, State)
- end;
-%% Ignore unknown TLS record level protocol messages
-handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) ->
- {next_state, StateName, State}.
-
-handle_state_timeout(flight_retransmission_timeout, StateName,
- #state{flight_state = {retransmit, NextTimeout}} = State0) ->
- {State1, Actions} = send_handshake_flight(State0#state{flight_state = {retransmit, NextTimeout}},
- retransmit_epoch(StateName, State0)),
- {Record, State} = next_record(State1),
- next_event(StateName, Record, State, Actions).
-
-send(Transport, {_, {{_,_}, _} = Socket}, Data) ->
- send(Transport, Socket, Data);
-send(Transport, Socket, Data) ->
- dtls_socket:send(Transport, Socket, Data).
%%--------------------------------------------------------------------
-%% Description:This function is called by a gen_fsm when it is about
-%% to terminate. It should be the opposite of Module:init/1 and do any
-%% necessary cleaning up. When it returns, the gen_fsm terminates with
-%% Reason. The return value is ignored.
+%% gen_statem callbacks
%%--------------------------------------------------------------------
+callback_mode() ->
+ [state_functions, state_enter].
+
terminate(Reason, StateName, State) ->
ssl_connection:terminate(Reason, StateName, State).
-%%--------------------------------------------------------------------
-%% code_change(OldVsn, StateName, State, Extra) -> {ok, StateName, NewState}
-%% Description: Convert process state when code is changed
-%%--------------------------------------------------------------------
code_change(_OldVsn, StateName, State, _Extra) ->
{ok, StateName, State}.
@@ -552,55 +718,6 @@ format_status(Type, Data) ->
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
-handle_client_hello(#client_hello{client_version = ClientVersion} = Hello,
- #state{connection_states = ConnectionStates0,
- port = Port, session = #session{own_certificate = Cert} = Session0,
- renegotiation = {Renegotiation, _},
- session_cache = Cache,
- session_cache_cb = CacheCb,
- negotiated_protocol = CurrentProtocol,
- key_algorithm = KeyExAlg,
- ssl_options = SslOpts} = State0) ->
-
- case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb,
- ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of
- #alert{} = Alert ->
- ssl_connection:handle_own_alert(Alert, ClientVersion, hello, State0);
- {Version, {Type, Session},
- ConnectionStates, Protocol0, ServerHelloExt, HashSign} ->
- Protocol = case Protocol0 of
- undefined -> CurrentProtocol;
- _ -> Protocol0
- end,
-
- State = prepare_flight(State0#state{connection_states = ConnectionStates,
- negotiated_version = Version,
- hashsign_algorithm = HashSign,
- session = Session,
- negotiated_protocol = Protocol}),
-
- ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt},
- State, ?MODULE)
- end.
-
-encode_handshake_flight(Flight, Version, MaxFragmentSize, Epoch, ConnectionStates) ->
- Fragments = lists:map(fun(Handshake) ->
- dtls_handshake:fragment_handshake(Handshake, MaxFragmentSize)
- end, Flight),
- dtls_record:encode_handshake(Fragments, Version, Epoch, ConnectionStates).
-
-encode_change_cipher(#change_cipher_spec{}, Version, Epoch, ConnectionStates) ->
- dtls_record:encode_change_cipher_spec(Version, Epoch, ConnectionStates).
-
-encode_data(Data, Version, ConnectionStates0)->
- dtls_record:encode_data(Data, Version, ConnectionStates0).
-
-encode_alert(#alert{} = Alert, Version, ConnectionStates) ->
- dtls_record:encode_alert_record(Alert, Version, ConnectionStates).
-
-decode_alerts(Bin) ->
- ssl_alert:decode(Bin).
-
initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, _}, User,
{CbModule, DataTag, CloseTag, ErrorTag}) ->
#ssl_options{beast_mitigation = BeastMitigation} = SSLOptions,
@@ -641,10 +758,12 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, _}, User,
flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}
}.
-next_dtls_record(Data, #state{protocol_buffers = #protocol_buffers{
+next_dtls_record(Data, StateName, #state{protocol_buffers = #protocol_buffers{
dtls_record_buffer = Buf0,
dtls_cipher_texts = CT0} = Buffers} = State0) ->
- case dtls_record:get_dtls_records(Data, Buf0) of
+ case dtls_record:get_dtls_records(Data,
+ acceptable_record_versions(StateName, State0),
+ Buf0) of
{Records, Buf1} ->
CT1 = CT0 ++ Records,
next_record(State0#state{protocol_buffers =
@@ -654,13 +773,18 @@ next_dtls_record(Data, #state{protocol_buffers = #protocol_buffers{
Alert
end.
-next_record(#state{unprocessed_handshake_events = N} = State) when N > 0 ->
- {no_record, State#state{unprocessed_handshake_events = N-1}};
-
-next_record(#state{protocol_buffers =
- #protocol_buffers{dtls_cipher_texts = [CT | Rest]}
- = Buffers,
- connection_states = ConnStates0} = State) ->
+acceptable_record_versions(hello, _) ->
+ [dtls_record:protocol_version(Vsn) || Vsn <- ?ALL_DATAGRAM_SUPPORTED_VERSIONS];
+acceptable_record_versions(_, #state{negotiated_version = Version}) ->
+ [Version].
+
+dtls_handshake_events(Packets) ->
+ lists:map(fun(Packet) ->
+ {next_event, internal, {handshake, Packet}}
+ end, Packets).
+
+decode_cipher_text(#state{protocol_buffers = #protocol_buffers{dtls_cipher_texts = [ CT | Rest]} = Buffers,
+ connection_states = ConnStates0} = State) ->
case dtls_record:decode_cipher_text(CT, ConnStates0) of
{Plain, ConnStates} ->
{Plain, State#state{protocol_buffers =
@@ -668,83 +792,185 @@ next_record(#state{protocol_buffers =
connection_states = ConnStates}};
#alert{} = Alert ->
{Alert, State}
- end;
-next_record(#state{role = server,
- socket = {Listener, {Client, _}},
- transport_cb = gen_udp} = State) ->
- dtls_udp_listener:active_once(Listener, Client, self()),
- {no_record, State};
-next_record(#state{role = client,
- socket = {_Server, Socket},
- transport_cb = Transport} = State) ->
- dtls_socket:setopts(Transport, Socket, [{active,once}]),
- {no_record, State};
-next_record(State) ->
- {no_record, State}.
+ end.
-next_record_if_active(State =
- #state{socket_options =
- #socket_options{active = false}}) ->
- {no_record ,State};
+dtls_version(hello, Version, #state{role = server} = State) ->
+ State#state{negotiated_version = Version}; %%Inital version
+dtls_version(_,_, State) ->
+ State.
-next_record_if_active(State) ->
- next_record(State).
+handle_client_hello(#client_hello{client_version = ClientVersion} = Hello,
+ #state{connection_states = ConnectionStates0,
+ port = Port, session = #session{own_certificate = Cert} = Session0,
+ renegotiation = {Renegotiation, _},
+ session_cache = Cache,
+ session_cache_cb = CacheCb,
+ negotiated_protocol = CurrentProtocol,
+ key_algorithm = KeyExAlg,
+ ssl_options = SslOpts} = State0) ->
+
+ case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb,
+ ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of
+ #alert{} = Alert ->
+ handle_own_alert(Alert, ClientVersion, hello, State0);
+ {Version, {Type, Session},
+ ConnectionStates, Protocol0, ServerHelloExt, HashSign} ->
+ Protocol = case Protocol0 of
+ undefined -> CurrentProtocol;
+ _ -> Protocol0
+ end,
-passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName) ->
- case Buffer of
- <<>> ->
- {Record, State} = next_record(State0),
- next_event(StateName, Record, State);
- _ ->
- {Record, State} = ssl_connection:read_application_data(<<>>, State0),
- next_event(StateName, Record, State)
+ State = prepare_flight(State0#state{connection_states = ConnectionStates,
+ negotiated_version = Version,
+ hashsign_algorithm = HashSign,
+ client_hello_version = ClientVersion,
+ session = Session,
+ negotiated_protocol = Protocol}),
+
+ ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt},
+ State, ?MODULE)
end.
-next_event(StateName, Record, State) ->
- next_event(StateName, Record, State, []).
-next_event(connection = StateName, no_record,
- #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->
- case next_record_if_active(State0) of
- {no_record, State} ->
- ssl_connection:hibernate_after(StateName, State, Actions);
- {#ssl_tls{epoch = CurrentEpoch} = Record, State} ->
- {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]};
- {#ssl_tls{epoch = Epoch,
- type = ?HANDSHAKE,
- version = _Version}, State1} = _Record when Epoch == CurrentEpoch-1 ->
- {State, MoreActions} = send_handshake_flight(State1, Epoch),
- {next_state, StateName, State, Actions ++ MoreActions};
- {#ssl_tls{epoch = _Epoch,
- version = _Version}, State} ->
- %% TODO maybe buffer later epoch
- {next_state, StateName, State, Actions};
- {#alert{} = Alert, State} ->
- {next_state, StateName, State, [{next_event, internal, Alert} | Actions]}
- end;
-next_event(StateName, Record,
- #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State, Actions) ->
- case Record of
- no_record ->
- {next_state, StateName, State, Actions};
- #ssl_tls{epoch = CurrentEpoch,
- version = Version} = Record ->
- {next_state, StateName,
- dtls_version(StateName, Version, State),
- [{next_event, internal, {protocol_record, Record}} | Actions]};
- #ssl_tls{epoch = _Epoch,
- version = _Version} = _Record ->
- %% TODO maybe buffer later epoch
- {next_state, StateName, State, Actions};
+%% raw data from socket, unpack records
+handle_info({Protocol, _, _, _, Data}, StateName,
+ #state{data_tag = Protocol} = State0) ->
+ case next_dtls_record(Data, StateName, State0) of
+ {Record, State} ->
+ next_event(StateName, Record, State);
#alert{} = Alert ->
- {next_state, StateName, State, [{next_event, internal, Alert} | Actions]}
+ ssl_connection:handle_normal_shutdown(Alert, StateName, State0),
+ {stop, {shutdown, own_alert}}
+ end;
+handle_info({CloseTag, Socket}, StateName,
+ #state{socket = Socket,
+ socket_options = #socket_options{active = Active},
+ protocol_buffers = #protocol_buffers{dtls_cipher_texts = CTs},
+ close_tag = CloseTag,
+ negotiated_version = Version} = State) ->
+ %% Note that as of DTLS 1.2 (TLS 1.1),
+ %% failure to properly close a connection no longer requires that a
+ %% session not be resumed. This is a change from DTLS 1.0 to conform
+ %% with widespread implementation practice.
+ case (Active == false) andalso (CTs =/= []) of
+ false ->
+ case Version of
+ {254, N} when N =< 253 ->
+ ok;
+ _ ->
+ %% As invalidate_sessions here causes performance issues,
+ %% we will conform to the widespread implementation
+ %% practice and go aginst the spec
+ %%invalidate_session(Role, Host, Port, Session)
+ ok
+ end,
+ ssl_connection:handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State),
+ {stop, {shutdown, transport_closed}};
+ true ->
+ %% Fixes non-delivery of final DTLS record in {active, once}.
+ %% Basically allows the application the opportunity to set {active, once} again
+ %% and then receive the final message.
+ next_event(StateName, no_record, State)
+ end;
+
+handle_info(new_cookie_secret, StateName,
+ #state{protocol_specific = #{current_cookie_secret := Secret} = CookieInfo} = State) ->
+ erlang:send_after(dtls_v1:cookie_timeout(), self(), new_cookie_secret),
+ {next_state, StateName, State#state{protocol_specific =
+ CookieInfo#{current_cookie_secret => dtls_v1:cookie_secret(),
+ previous_cookie_secret => Secret}}};
+handle_info(Msg, StateName, State) ->
+ ssl_connection:StateName(info, Msg, State, ?MODULE).
+
+handle_state_timeout(flight_retransmission_timeout, StateName,
+ #state{flight_state = {retransmit, NextTimeout}} = State0) ->
+ {State1, Actions0} = send_handshake_flight(State0#state{flight_state = {retransmit, NextTimeout}},
+ retransmit_epoch(StateName, State0)),
+ {Record, State2} = next_record(State1),
+ {next_state, StateName, State, Actions} = next_event(StateName, Record, State2, Actions0),
+ %% This will reset the retransmission timer by repeating the enter state event
+ {repeat_state, State, Actions}.
+
+handle_alerts([], Result) ->
+ Result;
+handle_alerts(_, {stop,_} = Stop) ->
+ Stop;
+handle_alerts([Alert | Alerts], {next_state, StateName, State}) ->
+ handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State));
+handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) ->
+ handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)).
+
+handle_own_alert(Alert, Version, StateName, #state{transport_cb = gen_udp,
+ role = Role,
+ ssl_options = Options} = State0) ->
+ case ignore_alert(Alert, State0) of
+ {true, State} ->
+ log_ignore_alert(Options#ssl_options.log_alert, StateName, Alert, Role),
+ {next_state, StateName, State};
+ {false, State} ->
+ ssl_connection:handle_own_alert(Alert, Version, StateName, State)
+ end;
+handle_own_alert(Alert, Version, StateName, State) ->
+ ssl_connection:handle_own_alert(Alert, Version, StateName, State).
+
+encode_handshake_flight(Flight, Version, MaxFragmentSize, Epoch, ConnectionStates) ->
+ Fragments = lists:map(fun(Handshake) ->
+ dtls_handshake:fragment_handshake(Handshake, MaxFragmentSize)
+ end, Flight),
+ dtls_record:encode_handshake(Fragments, Version, Epoch, ConnectionStates).
+
+encode_change_cipher(#change_cipher_spec{}, Version, Epoch, ConnectionStates) ->
+ dtls_record:encode_change_cipher_spec(Version, Epoch, ConnectionStates).
+
+decode_alerts(Bin) ->
+ ssl_alert:decode(Bin).
+
+gen_handshake(StateName, Type, Event,
+ #state{negotiated_version = Version} = State) ->
+ try ssl_connection:StateName(Type, Event, State, ?MODULE) of
+ Result ->
+ Result
+ catch
+ _:_ ->
+ ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
+ malformed_handshake_data),
+ Version, StateName, State)
end.
-dtls_version(hello, Version, #state{role = server} = State) ->
- State#state{negotiated_version = Version}; %%Inital version
-dtls_version(_,_, State) ->
- State.
+gen_info(Event, connection = StateName, #state{negotiated_version = Version} = State) ->
+ try handle_info(Event, StateName, State) of
+ Result ->
+ Result
+ catch
+ _:_ ->
+ ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR,
+ malformed_data),
+ Version, StateName, State)
+ end;
+
+gen_info(Event, StateName, #state{negotiated_version = Version} = State) ->
+ try handle_info(Event, StateName, State) of
+ Result ->
+ Result
+ catch
+ _:_ ->
+ ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,
+ malformed_handshake_data),
+ Version, StateName, State)
+ end.
+unprocessed_events(Events) ->
+ %% The first handshake event will be processed immediately
+ %% as it is entered first in the event queue and
+ %% when it is processed there will be length(Events)-1
+ %% handshake events left to process before we should
+ %% process more TLS-records received on the socket.
+ erlang:length(Events)-1.
+update_handshake_history(#hello_verify_request{}, _, Hist) ->
+ Hist;
+update_handshake_history(_, Handshake, Hist) ->
+ %% DTLS never needs option "v2_hello_compatible" to be true
+ ssl_handshake:update_handshake_history(Hist, iolist_to_binary(Handshake), false).
prepare_flight(#state{flight_buffer = Flight,
connection_states = ConnectionStates0,
protocol_buffers =
@@ -765,14 +991,14 @@ next_flight(Flight) ->
Flight#{handshakes => [],
change_cipher_spec => undefined,
handshakes_after_change_cipher_spec => []}.
-
-start_flight(#state{transport_cb = gen_udp,
- flight_state = {retransmit, Timeout}} = State) ->
+
+handle_flight_timer(#state{transport_cb = gen_udp,
+ flight_state = {retransmit, Timeout}} = State) ->
start_retransmision_timer(Timeout, State);
-start_flight(#state{transport_cb = gen_udp,
+handle_flight_timer(#state{transport_cb = gen_udp,
flight_state = connection} = State) ->
{State, []};
-start_flight(State) ->
+handle_flight_timer(State) ->
%% No retransmision needed i.e DTLS over SCTP
{State#state{flight_state = reliable}, []}.
@@ -785,57 +1011,101 @@ new_timeout(N) when N =< 30 ->
new_timeout(_) ->
60.
-dtls_handshake_events(Packets) ->
- lists:map(fun(Packet) ->
- {next_event, internal, {handshake, Packet}}
- end, Packets).
+send_handshake_flight(#state{socket = Socket,
+ transport_cb = Transport,
+ flight_buffer = #{handshakes := Flight,
+ change_cipher_spec := undefined},
+ negotiated_version = Version,
+ connection_states = ConnectionStates0} = State0, Epoch) ->
+ %% TODO remove hardcoded Max size
+ {Encoded, ConnectionStates} =
+ encode_handshake_flight(lists:reverse(Flight), Version, 1400, Epoch, ConnectionStates0),
+ send(Transport, Socket, Encoded),
+ {State0#state{connection_states = ConnectionStates}, []};
-renegotiate(#state{role = client} = State, Actions) ->
- %% Handle same way as if server requested
- %% the renegotiation
- Hs0 = ssl_handshake:init_handshake_history(),
- {next_state, connection, State#state{tls_handshake_history = Hs0,
- protocol_buffers = #protocol_buffers{}},
- [{next_event, internal, #hello_request{}} | Actions]};
+send_handshake_flight(#state{socket = Socket,
+ transport_cb = Transport,
+ flight_buffer = #{handshakes := [_|_] = Flight0,
+ change_cipher_spec := ChangeCipher,
+ handshakes_after_change_cipher_spec := []},
+ negotiated_version = Version,
+ connection_states = ConnectionStates0} = State0, Epoch) ->
+ {HsBefore, ConnectionStates1} =
+ encode_handshake_flight(lists:reverse(Flight0), Version, 1400, Epoch, ConnectionStates0),
+ {EncChangeCipher, ConnectionStates} = encode_change_cipher(ChangeCipher, Version, Epoch, ConnectionStates1),
-renegotiate(#state{role = server,
- connection_states = CS0} = State0, Actions) ->
- HelloRequest = ssl_handshake:hello_request(),
- CS = CS0#{write_msg_seq => 0},
- {State1, MoreActions} = send_handshake(HelloRequest,
- State0#state{connection_states =
- CS}),
- Hs0 = ssl_handshake:init_handshake_history(),
- {Record, State} = next_record(State1#state{tls_handshake_history = Hs0,
- protocol_buffers = #protocol_buffers{}}),
- next_event(hello, Record, State, Actions ++ MoreActions).
+ send(Transport, Socket, [HsBefore, EncChangeCipher]),
+ {State0#state{connection_states = ConnectionStates}, []};
-handle_alerts([], Result) ->
- Result;
-handle_alerts(_, {stop,_} = Stop) ->
- Stop;
-handle_alerts([Alert | Alerts], {next_state, StateName, State}) ->
- handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State));
-handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) ->
- handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)).
+send_handshake_flight(#state{socket = Socket,
+ transport_cb = Transport,
+ flight_buffer = #{handshakes := [_|_] = Flight0,
+ change_cipher_spec := ChangeCipher,
+ handshakes_after_change_cipher_spec := Flight1},
+ negotiated_version = Version,
+ connection_states = ConnectionStates0} = State0, Epoch) ->
+ {HsBefore, ConnectionStates1} =
+ encode_handshake_flight(lists:reverse(Flight0), Version, 1400, Epoch-1, ConnectionStates0),
+ {EncChangeCipher, ConnectionStates2} =
+ encode_change_cipher(ChangeCipher, Version, Epoch-1, ConnectionStates1),
+ {HsAfter, ConnectionStates} =
+ encode_handshake_flight(lists:reverse(Flight1), Version, 1400, Epoch, ConnectionStates2),
+ send(Transport, Socket, [HsBefore, EncChangeCipher, HsAfter]),
+ {State0#state{connection_states = ConnectionStates}, []};
+
+send_handshake_flight(#state{socket = Socket,
+ transport_cb = Transport,
+ flight_buffer = #{handshakes := [],
+ change_cipher_spec := ChangeCipher,
+ handshakes_after_change_cipher_spec := Flight1},
+ negotiated_version = Version,
+ connection_states = ConnectionStates0} = State0, Epoch) ->
+ {EncChangeCipher, ConnectionStates1} =
+ encode_change_cipher(ChangeCipher, Version, Epoch-1, ConnectionStates0),
+ {HsAfter, ConnectionStates} =
+ encode_handshake_flight(lists:reverse(Flight1), Version, 1400, Epoch, ConnectionStates1),
+ send(Transport, Socket, [EncChangeCipher, HsAfter]),
+ {State0#state{connection_states = ConnectionStates}, []}.
retransmit_epoch(_StateName, #state{connection_states = ConnectionStates}) ->
#{epoch := Epoch} =
ssl_record:current_connection_state(ConnectionStates, write),
Epoch.
-
-update_handshake_history(#hello_verify_request{}, _, Hist) ->
- Hist;
-update_handshake_history(_, Handshake, Hist) ->
- %% DTLS never needs option "v2_hello_compatible" to be true
- ssl_handshake:update_handshake_history(Hist, iolist_to_binary(Handshake), false).
-
-unprocessed_events(Events) ->
- %% The first handshake event will be processed immediately
- %% as it is entered first in the event queue and
- %% when it is processed there will be length(Events)-1
- %% handshake events left to process before we should
- %% process more TLS-records received on the socket.
- erlang:length(Events)-1.
-
+ignore_alert(#alert{level = ?FATAL}, #state{protocol_specific = #{ignored_alerts := N,
+ max_ignored_alerts := N}} = State) ->
+ {false, State};
+ignore_alert(#alert{level = ?FATAL} = Alert,
+ #state{protocol_specific = #{ignored_alerts := N} = PS} = State) ->
+ case is_ignore_alert(Alert) of
+ true ->
+ {true, State#state{protocol_specific = PS#{ignored_alerts => N+1}}};
+ false ->
+ {false, State}
+ end;
+ignore_alert(_, State) ->
+ {false, State}.
+
+%% RFC 6347 4.1.2.7. Handling Invalid Records
+%% recommends to silently ignore invalid DTLS records when
+%% upd is the transport. Note we do not support compression so no need
+%% include ?DECOMPRESSION_FAILURE
+is_ignore_alert(#alert{description = ?BAD_RECORD_MAC}) ->
+ true;
+is_ignore_alert(#alert{description = ?RECORD_OVERFLOW}) ->
+ true;
+is_ignore_alert(#alert{description = ?DECODE_ERROR}) ->
+ true;
+is_ignore_alert(#alert{description = ?DECRYPT_ERROR}) ->
+ true;
+is_ignore_alert(#alert{description = ?ILLEGAL_PARAMETER}) ->
+ true;
+is_ignore_alert(_) ->
+ false.
+
+log_ignore_alert(true, StateName, Alert, Role) ->
+ Txt = ssl_alert:alert_txt(Alert),
+ error_logger:format("DTLS over UDP ~p: In state ~p ignored to send ALERT ~s as DoS-attack mitigation \n",
+ [Role, StateName, Txt]);
+log_ignore_alert(false, _, _,_) ->
+ ok.