diff options
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 78 |
1 files changed, 55 insertions, 23 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index 1c55e370cb..073cb4009b 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -444,21 +444,21 @@ init({call, From}, {start, Timeout}, {Record, State} = next_record(State3), next_event(hello, Record, State, Actions); init({call, _} = Type, Event, #state{role = server, transport_cb = gen_udp} = State) -> - Result = ssl_connection:?FUNCTION_NAME(Type, Event, - State#state{flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}, - protocol_specific = #{current_cookie_secret => dtls_v1:cookie_secret(), - previous_cookie_secret => <<>>, - ignored_alerts => 0, - max_ignored_alerts => 10}}, - ?MODULE), + Result = gen_handshake(?FUNCTION_NAME, Type, Event, + State#state{flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}, + protocol_specific = #{current_cookie_secret => dtls_v1:cookie_secret(), + previous_cookie_secret => <<>>, + ignored_alerts => 0, + max_ignored_alerts => 10}}), erlang:send_after(dtls_v1:cookie_timeout(), self(), new_cookie_secret), Result; init({call, _} = Type, Event, #state{role = server} = State) -> %% I.E. DTLS over sctp - ssl_connection:?FUNCTION_NAME(Type, Event, State#state{flight_state = reliable}, ?MODULE); + gen_handshake(?FUNCTION_NAME, Type, Event, State#state{flight_state = reliable}); init(Type, Event, State) -> - ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE). + gen_handshake(?FUNCTION_NAME, Type, Event, State). + %%-------------------------------------------------------------------- -spec error(gen_statem:event_type(), {start, timeout()} | term(), #state{}) -> @@ -469,7 +469,7 @@ error(enter, _, State) -> error({call, From}, {start, _Timeout}, {Error, State}) -> {stop_and_reply, normal, {reply, From, {error, Error}}, State}; error({call, _} = Call, Msg, State) -> - ssl_connection:?FUNCTION_NAME(Call, Msg, State, ?MODULE); + gen_handshake(?FUNCTION_NAME, Call, Msg, State); error(_, _, _) -> {keep_state_and_data, [postpone]}. @@ -564,11 +564,11 @@ hello(internal, {handshake, {#hello_verify_request{} = Handshake, _}}, State) -> %% hello_verify should not be in handshake history {next_state, ?FUNCTION_NAME, State, [{next_event, internal, Handshake}]}; hello(info, Event, State) -> - handle_info(Event, ?FUNCTION_NAME, State); + gen_info(Event, ?FUNCTION_NAME, State); hello(state_timeout, Event, State) -> handle_state_timeout(Event, ?FUNCTION_NAME, State); hello(Type, Event, State) -> - ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE). + gen_handshake(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec abbreviated(gen_statem:event_type(), term(), #state{}) -> @@ -578,22 +578,21 @@ abbreviated(enter, _, State0) -> {State, Actions} = handle_flight_timer(State0), {keep_state, State, Actions}; abbreviated(info, Event, State) -> - handle_info(Event, ?FUNCTION_NAME, State); + gen_info(Event, ?FUNCTION_NAME, State); abbreviated(internal = Type, #change_cipher_spec{type = <<1>>} = Event, #state{connection_states = ConnectionStates0} = State) -> ConnectionStates1 = dtls_record:save_current_connection_state(ConnectionStates0, read), ConnectionStates = dtls_record:next_epoch(ConnectionStates1, read), - ssl_connection:?FUNCTION_NAME(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE); + gen_handshake(?FUNCTION_NAME, Type, Event, State#state{connection_states = ConnectionStates}); abbreviated(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) -> - ssl_connection:?FUNCTION_NAME(Type, Event, - prepare_flight(State#state{connection_states = ConnectionStates, - flight_state = connection}), ?MODULE); + gen_handshake(?FUNCTION_NAME, Type, Event, + prepare_flight(State#state{connection_states = ConnectionStates, + flight_state = connection})); abbreviated(state_timeout, Event, State) -> handle_state_timeout(Event, ?FUNCTION_NAME, State); abbreviated(Type, Event, State) -> - ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE). - + gen_handshake(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec certify(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). @@ -602,13 +601,13 @@ certify(enter, _, State0) -> {State, Actions} = handle_flight_timer(State0), {keep_state, State, Actions}; certify(info, Event, State) -> - handle_info(Event, ?FUNCTION_NAME, State); + gen_info(Event, ?FUNCTION_NAME, State); certify(internal = Type, #server_hello_done{} = Event, State) -> ssl_connection:certify(Type, Event, prepare_flight(State), ?MODULE); certify(state_timeout, Event, State) -> handle_state_timeout(Event, ?FUNCTION_NAME, State); certify(Type, Event, State) -> - ssl_connection:?FUNCTION_NAME(Type, Event, State, ?MODULE). + gen_handshake(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec cipher(gen_statem:event_type(), term(), #state{}) -> @@ -618,7 +617,7 @@ cipher(enter, _, State0) -> {State, Actions} = handle_flight_timer(State0), {keep_state, State, Actions}; cipher(info, Event, State) -> - handle_info(Event, ?FUNCTION_NAME, State); + gen_info(Event, ?FUNCTION_NAME, State); cipher(internal = Type, #change_cipher_spec{type = <<1>>} = Event, #state{connection_states = ConnectionStates0} = State) -> ConnectionStates1 = dtls_record:save_current_connection_state(ConnectionStates0, read), @@ -642,7 +641,7 @@ cipher(Type, Event, State) -> connection(enter, _, State) -> {keep_state, State}; connection(info, Event, State) -> - handle_info(Event, ?FUNCTION_NAME, State); + gen_info(Event, ?FUNCTION_NAME, State); connection(internal, #hello_request{}, #state{host = Host, port = Port, session = #session{own_certificate = Cert} = Session0, session_cache = Cache, session_cache_cb = CacheCb, @@ -905,6 +904,39 @@ encode_change_cipher(#change_cipher_spec{}, Version, Epoch, ConnectionStates) -> decode_alerts(Bin) -> ssl_alert:decode(Bin). +gen_handshake(StateName, Type, Event, + #state{negotiated_version = Version} = State) -> + try ssl_connection:StateName(Type, Event, State, ?MODULE) of + Result -> + Result + catch + _:_ -> + ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, + malformed_handshake_data), + Version, StateName, State) + end. + +gen_info(Event, connection = StateName, #state{negotiated_version = Version} = State) -> + try handle_info(Event, StateName, State) of + Result -> + Result + catch + _:_ -> + ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR, + malformed_data), + Version, StateName, State) + end; + +gen_info(Event, StateName, #state{negotiated_version = Version} = State) -> + try handle_info(Event, StateName, State) of + Result -> + Result + catch + _:_ -> + ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, + malformed_handshake_data), + Version, StateName, State) + end. unprocessed_events(Events) -> %% The first handshake event will be processed immediately %% as it is entered first in the event queue and |