diff options
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
| -rw-r--r-- | lib/ssl/src/dtls_connection.erl | 354 | 
1 files changed, 145 insertions, 209 deletions
| diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index e490de7eeb..82d6faee42 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -21,7 +21,7 @@  %% Internal application API --behaviour(gen_fsm). +-behaviour(gen_statem).  -include("dtls_connection.hrl").  -include("dtls_handshake.hrl"). @@ -36,37 +36,38 @@  %% Internal application API  %% Setup --export([start_fsm/8]). +-export([start_fsm/8, start_link/7, init/1]).  %% State transition handling	  --export([next_record/1, next_state/4%,  -	 %%next_state_connection/2 -	]). +-export([next_record/1, next_event/3]).  %% Handshake handling --export([%%renegotiate/1,  +-export([%%renegotiate/2,   	 send_handshake/2, send_change_cipher/2]). +  %% Alert and close handling --export([send_alert/2, handle_own_alert/4, %%handle_close_alert/3, -	 handle_normal_shutdown/3 -	 %%handle_unexpected_message/3, -	 %%alert_user/5, alert_user/8 +-export([%%send_alert/2, handle_own_alert/4, handle_close_alert/3, +	 handle_normal_shutdown/3 %%, close/5 +	 %%alert_user/6, alert_user/9  	]).  %% Data handling  -export([%%write_application_data/3,  -	 read_application_data/2%%, -%%	 passive_receive/2,  next_record_if_active/1 +	 read_application_data/2, +	 %%passive_receive/2, +	 next_record_if_active/1 %%, +	 %%handle_common_event/4  	]). -%% Called by tls_connection_sup --export([start_link/7]).  +%% gen_statem state functions +-export([init/3, error/3, downgrade/3, %% Initiation and take down states +	 hello/3, certify/3, cipher/3, abbreviated/3, %% Handshake states  +	 connection/3]).  +%% gen_statem callbacks +-export([terminate/3, code_change/4, format_status/2]). -%% gen_fsm callbacks --export([init/1, hello/2, certify/2, cipher/2, -	 abbreviated/2, connection/2, handle_event/3, -         handle_sync_event/4, handle_info/3, terminate/3, code_change/4]). +-define(GEN_STATEM_CB_MODE, state_functions).  %%====================================================================  %% Internal application API @@ -141,75 +142,74 @@ send_change_cipher(Msg, #state{connection_states = ConnectionStates0,  start_link(Role, Host, Port, Socket, Options, User, CbInfo) ->      {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Host, Port, Socket, Options, User, CbInfo]])}. -init([Role, Host, Port, Socket, {SSLOpts0, _} = Options,  User, CbInfo]) -> +init([Role, Host, Port, Socket, Options,  User, CbInfo]) ->      process_flag(trap_exit, true),      State0 =  initial_state(Role, Host, Port, Socket, Options, User, CbInfo), -    Handshake = ssl_handshake:init_handshake_history(), -    TimeStamp = erlang:monotonic_time(), -    try ssl_config:init(SSLOpts0, Role) of -	{ok, Ref, CertDbHandle, FileRefHandle, CacheHandle,  CRLDbInfo, OwnCert, Key, DHParams} -> -	    Session = State0#state.session, -	    State = State0#state{ -		      tls_handshake_history = Handshake, -		      session = Session#session{own_certificate = OwnCert, -						time_stamp = TimeStamp}, -		      file_ref_db = FileRefHandle, -		      cert_db_ref = Ref, -		      cert_db = CertDbHandle, -		      crl_db = CRLDbInfo, -		      session_cache = CacheHandle, -		      private_key = Key, -		      diffie_hellman_params = DHParams}, -	    gen_fsm:enter_loop(?MODULE, [], hello, State, get_timeout(State)) +    try +	State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0), +	gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, init, State)      catch  	throw:Error -> -	    gen_fsm:enter_loop(?MODULE, [], error, {Error,State0}, get_timeout(State0)) +	    gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, error, {Error,State0})      end.  %%-------------------------------------------------------------------- -%% Description:There should be one instance of this function for each -%% possible state name. Whenever a gen_fsm receives an event sent -%% using gen_fsm:send_event/2, the instance of this function with the -%% same name as the current state name StateName is called to handle -%% the event. It is also called if a timeout occurs. -%% -hello(start, #state{host = Host, port = Port, role = client, -		    ssl_options = SslOpts, -		    session = #session{own_certificate = Cert} = Session0, -		    session_cache = Cache, session_cache_cb = CacheCb, -		    transport_cb = Transport, socket = Socket, -		    connection_states = ConnectionStates0, -		    renegotiation = {Renegotiation, _}} = State0) -> +%% State functionsconnection/2 +%%-------------------------------------------------------------------- + +init({call, From}, {start, Timeout},  +     #state{host = Host, port = Port, role = client, +	    ssl_options = SslOpts, +	    session = #session{own_certificate = Cert} = Session0, +	    transport_cb = Transport, socket = Socket, +	    connection_states = ConnectionStates0, +	    renegotiation = {Renegotiation, _}, +	    session_cache = Cache, +	    session_cache_cb = CacheCb +	   } = State0) -> +    Timer = ssl_connection:start_or_recv_cancel_timer(Timeout, From),      Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, -					Cache, CacheCb, Renegotiation, Cert), +				       Cache, CacheCb, Renegotiation, Cert),      Version = Hello#client_hello.client_version, +    HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions),      Handshake0 = ssl_handshake:init_handshake_history(),      {BinMsg, ConnectionStates, Handshake} = -        encode_handshake(Hello, Version, ConnectionStates0, Handshake0), +        encode_handshake(Hello,  HelloVersion, ConnectionStates0, Handshake0),      Transport:send(Socket, BinMsg),      State1 = State0#state{connection_states = ConnectionStates,  			  negotiated_version = Version, %% Requested version  			  session =  			      Session0#session{session_id = Hello#client_hello.session_id}, -			  tls_handshake_history = Handshake}, +			  tls_handshake_history = Handshake, +			  start_or_recv_from = From, +			  timer = Timer},      {Record, State} = next_record(State1), -    next_state(hello, hello, Record, State); - -hello(Hello = #client_hello{client_version = ClientVersion}, -      State = #state{connection_states = ConnectionStates0, -		     port = Port, session = #session{own_certificate = Cert} = Session0, -		     renegotiation = {Renegotiation, _}, -		     session_cache = Cache, -		     session_cache_cb = CacheCb, -		     ssl_options = SslOpts}) -> +    next_event(hello, Record, State); +init(Type, Event, State) -> +    ssl_connection:init(Type, Event, State, ?MODULE). +  +error({call, From}, {start, _Timeout}, {Error, State}) -> +    {stop_and_reply, normal, {reply, From, {error, Error}}, State}; +error({call, From}, Msg, State) -> +    handle_call(Msg, From, error, State); +error(_, _, _) -> +     {keep_state_and_data, [postpone]}. + +hello(internal, #client_hello{client_version = ClientVersion} = Hello, +      #state{connection_states = ConnectionStates0, +	     port = Port, session = #session{own_certificate = Cert} = Session0, +	     renegotiation = {Renegotiation, _}, +	     session_cache = Cache, +	     session_cache_cb = CacheCb, +	     ssl_options = SslOpts} = State) ->      case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb,  					      ConnectionStates0, Cert}, Renegotiation) of          {Version, {Type, Session},  	 ConnectionStates,  	 #hello_extensions{ec_point_formats = EcPointFormats,  			   elliptic_curves = EllipticCurves} = ServerHelloExt, HashSign} -> -            ssl_connection:hello({common_client_hello, Type, ServerHelloExt, HashSign}, +            ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt, HashSign},  				 State#state{connection_states  = ConnectionStates,  					     negotiated_version = Version,  					     session = Session, @@ -217,7 +217,7 @@ hello(Hello = #client_hello{client_version = ClientVersion},          #alert{} = Alert ->              handle_own_alert(Alert, ClientVersion, hello, State)      end; -hello(Hello, +hello(internal, Hello,        #state{connection_states = ConnectionStates0,  	     negotiated_version = ReqVersion,  	     role = client, @@ -230,20 +230,30 @@ hello(Hello,  	    ssl_connection:handle_session(Hello,   					  Version, NewId, ConnectionStates, ProtoExt, Protocol, State)      end; - -hello(Msg, State) -> -    ssl_connection:hello(Msg, State, ?MODULE). - -abbreviated(Msg, State) -> -    ssl_connection:abbreviated(Msg, State, ?MODULE). - -certify(Msg, State) -> -    ssl_connection:certify(Msg, State, ?MODULE). - -cipher(Msg, State) -> -     ssl_connection:cipher(Msg, State, ?MODULE). - -connection(#hello_request{}, #state{host = Host, port = Port, +hello(info, Event, State) -> +    handle_info(Event, hello, State); + +hello(Type, Event, State) -> +    ssl_connection:hello(Type, Event, State, ?MODULE). + +abbreviated(info, Event, State) -> +    handle_info(Event, abbreviated, State); +abbreviated(Type, Event, State) -> +    ssl_connection:abbreviated(Type, Event, State, ?MODULE). + +certify(info, Event, State) -> +    handle_info(Event, certify, State); +certify(Type, Event, State) -> +    ssl_connection:certify(Type, Event, State, ?MODULE). + +cipher(info, Event, State) -> +    handle_info(Event, cipher, State); +cipher(Type, Event, State) -> +     ssl_connection:cipher(Type, Event, State, ?MODULE). + +connection(info, Event, State) -> +    handle_info(Event, connection, State); +connection(internal, #hello_request{}, #state{host = Host, port = Port,  				    session = #session{own_certificate = Cert} = Session0,  				    session_cache = Cache, session_cache_cb = CacheCb,  				    ssl_options = SslOpts, @@ -257,40 +267,30 @@ connection(#hello_request{}, #state{host = Host, port = Port,  	next_record(  	  State1#state{session = Session0#session{session_id  						  = Hello#client_hello.session_id}}), -    next_state(connection, hello, Record, State); +    next_event(hello, Record, State); -connection(#client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) -> +connection(internal, #client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) ->      %% Mitigate Computational DoS attack      %% http://www.educatedguesswork.org/2011/10/ssltls_and_computational_dos.html      %% http://www.thc.org/thc-ssl-dos/ Rather than disabling client      %% initiated renegotiation we will disallow many client initiated      %% renegotiations immediately after each other.      erlang:send_after(?WAIT_TO_ALLOW_RENEGOTIATION, self(), allow_renegotiate), -    hello(Hello, State#state{allow_renegotiate = false}); +    {next_state, hello, State#state{allow_renegotiate = false}, [{next_event, internal, Hello}]}; + -connection(#client_hello{}, #state{role = server, allow_renegotiate = false} = State0) -> +connection(internal, #client_hello{}, #state{role = server, allow_renegotiate = false} = State0) ->      Alert = ?ALERT_REC(?WARNING, ?NO_RENEGOTIATION), -    State = send_alert(Alert, State0), -    next_state_connection(connection, State); +    State1 = send_alert(Alert, State0), +    {Record, State} = ssl_connection:prepare_connection(State1, ?MODULE), +    next_event(connection, Record, State); -connection(Msg, State) -> -     ssl_connection:connection(Msg, State, tls_connection). +connection(Type, Event, State) -> +     ssl_connection:connection(Type, Event, State, ?MODULE). -%%-------------------------------------------------------------------- -%% Description: Whenever a gen_fsm receives an event sent using -%% gen_fsm:send_all_state_event/2, this function is called to handle -%% the event. Not currently used! -%%-------------------------------------------------------------------- -handle_event(_Event, StateName, State) -> -    {next_state, StateName, State, get_timeout(State)}. +downgrade(Type, Event, State) -> +     ssl_connection:downgrade(Type, Event, State, ?MODULE). -%%-------------------------------------------------------------------- -%% Description: Whenever a gen_fsm receives an event sent using -%% gen_fsm:sync_send_all_state_event/2,3, this function is called to handle -%% the event. -%%-------------------------------------------------------------------- -handle_sync_event(Event, From, StateName, State) -> -    ssl_connection:handle_sync_event(Event, From, StateName, State).  %%--------------------------------------------------------------------  %% Description: This function is called by a gen_fsm when it receives any @@ -301,26 +301,25 @@ handle_sync_event(Event, From, StateName, State) ->  %% raw data from socket, unpack records  handle_info({Protocol, _, Data}, StateName,              #state{data_tag = Protocol} = State0) -> -    %% Simplify for now to avoid dialzer warnings before implementation is  compleate -    %% case next_tls_record(Data, State0) of -    %% 	{Record, State} -> -    %% 	    next_state(StateName, StateName, Record, State); -    %% 	#alert{} = Alert -> -    %% 	    handle_normal_shutdown(Alert, StateName, State0),  -    %% 	    {stop, {shutdown, own_alert}, State0} -    %% end; -    {Record, State} = next_tls_record(Data, State0),  -    next_state(StateName, StateName, Record, State); - +     case next_tls_record(Data, State0) of +	{Record, State} -> +	    next_event(StateName, Record, State); +	#alert{} = Alert -> +	    handle_normal_shutdown(Alert, StateName, State0),  +	    {stop, {shutdown, own_alert}} +     end;  handle_info({CloseTag, Socket}, StateName,              #state{socket = Socket, close_tag = CloseTag,  		   negotiated_version = _Version} = State) ->      handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State), -    {stop, {shutdown, transport_closed}, State}; +    {stop, {shutdown, transport_closed}};  handle_info(Msg, StateName, State) ->      ssl_connection:handle_info(Msg, StateName, State). +handle_call(Event, From, StateName, State) -> +    ssl_connection:handle_call(Event, From, StateName, State, ?MODULE). +  %%--------------------------------------------------------------------  %% Description:This function is called by a gen_fsm when it is about  %% to terminate. It should be the opposite of Module:init/1 and do any @@ -335,7 +334,10 @@ terminate(Reason, StateName, State) ->  %% Description: Convert process state when code is changed  %%--------------------------------------------------------------------  code_change(_OldVsn, StateName, State, _Extra) -> -    {ok, StateName, State}. +    {?GEN_STATEM_CB_MODE, StateName, State}. + +format_status(Type, Data) -> +    ssl_connection:format_status(Type, Data).  %%--------------------------------------------------------------------  %%% Internal functions @@ -372,96 +374,28 @@ next_record(#state{socket = Socket,  next_record(State) ->      {no_record, State}. -next_state(Current,_, #alert{} = Alert, #state{negotiated_version = Version} = State) -> -    handle_own_alert(Alert, Version, Current, State); - -next_state(_,Next, no_record, State) -> -    {next_state, Next, State, get_timeout(State)}; - -%% next_state(_,Next, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, State) -> -%%     Alerts = decode_alerts(EncAlerts), -%%     handle_alerts(Alerts,  {next_state, Next, State, get_timeout(State)}); - -next_state(Current, Next, #ssl_tls{type = ?HANDSHAKE, fragment = Data}, -	   State0 = #state{protocol_buffers = -			       #protocol_buffers{dtls_handshake_buffer = Buf0} = Buffers, -			   negotiated_version = Version}) -> -    Handle =  -   	fun({#hello_request{} = Packet, _}, {next_state, connection = SName, State}) -> -   		%% This message should not be included in handshake -   		%% message hashes. Starts new handshake (renegotiation) -		Hs0 = ssl_handshake:init_handshake_history(), -		?MODULE:SName(Packet, State#state{tls_handshake_history=Hs0, -   						  renegotiation = {true, peer}}); -   	   ({#hello_request{} = Packet, _}, {next_state, SName, State}) -> -   		%% This message should not be included in handshake -   		%% message hashes. Already in negotiation so it will be ignored! -   		?MODULE:SName(Packet, State); -	   ({#client_hello{} = Packet, Raw}, {next_state, connection = SName, State}) -> -		Version = Packet#client_hello.client_version, -		Hs0 = ssl_handshake:init_handshake_history(), -		Hs1 = ssl_handshake:update_handshake_history(Hs0, Raw), -		?MODULE:SName(Packet, State#state{tls_handshake_history=Hs1, -   						  renegotiation = {true, peer}}); -	   ({Packet, Raw}, {next_state, SName, State = #state{tls_handshake_history=Hs0}}) -> -		Hs1 = ssl_handshake:update_handshake_history(Hs0, Raw), -		?MODULE:SName(Packet, State#state{tls_handshake_history=Hs1}); -   	   (_, StopState) -> StopState -   	end, -    try -	{Packets, Buf} = tls_handshake:get_tls_handshake(Version,Data,Buf0), -	State = State0#state{protocol_buffers = -				 Buffers#protocol_buffers{dtls_packets = Packets, -							  dtls_handshake_buffer = Buf}}, -	handle_dtls_handshake(Handle, Next, State) -    catch throw:#alert{} = Alert -> -	    handle_own_alert(Alert, Version, Current, State0) -    end; -next_state(_, StateName, #ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, State0) -> -    %% Simplify for now to avoid dialzer warnings before implementation is  compleate -    %% case read_application_data(Data, State0) of -    %% 	Stop = {stop,_,_} -> -    %% 	    Stop; -    %% 	{Record, State} -> -    %% 	    next_state(StateName, StateName, Record, State) -    %% end; -    {Record, State} = read_application_data(Data, State0), -    next_state(StateName, StateName, Record, State); -	 -next_state(Current, Next, #ssl_tls{type = ?CHANGE_CIPHER_SPEC, fragment = <<1>>} =  - 	   _ChangeCipher,  - 	   #state{connection_states = ConnectionStates0} = State0) -> -    ConnectionStates1 = -	ssl_record:activate_pending_connection_state(ConnectionStates0, read), -    {Record, State} = next_record(State0#state{connection_states = ConnectionStates1}), -    next_state(Current, Next, Record, State); -next_state(Current, Next, #ssl_tls{type = _Unknown}, State0) -> -    %% Ignore unknown type  -    {Record, State} = next_record(State0), -    next_state(Current, Next, Record, State). - -handle_dtls_handshake(Handle, StateName, -		     #state{protocol_buffers = -				#protocol_buffers{dtls_packets = [Packet]} = Buffers} = State) -> -    FsmReturn = {next_state, StateName, State#state{protocol_buffers = -							Buffers#protocol_buffers{dtls_packets = []}}}, -    Handle(Packet, FsmReturn); - -handle_dtls_handshake(Handle, StateName, -		     #state{protocol_buffers = -				#protocol_buffers{dtls_packets = [Packet | Packets]} = Buffers} = -			 State0) -> -    FsmReturn = {next_state, StateName, State0#state{protocol_buffers = -							 Buffers#protocol_buffers{dtls_packets = -										      Packets}}}, -    case Handle(Packet, FsmReturn) of -	{next_state, NextStateName, State, _Timeout} -> -	    handle_dtls_handshake(Handle, NextStateName, State); -	{stop, _,_} = Stop -> -	    Stop -    end. +next_event(StateName, Record, State) -> +    next_event(StateName, Record, State, []). +next_event(connection = StateName, no_record, State0, Actions) -> +    case next_record_if_active(State0) of +	{no_record, State} -> +	    ssl_connection:hibernate_after(StateName, State, Actions); +	{#ssl_tls{} = Record, State} -> +	    {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; +	{#alert{} = Alert, State} -> +	    {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} +    end; +next_event(StateName, Record, State, Actions) -> +    case Record of +	no_record -> +	    {next_state, StateName, State, Actions}; +	#ssl_tls{} = Record -> +	    {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; +	#alert{} = Alert -> +	    {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} +    end.  send_flight(Fragments, #state{transport_cb = Transport, socket = Socket,  			      protocol_buffers = _PBuffers} = State) -> @@ -514,21 +448,23 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User,  	   renegotiation = {false, first},  	   allow_renegotiate = SSLOptions#ssl_options.client_renegotiation,  	   start_or_recv_from = undefined, -	   send_queue = queue:new(),  	   protocol_cb = ?MODULE  	  }.  read_application_data(_,State) ->      {#ssl_tls{fragment = <<"place holder">>}, State}. -	 + +next_tls_record(<<>>, _State) -> +    #alert{}; %% Place holder  next_tls_record(_, State) ->      {#ssl_tls{fragment = <<"place holder">>}, State}. -get_timeout(_) -> %% Place holder -    infinity. - -next_state_connection(_, State) -> %% Place holder -    {next_state, connection, State, get_timeout(State)}. -  sequence(_) ->       %%TODO real imp      1. +next_record_if_active(State =  +		      #state{socket_options =  +			     #socket_options{active = false}}) ->  +    {no_record ,State}; + +next_record_if_active(State) -> +    next_record(State). | 
