aboutsummaryrefslogtreecommitdiffstats
path: root/lib/ssl/src/tls_connection.erl
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ssl/src/tls_connection.erl')
-rw-r--r--lib/ssl/src/tls_connection.erl74
1 files changed, 54 insertions, 20 deletions
diff --git a/lib/ssl/src/tls_connection.erl b/lib/ssl/src/tls_connection.erl
index 9880befa94..8b828f3421 100644
--- a/lib/ssl/src/tls_connection.erl
+++ b/lib/ssl/src/tls_connection.erl
@@ -68,10 +68,8 @@
hello/3, certify/3, cipher/3, abbreviated/3, %% Handshake states
connection/3]).
%% gen_statem callbacks
--export([terminate/3, code_change/4, format_status/2]).
+-export([callback_mode/0, terminate/3, code_change/4, format_status/2]).
--define(GEN_STATEM_CB_MODE, state_functions).
-
%%====================================================================
%% Internal application API
%%====================================================================
@@ -169,11 +167,14 @@ init([Role, Host, Port, Socket, Options, User, CbInfo]) ->
State0 = initial_state(Role, Host, Port, Socket, Options, User, CbInfo),
try
State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0),
- gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, init, State)
+ gen_statem:enter_loop(?MODULE, [], init, State)
catch throw:Error ->
- gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, error, {Error, State0})
+ gen_statem:enter_loop(?MODULE, [], error, {Error, State0})
end.
+callback_mode() ->
+ state_functions.
+
%%--------------------------------------------------------------------
%% State functions
%%--------------------------------------------------------------------
@@ -213,7 +214,7 @@ init({call, From}, {start, Timeout},
{Record, State} = next_record(State1),
next_event(hello, Record, State);
init(Type, Event, State) ->
- ssl_connection:init(Type, Event, State, ?MODULE).
+ gen_handshake(ssl_connection, init, Type, Event, State).
%%--------------------------------------------------------------------
-spec error(gen_statem:event_type(),
@@ -257,13 +258,13 @@ hello(internal, #client_hello{client_version = ClientVersion,
_ -> Protocol0
end,
- ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt},
+ gen_handshake(ssl_connection, hello, internal, {common_client_hello, Type, ServerHelloExt},
State#state{connection_states = ConnectionStates,
negotiated_version = Version,
hashsign_algorithm = HashSign,
session = Session,
client_ecc = {EllipticCurves, EcPointFormats},
- negotiated_protocol = Protocol}, ?MODULE)
+ negotiated_protocol = Protocol})
end;
hello(internal, #server_hello{} = Hello,
#state{connection_states = ConnectionStates0,
@@ -279,36 +280,36 @@ hello(internal, #server_hello{} = Hello,
Version, NewId, ConnectionStates, ProtoExt, Protocol, State)
end;
hello(info, Event, State) ->
- handle_info(Event, hello, State);
+ gen_info(Event, hello, State);
hello(Type, Event, State) ->
- ssl_connection:hello(Type, Event, State, ?MODULE).
+ gen_handshake(ssl_connection, hello, Type, Event, State).
%%--------------------------------------------------------------------
-spec abbreviated(gen_statem:event_type(), term(), #state{}) ->
gen_statem:state_function_result().
%%--------------------------------------------------------------------
abbreviated(info, Event, State) ->
- handle_info(Event, abbreviated, State);
+ gen_info(Event, abbreviated, State);
abbreviated(Type, Event, State) ->
- ssl_connection:abbreviated(Type, Event, State, ?MODULE).
+ gen_handshake(ssl_connection, abbreviated, Type, Event, State).
%%--------------------------------------------------------------------
-spec certify(gen_statem:event_type(), term(), #state{}) ->
gen_statem:state_function_result().
%%--------------------------------------------------------------------
certify(info, Event, State) ->
- handle_info(Event, certify, State);
+ gen_info(Event, certify, State);
certify(Type, Event, State) ->
- ssl_connection:certify(Type, Event, State, ?MODULE).
+ gen_handshake(ssl_connection, certify, Type, Event, State).
%%--------------------------------------------------------------------
-spec cipher(gen_statem:event_type(), term(), #state{}) ->
gen_statem:state_function_result().
%%--------------------------------------------------------------------
cipher(info, Event, State) ->
- handle_info(Event, cipher, State);
+ gen_info(Event, cipher, State);
cipher(Type, Event, State) ->
- ssl_connection:cipher(Type, Event, State, ?MODULE).
+ gen_handshake(ssl_connection, cipher, Type, Event, State).
%%--------------------------------------------------------------------
-spec connection(gen_statem:event_type(),
@@ -316,7 +317,7 @@ cipher(Type, Event, State) ->
gen_statem:state_function_result().
%%--------------------------------------------------------------------
connection(info, Event, State) ->
- handle_info(Event, connection, State);
+ gen_info(Event, connection, State);
connection(internal, #hello_request{},
#state{role = client, host = Host, port = Port,
session = #session{own_certificate = Cert} = Session0,
@@ -432,11 +433,16 @@ handle_common_event(internal, #ssl_tls{type = ?CHANGE_CIPHER_SPEC, fragment = Da
%%% TLS record protocol level Alert messages
handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, StateName,
#state{negotiated_version = Version} = State) ->
- case decode_alerts(EncAlerts) of
+ try decode_alerts(EncAlerts) of
Alerts = [_|_] ->
handle_alerts(Alerts, {next_state, StateName, State});
+ [] ->
+ handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, empty_alert), Version, StateName, State);
#alert{} = Alert ->
handle_own_alert(Alert, Version, StateName, State)
+ catch
+ _:_ ->
+ handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, alert_decode_error), Version, StateName, State)
end;
%% Ignore unknown TLS record level protocol messages
handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) ->
@@ -457,9 +463,9 @@ format_status(Type, Data) ->
%%--------------------------------------------------------------------
code_change(_OldVsn, StateName, State0, {Direction, From, To}) ->
State = convert_state(State0, Direction, From, To),
- {?GEN_STATEM_CB_MODE, StateName, State};
+ {ok, StateName, State};
code_change(_OldVsn, StateName, State, _) ->
- {?GEN_STATEM_CB_MODE, StateName, State}.
+ {ok, StateName, State}.
%%--------------------------------------------------------------------
%%% Internal functions
@@ -1039,3 +1045,31 @@ handle_sni_extension(#client_hello{extensions = HelloExtensions}, State0) ->
end;
handle_sni_extension(_, State) ->
State.
+
+gen_handshake(GenConnection, StateName, Type, Event, #state{negotiated_version = Version} = State) ->
+ try GenConnection:StateName(Type, Event, State, ?MODULE) of
+ Result ->
+ Result
+ catch
+ _:_ ->
+ handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, malformed_handshake_data), Version, StateName, State)
+ end.
+
+gen_info(Event, connection = StateName, #state{negotiated_version = Version} = State) ->
+ try handle_info(Event, StateName, State) of
+ Result ->
+ Result
+ catch
+ _:_ ->
+ handle_own_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR, malformed_data), Version, StateName, State)
+ end;
+
+gen_info(Event, StateName, #state{negotiated_version = Version} = State) ->
+ try handle_info(Event, StateName, State) of
+ Result ->
+ Result
+ catch
+ _:_ ->
+ handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, malformed_handshake_data), Version, StateName, State)
+ end.
+