From f995d04a0575cdd110a96741bc733eb95d063113 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?P=C3=A9ter=20Dimitrov?= Date: Fri, 16 Nov 2018 11:19:42 +0100 Subject: ssl: Improve TLS 1.3 state machine - Use internal event to transition to the first state of the TLS 1.3 state machine. - Add gen_handshake_1_3/4 and gen_info_1_3/4. Change-Id: I17f12110356c7be4a8dddf9a616df7f181b0ef37 --- lib/ssl/src/tls_connection.erl | 179 ++++++++++++++++++++++++++--------------- 1 file changed, 115 insertions(+), 64 deletions(-) (limited to 'lib/ssl/src/tls_connection.erl') diff --git a/lib/ssl/src/tls_connection.erl b/lib/ssl/src/tls_connection.erl index 9f98572691..be262c48e8 100644 --- a/lib/ssl/src/tls_connection.erl +++ b/lib/ssl/src/tls_connection.erl @@ -225,7 +225,8 @@ handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE, fragment = Data}, negotiated_version = Version, ssl_options = Options} = State0) -> try - {Packets, Buf} = tls_handshake:get_tls_handshake(Version,Data,Buf0, Options), + EffectiveVersion = effective_version(Version, Options), + {Packets, Buf} = tls_handshake:get_tls_handshake(EffectiveVersion,Data,Buf0, Options), State1 = State0#state{protocol_buffers = Buffers#protocol_buffers{tls_handshake_buffer = Buf}}, @@ -498,7 +499,7 @@ init({call, From}, {start, Timeout}, session_cache = Cache, session_cache_cb = CacheCb } = State0) -> - KeyShare = maybe_generate_key_share(SslOpts), + KeyShare = maybe_generate_client_shares(SslOpts), Timer = ssl_connection:start_or_recv_cancel_timer(Timeout, From), Hello = tls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, Cache, CacheCb, Renegotiation, Cert, KeyShare), @@ -570,41 +571,36 @@ hello(internal, #client_hello{client_version = ClientVersion} = Hello, negotiated_protocol = CurrentProtocol, key_algorithm = KeyExAlg, ssl_options = SslOpts} = State) -> - case tls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb, - ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of - #alert{} = Alert -> - ssl_connection:handle_own_alert(Alert, ClientVersion, hello, - State#state{negotiated_version - = ClientVersion}); - {Version, {Type, Session}, - ConnectionStates, Protocol0, ServerHelloExt, HashSign} when Version < {3,4} -> - Protocol = case Protocol0 of - undefined -> CurrentProtocol; - _ -> Protocol0 - end, - gen_handshake(?FUNCTION_NAME, internal, {common_client_hello, Type, ServerHelloExt}, - State#state{connection_states = ConnectionStates, - negotiated_version = Version, - hashsign_algorithm = HashSign, - client_hello_version = ClientVersion, - session = Session, - negotiated_protocol = Protocol}); - %% TLS 1.3 - {Version, {Type, Session}, - ConnectionStates, Protocol0, ServerHelloExt, HashSign} -> - Protocol = case Protocol0 of - undefined -> CurrentProtocol; - _ -> Protocol0 - end, - tls_connection_1_3:gen_handshake(?FUNCTION_NAME, - internal, - {common_client_hello, Type, ServerHelloExt}, - State#state{connection_states = ConnectionStates, - negotiated_version = Version, - hashsign_algorithm = HashSign, - client_hello_version = ClientVersion, - session = Session, - negotiated_protocol = Protocol}) + case choose_tls_version(SslOpts, Hello) of + 'tls_v1.3' -> + %% Continue in TLS 1.3 'start' state + {next_state, start, State, [{next_event, internal, Hello}]}; + 'tls_v1.2' -> + case tls_handshake:hello(Hello, + SslOpts, + {Port, Session0, Cache, CacheCb, + ConnectionStates0, Cert, KeyExAlg}, + Renegotiation) of + #alert{} = Alert -> + ssl_connection:handle_own_alert(Alert, ClientVersion, hello, + State#state{negotiated_version + = ClientVersion}); + {Version, {Type, Session}, + ConnectionStates, Protocol0, ServerHelloExt, HashSign} -> + Protocol = case Protocol0 of + undefined -> CurrentProtocol; + _ -> Protocol0 + end, + gen_handshake(?FUNCTION_NAME, + internal, + {common_client_hello, Type, ServerHelloExt}, + State#state{connection_states = ConnectionStates, + negotiated_version = Version, + hashsign_algorithm = HashSign, + client_hello_version = ClientVersion, + session = Session, + negotiated_protocol = Protocol}) + end end; hello(internal, #server_hello{} = Hello, #state{connection_states = ConnectionStates0, @@ -724,108 +720,108 @@ downgrade(Type, Event, State) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- start(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); start(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec negotiated(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- negotiated(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); negotiated(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec recvd_ch(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- recvd_ch(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); recvd_ch(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_cert(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_cert(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_cert(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_cv(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_cv(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_cv(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_eoed(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_eoed(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_eoed(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_finished(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_finished(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_finished(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_flight2(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_flight2(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_flight2(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec connected(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- connected(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); connected(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_cert_cr(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_cert_cr(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_cert_cr(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_ee(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_ee(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_ee(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %%-------------------------------------------------------------------- -spec wait_sh(gen_statem:event_type(), term(), #state{}) -> gen_statem:state_function_result(). %%-------------------------------------------------------------------- wait_sh(info, Event, State) -> - gen_info(Event, ?FUNCTION_NAME, State); + gen_info_1_3(Event, ?FUNCTION_NAME, State); wait_sh(Type, Event, State) -> - gen_handshake(?FUNCTION_NAME, Type, Event, State). + gen_handshake_1_3(?FUNCTION_NAME, Type, Event, State). %-------------------------------------------------------------------- %% gen_statem callbacks @@ -851,7 +847,6 @@ initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trac #ssl_options{beast_mitigation = BeastMitigation, erl_dist = IsErlDist} = SSLOptions, ConnectionStates = tls_record:init_connection_states(Role, BeastMitigation), - ErlDistData = erl_dist_data(IsErlDist), SessionCacheCb = case application:get_env(ssl, session_cb) of {ok, Cb} when is_atom(Cb) -> @@ -1037,6 +1032,18 @@ gen_handshake(StateName, Type, Event, Version, StateName, State) end. +gen_handshake_1_3(StateName, Type, Event, + #state{negotiated_version = Version} = State) -> + try tls_connection_1_3:StateName(Type, Event, State, ?MODULE) of + Result -> + Result + catch + _:_ -> + ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, + malformed_handshake_data), + Version, StateName, State) + end. + gen_info(Event, connection = StateName, #state{negotiated_version = Version} = State) -> try handle_info(Event, StateName, State) of Result -> @@ -1058,6 +1065,29 @@ gen_info(Event, StateName, #state{negotiated_version = Version} = State) -> malformed_handshake_data), Version, StateName, State) end. + +gen_info_1_3(Event, connected = StateName, #state{negotiated_version = Version} = State) -> + try handle_info(Event, StateName, State) of + Result -> + Result + catch + _:_ -> + ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR, + malformed_data), + Version, StateName, State) + end; + +gen_info_1_3(Event, StateName, #state{negotiated_version = Version} = State) -> + try handle_info(Event, StateName, State) of + Result -> + Result + catch + _:_ -> + ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, + malformed_handshake_data), + Version, StateName, State) + end. + unprocessed_events(Events) -> %% The first handshake event will be processed immediately @@ -1103,12 +1133,33 @@ ensure_sender_terminate(_, #state{protocol_specific = #{sender := Sender}}) -> end, spawn(Kill). -maybe_generate_key_share(#ssl_options{ +maybe_generate_client_shares(#ssl_options{ versions = [Version|_], supported_groups = #supported_groups{ supported_groups = Groups}}) when Version =:= {3,4} -> ssl_cipher:generate_client_shares(Groups); -maybe_generate_key_share(_) -> +maybe_generate_client_shares(_) -> undefined. + +choose_tls_version(#ssl_options{versions = Versions}, + #client_hello{ + extensions = #{client_hello_versions := + #client_hello_versions{versions = ClientVersions} + } + }) -> + case ssl_handshake:select_supported_version(ClientVersions, Versions) of + {3,4} -> + 'tls_v1.3'; + _Else -> + 'tls_v1.2' + end; +choose_tls_version(_, _) -> + 'tls_v1.2'. + + +effective_version(undefined, #ssl_options{versions = [Version|_]}) -> + Version; +effective_version(Version, _) -> + Version. -- cgit v1.2.3