diff options
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 105 |
1 files changed, 51 insertions, 54 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index 40289f5e2f..513b4b5be0 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -39,7 +39,7 @@ -export([start_fsm/8, start_link/7, init/1]). %% State transition handling --export([next_record/1, next_event/3]). +-export([next_record/1, next_event/3, next_event/4]). %% Handshake handling -export([renegotiate/2, @@ -201,6 +201,7 @@ reinit_handshake_data(#state{protocol_buffers = Buffers} = State) -> State#state{premaster_secret = undefined, public_key_info = undefined, tls_handshake_history = ssl_handshake:init_handshake_history(), + flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}, protocol_buffers = Buffers#protocol_buffers{ dtls_handshake_next_seq = 0, @@ -246,7 +247,7 @@ callback_mode() -> state_functions. %%-------------------------------------------------------------------- -%% State functionsconnection/2 +%% State functions %%-------------------------------------------------------------------- init({call, From}, {start, Timeout}, @@ -265,7 +266,7 @@ init({call, From}, {start, Timeout}, Version = Hello#client_hello.client_version, HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions), State1 = prepare_flight(State0#state{negotiated_version = Version}), - State2 = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}), + {State2, Actions} = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}), State3 = State2#state{negotiated_version = Version, %% Requested version session = Session0#session{session_id = Hello#client_hello.session_id}, @@ -274,7 +275,7 @@ init({call, From}, {start, Timeout}, flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT} }, {Record, State} = next_record(State3), - next_event(hello, Record, State); + next_event(hello, Record, State, Actions); init({call, _} = Type, Event, #state{role = server, transport_cb = gen_udp} = State) -> ssl_connection:init(Type, Event, State#state{flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}}, @@ -307,9 +308,9 @@ hello(internal, #client_hello{cookie = <<>>, Cookie = dtls_handshake:cookie(<<"secret">>, IP, Port, Hello), VerifyRequest = dtls_handshake:hello_verify_request(Cookie, Version), State1 = prepare_flight(State0#state{negotiated_version = Version}), - State2 = send_handshake(VerifyRequest, State1), + {State2, Actions} = send_handshake(VerifyRequest, State1), {Record, State} = next_record(State2), - next_event(hello, Record, State#state{tls_handshake_history = ssl_handshake:init_handshake_history()}); + next_event(hello, Record, State#state{tls_handshake_history = ssl_handshake:init_handshake_history()}, Actions); hello(internal, #client_hello{cookie = Cookie} = Hello, #state{role = server, transport_cb = Transport, socket = Socket} = State0) -> @@ -338,13 +339,13 @@ hello(internal, #hello_verify_request{cookie = Cookie}, #state{role = client, Cache, CacheCb, Renegotiation, OwnCert), Version = Hello#client_hello.client_version, HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions), - State2 = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}), + {State2, Actions} = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}), State3 = State2#state{negotiated_version = Version, %% Requested version session = Session0#session{session_id = Hello#client_hello.session_id}}, {Record, State} = next_record(State3), - next_event(hello, Record, State); + next_event(hello, Record, State, Actions); hello(internal, #server_hello{} = Hello, #state{connection_states = ConnectionStates0, negotiated_version = ReqVersion, @@ -361,13 +362,13 @@ hello(internal, #server_hello{} = Hello, hello(internal, {handshake, {#client_hello{cookie = <<>>} = Handshake, _}}, State) -> %% Initial hello should not be in handshake history {next_state, hello, State, [{next_event, internal, Handshake}]}; - hello(internal, {handshake, {#hello_verify_request{} = Handshake, _}}, State) -> %% hello_verify should not be in handshake history {next_state, hello, State, [{next_event, internal, Handshake}]}; - hello(info, Event, State) -> handle_info(Event, hello, State); +hello(state_timeout, Event, State) -> + handle_state_timeout(Event, hello, State); hello(Type, Event, State) -> ssl_connection:hello(Type, Event, State, ?MODULE). @@ -380,7 +381,11 @@ abbreviated(internal = Type, ConnectionStates = dtls_record:next_epoch(ConnectionStates1, read), ssl_connection:abbreviated(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE); abbreviated(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) -> - ssl_connection:cipher(Type, Event, prepare_flight(State#state{connection_states = ConnectionStates}), ?MODULE); + ssl_connection:abbreviated(Type, Event, + prepare_flight(State#state{connection_states = ConnectionStates, + flight_state = connection}), ?MODULE); +abbreviated(state_timeout, Event, State) -> + handle_state_timeout(Event, abbreviated, State); abbreviated(Type, Event, State) -> ssl_connection:abbreviated(Type, Event, State, ?MODULE). @@ -388,6 +393,8 @@ certify(info, Event, State) -> handle_info(Event, certify, State); certify(internal = Type, #server_hello_done{} = Event, State) -> ssl_connection:certify(Type, Event, prepare_flight(State), ?MODULE); +certify(state_timeout, Event, State) -> + handle_state_timeout(Event, certify, State); certify(Type, Event, State) -> ssl_connection:certify(Type, Event, State, ?MODULE). @@ -400,7 +407,11 @@ cipher(internal = Type, #change_cipher_spec{type = <<1>>} = Event, ssl_connection:cipher(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE); cipher(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) -> ssl_connection:cipher(Type, Event, - prepare_flight(State#state{connection_states = ConnectionStates}), ?MODULE); + prepare_flight(State#state{connection_states = ConnectionStates, + flight_state = connection}), + ?MODULE); +cipher(state_timeout, Event, State) -> + handle_state_timeout(Event, cipher, State); cipher(Type, Event, State) -> ssl_connection:cipher(Type, Event, State, ?MODULE). @@ -414,12 +425,12 @@ connection(internal, #hello_request{}, #state{host = Host, port = Port, renegotiation = {Renegotiation, _}} = State0) -> Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, Cache, CacheCb, Renegotiation, Cert), - State1 = send_handshake(Hello, State0), + {State1, Actions} = send_handshake(Hello, State0), {Record, State} = next_record( State1#state{session = Session0#session{session_id = Hello#client_hello.session_id}}), - next_event(hello, Record, State); + next_event(hello, Record, State, Actions); connection(internal, #client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) -> %% Mitigate Computational DoS attack %% http://www.educatedguesswork.org/2011/10/ssltls_and_computational_dos.html @@ -439,7 +450,6 @@ connection(Type, Event, State) -> downgrade(Type, Event, State) -> ssl_connection:downgrade(Type, Event, State, ?MODULE). - %%-------------------------------------------------------------------- %% Description: This function is called by a gen_fsm when it receives any %% other message than a synchronous or asynchronous event @@ -447,16 +457,6 @@ downgrade(Type, Event, State) -> %%-------------------------------------------------------------------- %% raw data from socket, unpack records -handle_info({_,flight_retransmission_timeout}, connection, _) -> - {keep_state_and_data, []}; -handle_info({Ref, flight_retransmission_timeout}, StateName, - #state{flight_state = {waiting, Ref, NextTimeout}} = State0) -> - State1 = send_handshake_flight(State0#state{flight_state = {retransmit, NextTimeout}}, - retransmit_epoch(StateName, State0)), - {Record, State} = next_record(State1), - next_event(StateName, Record, State); -handle_info({_, flight_retransmission_timeout}, _, _) -> - {keep_state_and_data, []}; handle_info({Protocol, _, _, _, Data}, StateName, #state{data_tag = Protocol} = State0) -> case next_dtls_record(Data, State0) of @@ -494,7 +494,6 @@ handle_call(Event, From, StateName, State) -> handle_common_event(internal, #alert{} = Alert, StateName, #state{negotiated_version = Version} = State) -> ssl_connection:handle_own_alert(Alert, Version, StateName, State); - %%% DTLS record protocol level handshake messages handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE, fragment = Data}, @@ -509,13 +508,8 @@ handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE, {Packets, Buffers} -> State = State0#state{protocol_buffers = Buffers}, Events = dtls_handshake_events(Packets), - case StateName of - connection -> - ssl_connection:hibernate_after(StateName, State, Events); - _ -> - {next_state, StateName, - State#state{unprocessed_handshake_events = unprocessed_events(Events)}, Events} - end + {next_state, StateName, + State#state{unprocessed_handshake_events = unprocessed_events(Events)}, Events} end catch throw:#alert{} = Alert -> ssl_connection:handle_own_alert(Alert, Version, StateName, State0) @@ -539,6 +533,13 @@ handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, Sta handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) -> {next_state, StateName, State}. +handle_state_timeout(flight_retransmission_timeout, StateName, + #state{flight_state = {retransmit, NextTimeout}} = State0) -> + {State1, Actions} = send_handshake_flight(State0#state{flight_state = {retransmit, NextTimeout}}, + retransmit_epoch(StateName, State0)), + {Record, State} = next_record(State1), + next_event(StateName, Record, State, Actions). + send(Transport, {_, {{_,_}, _} = Socket}, Data) -> send(Transport, Socket, Data); send(Transport, Socket, Data) -> @@ -650,7 +651,8 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, _}, User, allow_renegotiate = SSLOptions#ssl_options.client_renegotiation, start_or_recv_from = undefined, protocol_cb = ?MODULE, - flight_buffer = new_flight() + flight_buffer = new_flight(), + flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT} }. next_dtls_record(Data, #state{protocol_buffers = #protocol_buffers{ @@ -719,14 +721,14 @@ next_event(connection = StateName, no_record, #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) -> case next_record_if_active(State0) of {no_record, State} -> - ssl_connection:hibernate_after(StateName, State, Actions); + ssl_connection:hibernate_after(StateName, State, Actions); {#ssl_tls{epoch = CurrentEpoch} = Record, State} -> {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]}; {#ssl_tls{epoch = Epoch, type = ?HANDSHAKE, version = _Version}, State1} = _Record when Epoch == CurrentEpoch-1 -> - State = send_handshake_flight(State1, Epoch), - {next_state, StateName, State, Actions}; + {State, MoreActions} = send_handshake_flight(State1, Epoch), + {next_state, StateName, State, Actions ++ MoreActions}; {#ssl_tls{epoch = _Epoch, version = _Version}, State} -> %% TODO maybe buffer later epoch @@ -782,16 +784,15 @@ start_flight(#state{transport_cb = gen_udp, flight_state = {retransmit, Timeout}} = State) -> start_retransmision_timer(Timeout, State); start_flight(#state{transport_cb = gen_udp, - flight_state = {waiting, _, Timeout}} = State) -> - start_retransmision_timer(Timeout, State); + flight_state = connection} = State) -> + {State, []}; start_flight(State) -> %% No retransmision needed i.e DTLS over SCTP - State#state{flight_state = reliable}. + {State#state{flight_state = reliable}, []}. start_retransmision_timer(Timeout, State) -> - Ref = erlang:make_ref(), - _ = erlang:send_after(Timeout, self(), {Ref, flight_retransmission_timeout}), - State#state{flight_state = {waiting, Ref, new_timeout(Timeout)}}. + {State#state{flight_state = {retransmit, new_timeout(Timeout)}}, + [{state_timeout, Timeout, flight_retransmission_timeout}]}. new_timeout(N) when N =< 30 -> N * 2; @@ -815,13 +816,13 @@ renegotiate(#state{role = server, connection_states = CS0} = State0, Actions) -> HelloRequest = ssl_handshake:hello_request(), CS = CS0#{write_msg_seq => 0}, - State1 = send_handshake(HelloRequest, - State0#state{connection_states = - CS}), + {State1, MoreActions} = send_handshake(HelloRequest, + State0#state{connection_states = + CS}), Hs0 = ssl_handshake:init_handshake_history(), {Record, State} = next_record(State1#state{tls_handshake_history = Hs0, protocol_buffers = #protocol_buffers{}}), - next_event(hello, Record, State, Actions). + next_event(hello, Record, State, Actions ++ MoreActions). handle_alerts([], Result) -> Result; @@ -832,15 +833,11 @@ handle_alerts([Alert | Alerts], {next_state, StateName, State}) -> handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) -> handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)). -retransmit_epoch(StateName, #state{connection_states = ConnectionStates}) -> +retransmit_epoch(_StateName, #state{connection_states = ConnectionStates}) -> #{epoch := Epoch} = ssl_record:current_connection_state(ConnectionStates, write), - case StateName of - connection -> - Epoch-1; - _ -> - Epoch - end. + Epoch. + update_handshake_history(#hello_verify_request{}, _, Hist) -> Hist; |