diff options
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 186 |
1 files changed, 139 insertions, 47 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index b8be686b99..479f68f4bb 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -42,22 +42,17 @@ -export([next_record/1, next_event/3]). %% Handshake handling --export([%%renegotiate/2, - send_handshake/2, queue_handshake/2, queue_change_cipher/2]). +-export([renegotiate/2, + reinit_handshake_data/1, + send_handshake/2, queue_handshake/2, queue_change_cipher/2, + select_sni_extension/1]). %% Alert and close handling --export([%%send_alert/2, handle_own_alert/4, handle_close_alert/3, - handle_normal_shutdown/3 %%, close/5 - %%alert_user/6, alert_user/9 - ]). +-export([send_alert/2, close/5]). %% Data handling --export([%%write_application_data/3, - read_application_data/2, - passive_receive/2, next_record_if_active/1%, - %%handle_common_event/4, - %handle_packet/3 +-export([passive_receive/2, next_record_if_active/1, handle_common_event/4 ]). %% gen_statem state functions @@ -65,9 +60,7 @@ hello/3, certify/3, cipher/3, abbreviated/3, %% Handshake states connection/3]). %% gen_statem callbacks --export([terminate/3, code_change/4, format_status/2]). - --define(GEN_STATEM_CB_MODE, state_functions). +-export([callback_mode/0, terminate/3, code_change/4, format_status/2]). %%==================================================================== %% Internal application API @@ -104,10 +97,11 @@ send_handshake(Handshake, State) -> send_handshake_flight(queue_handshake(Handshake, State)). queue_flight_buffer(Msg, #state{negotiated_version = Version, - connection_states = #connection_states{ - current_write = - #connection_state{epoch = Epoch}}, + connection_states = ConnectionStates, flight_buffer = Flight} = State) -> + ConnectionState = + ssl_record:current_connection_state(ConnectionStates, write), + Epoch = maps:get(epoch, ConnectionState), State#state{flight_buffer = Flight ++ [{Version, Epoch, Msg}]}. queue_handshake(Handshake, #state{negotiated_version = Version, @@ -141,6 +135,25 @@ send_alert(Alert, #state{negotiated_version = Version, Transport:send(Socket, BinMsg), State0#state{connection_states = ConnectionStates}. +close(downgrade, _,_,_,_) -> + ok; +%% Other +close(_, Socket, Transport, _,_) -> + Transport:close(Socket). + +reinit_handshake_data(#state{protocol_buffers = Buffers} = State) -> + State#state{premaster_secret = undefined, + public_key_info = undefined, + tls_handshake_history = ssl_handshake:init_handshake_history(), + protocol_buffers = + Buffers#protocol_buffers{dtls_fragment_state = + dtls_handshake:dtls_handshake_new_flight(0)}}. + +select_sni_extension(#client_hello{extensions = HelloExtensions}) -> + HelloExtensions#hello_extensions.sni; +select_sni_extension(_) -> + undefined. + %%==================================================================== %% tls_connection_sup API %%==================================================================== @@ -161,12 +174,15 @@ init([Role, Host, Port, Socket, Options, User, CbInfo]) -> State0 = initial_state(Role, Host, Port, Socket, Options, User, CbInfo), try State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0), - gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, init, State) + gen_statem:enter_loop(?MODULE, [], init, State) catch throw:Error -> - gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, error, {Error,State0}) + gen_statem:enter_loop(?MODULE, [], error, {Error,State0}) end. +callback_mode() -> + state_functions. + %%-------------------------------------------------------------------- %% State functionsconnection/2 %%-------------------------------------------------------------------- @@ -231,7 +247,7 @@ hello(internal, #client_hello{client_version = ClientVersion, case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb, ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of #alert{} = Alert -> - handle_own_alert(Alert, ClientVersion, hello, State); + ssl_connection:handle_own_alert(Alert, ClientVersion, hello, State); {Version, {Type, Session}, ConnectionStates, Protocol0, ServerHelloExt, HashSign} -> Protocol = case Protocol0 of @@ -255,7 +271,7 @@ hello(internal, #server_hello{} = Hello, ssl_options = SslOptions} = State) -> case dtls_handshake:hello(Hello, SslOptions, ConnectionStates0, Renegotiation) of #alert{} = Alert -> - handle_own_alert(Alert, ReqVersion, hello, State); + ssl_connection:handle_own_alert(Alert, ReqVersion, hello, State); {Version, NewId, ConnectionStates, ProtoExt, Protocol} -> ssl_connection:handle_session(Hello, Version, NewId, ConnectionStates, ProtoExt, Protocol, State) @@ -334,7 +350,7 @@ handle_info({Protocol, _, Data}, StateName, {Record, State} -> next_event(StateName, Record, State); #alert{} = Alert -> - handle_normal_shutdown(Alert, StateName, State0), + ssl_connection:handle_normal_shutdown(Alert, StateName, State0), {stop, {shutdown, own_alert}} end; handle_info({CloseTag, Socket}, StateName, @@ -354,7 +370,7 @@ handle_info({CloseTag, Socket}, StateName, %%invalidate_session(Role, Host, Port, Session) ok end, - handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State), + ssl_connection:handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State), {stop, {shutdown, transport_closed}}; handle_info(Msg, StateName, State) -> ssl_connection:handle_info(Msg, StateName, State). @@ -362,6 +378,51 @@ handle_info(Msg, StateName, State) -> handle_call(Event, From, StateName, State) -> ssl_connection:handle_call(Event, From, StateName, State, ?MODULE). +handle_common_event(internal, #alert{} = Alert, StateName, + #state{negotiated_version = Version} = State) -> + ssl_connection:handle_own_alert(Alert, Version, StateName, State); + +%%% DTLS record protocol level handshake messages +handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE} = Record, + StateName, + #state{protocol_buffers = + #protocol_buffers{dtls_packets = Packets0, + dtls_fragment_state = HsState0} = Buffers, + negotiated_version = Version} = State0) -> + try + {Packets1, HsState} = dtls_handshake:get_dtls_handshake(Record, HsState0), + State = + State0#state{protocol_buffers = + Buffers#protocol_buffers{dtls_fragment_state = HsState}}, + Events = dtls_handshake_events(Packets0 ++ Packets1), + case StateName of + connection -> + ssl_connection:hibernate_after(StateName, State, Events); + _ -> + {next_state, StateName, State, Events} + end + catch throw:#alert{} = Alert -> + ssl_connection:handle_own_alert(Alert, Version, StateName, State0) + end; +%%% DTLS record protocol level application data messages +handle_common_event(internal, #ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, StateName, State) -> + {next_state, StateName, State, [{next_event, internal, {application_data, Data}}]}; +%%% DTLS record protocol level change cipher messages +handle_common_event(internal, #ssl_tls{type = ?CHANGE_CIPHER_SPEC, fragment = Data}, StateName, State) -> + {next_state, StateName, State, [{next_event, internal, #change_cipher_spec{type = Data}}]}; +%%% DTLS record protocol level Alert messages +handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, StateName, + #state{negotiated_version = Version} = State) -> + case decode_alerts(EncAlerts) of + Alerts = [_|_] -> + handle_alerts(Alerts, {next_state, StateName, State}); + #alert{} = Alert -> + ssl_connection:handle_own_alert(Alert, Version, StateName, State) + end; +%% Ignore unknown TLS record level protocol messages +handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) -> + {next_state, StateName, State}. + %%-------------------------------------------------------------------- %% Description:This function is called by a gen_fsm when it is about %% to terminate. It should be the opposite of Module:init/1 and do any @@ -376,7 +437,7 @@ terminate(Reason, StateName, State) -> %% Description: Convert process state when code is changed %%-------------------------------------------------------------------- code_change(_OldVsn, StateName, State, _Extra) -> - {?GEN_STATEM_CB_MODE, StateName, State}. + {ok, StateName, State}. format_status(Type, Data) -> ssl_connection:format_status(Type, Data). @@ -384,10 +445,21 @@ format_status(Type, Data) -> %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- + +dtls_handshake_events([]) -> + throw(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, malformed_handshake)); +dtls_handshake_events(Packets) -> + lists:map(fun(Packet) -> + {next_event, internal, {handshake, Packet}} + end, Packets). + + encode_handshake(Handshake, Version, ConnectionStates0, Hist0) -> {Seq, ConnectionStates} = sequence(ConnectionStates0), {EncHandshake, Frag} = dtls_handshake:encode_handshake(Handshake, Version, Seq), - Hist = ssl_handshake:update_handshake_history(Hist0, EncHandshake), + %% DTLS does not have an equivalent version to SSLv2. So v2 hello compatibility + %% will always be false + Hist = ssl_handshake:update_handshake_history(Hist0, EncHandshake, false), {Frag, ConnectionStates, Hist}. encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> @@ -423,12 +495,12 @@ encode_handshake_record(_Version, _Epoch, _Space, _MsgType, _MsgSeq, _Len, <<>>, encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, Offset, MRS, Encoded0, CS0) -> MaxFragmentLen = Space - 25, - case Bin of - <<BinFragment:MaxFragmentLen/bytes, Rest/binary>> -> - ok; + {BinFragment, Rest} = + case Bin of + <<BinFragment0:MaxFragmentLen/bytes, Rest0/binary>> -> + {BinFragment0, Rest0}; _ -> - BinFragment = Bin, - Rest = <<>> + {Bin, <<>>} end, FragLength = byte_size(BinFragment), Frag = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(Offset), ?uint24(FragLength), BinFragment], @@ -459,13 +531,13 @@ finish_pack_records({[], Acc}) -> finish_pack_records({Buf, Acc}) -> lists:reverse([lists:reverse(Buf)|Acc]). -%% decode_alerts(Bin) -> -%% ssl_alert:decode(Bin). +decode_alerts(Bin) -> + ssl_alert:decode(Bin). initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User, {CbModule, DataTag, CloseTag, ErrorTag}) -> #ssl_options{beast_mitigation = BeastMitigation} = SSLOptions, - ConnectionStates = ssl_record:init_connection_states(Role, BeastMitigation), + ConnectionStates = dtls_record:init_connection_states(Role, BeastMitigation), SessionCacheCb = case application:get_env(ssl, session_cb) of {ok, Cb} when is_atom(Cb) -> @@ -548,7 +620,7 @@ passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName) -> {Record, State} = next_record(State0), next_event(StateName, Record, State); _ -> - {Record, State} = read_application_data(<<>>, State0), + {Record, State} = ssl_connection:read_application_data(<<>>, State0), next_event(StateName, Record, State) end. @@ -560,7 +632,7 @@ next_event(connection = StateName, no_record, State0, Actions) -> {no_record, State} -> ssl_connection:hibernate_after(StateName, State, Actions); {#ssl_tls{} = Record, State} -> - {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; + {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]}; {#alert{} = Alert, State} -> {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} end; @@ -569,20 +641,11 @@ next_event(StateName, Record, State, Actions) -> no_record -> {next_state, StateName, State, Actions}; #ssl_tls{} = Record -> - {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; + {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]}; #alert{} = Alert -> {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} end. -read_application_data(_,State) -> - {#ssl_tls{fragment = <<"place holder">>}, State}. - -handle_own_alert(_,_,_, State) -> %% Place holder - {stop, {shutdown, own_alert}, State}. - -handle_normal_shutdown(_, _, _State) -> %% Place holder - ok. - %% TODO This generates dialyzer warnings, has to be handled differently. %% handle_packet(Address, Port, Packet) -> %% try dtls_record:get_dtls_records(Packet, <<>>) of @@ -631,5 +694,34 @@ handle_normal_shutdown(_, _, _State) -> %% Place holder %% address_to_bin({A,B,C,D,E,F,G,H}, Port) -> %% <<A:16,B:16,C:16,D:16,E:16,F:16,G:16,H:16,Port:16>>. -sequence(#connection_states{dtls_write_msg_seq = Seq} = CS) -> - {Seq, CS#connection_states{dtls_write_msg_seq = Seq + 1}}. +sequence(#{write_msg_seq := Seq} = ConnectionState) -> + {Seq, ConnectionState#{write_msg_seq => Seq + 1}}. + +renegotiate(#state{role = client} = State, Actions) -> + %% Handle same way as if server requested + %% the renegotiation + Hs0 = ssl_handshake:init_handshake_history(), + {next_state, connection, State#state{tls_handshake_history = Hs0, + protocol_buffers = #protocol_buffers{}}, + [{next_event, internal, #hello_request{}} | Actions]}; + +renegotiate(#state{role = server, + connection_states = CS0} = State0, Actions) -> + HelloRequest = ssl_handshake:hello_request(), + CS = CS0#{write_msg_seq => 0}, + State1 = send_handshake(HelloRequest, + State0#state{connection_states = + CS}), + Hs0 = ssl_handshake:init_handshake_history(), + {Record, State} = next_record(State1#state{tls_handshake_history = Hs0, + protocol_buffers = #protocol_buffers{}}), + next_event(hello, Record, State, Actions). + +handle_alerts([], Result) -> + Result; +handle_alerts(_, {stop,_} = Stop) -> + Stop; +handle_alerts([Alert | Alerts], {next_state, StateName, State}) -> + handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)); +handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) -> + handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)). |