diff options
| author | Ingela Anderton Andin <[email protected]> | 2016-09-20 20:58:34 +0200 | 
|---|---|---|
| committer | Ingela Anderton Andin <[email protected]> | 2016-12-05 10:59:51 +0100 | 
| commit | 1e6942e97339ff39a0436834c260bf50c3d3a481 (patch) | |
| tree | a7b277fe925fc74dd186faf64f3d24ac066de5bd /lib/ssl/src | |
| parent | 39fef8eec26c2903ef11ba8829e1f625906b3e11 (diff) | |
| download | otp-1e6942e97339ff39a0436834c260bf50c3d3a481.tar.gz otp-1e6942e97339ff39a0436834c260bf50c3d3a481.tar.bz2 otp-1e6942e97339ff39a0436834c260bf50c3d3a481.zip | |
ssl: Implement DTLS state machine
Beta DTLS, not production ready. Only very basically tested, and
not everything in the SPEC is implemented and some things
are hard coded that should not be, so this implementation can not be consider
secure.
Refactor "TLS connection state" and socket handling, to facilitate
DTLS implementation.
Create dtls "listner" (multiplexor) process that spawns
DTLS connection process handlers.
Handle DTLS fragmentation.
Framework for handling retransmissions.
Replay Detection is not implemented yet.
Alerts currently always handled as in TLS.
Diffstat (limited to 'lib/ssl/src')
26 files changed, 1816 insertions, 1188 deletions
| diff --git a/lib/ssl/src/Makefile b/lib/ssl/src/Makefile index b625db0656..e96b1b971f 100644 --- a/lib/ssl/src/Makefile +++ b/lib/ssl/src/Makefile @@ -50,6 +50,7 @@ MODULES= \  	ssl_app \  	ssl_dist_sup\  	ssl_sup \ +	dtls_udp_sup \  	inet_tls_dist \  	inet6_tls_dist \  	ssl_certificate\ @@ -71,7 +72,9 @@ MODULES= \  	ssl_crl\  	ssl_crl_cache \  	ssl_crl_hash_dir \ -	ssl_socket \ +	tls_socket \ +	dtls_socket \ +	dtls_udp_listener\  	ssl_listen_tracker_sup \  	tls_record \  	dtls_record \ diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index 4f1f050e4b..68c6f93579 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -48,12 +48,12 @@  	 select_sni_extension/1]).  %% Alert and close handling --export([send_alert/2, close/5]). +-export([encode_alert/3,send_alert/2, close/5]).  %% Data handling --export([passive_receive/2,  next_record_if_active/1, handle_common_event/4 -	]). +-export([encode_data/3, passive_receive/2,  next_record_if_active/1, handle_common_event/4, +	 send/3]).  %% gen_statem state functions  -export([init/3, error/3, downgrade/3, %% Initiation and take down states @@ -93,61 +93,120 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} =  	    Error      end. -send_handshake(Handshake, State) -> -    send_handshake_flight(queue_handshake(Handshake, State)). +send_handshake(Handshake, #state{connection_states = ConnectionStates} = States) -> +    #{epoch := Epoch} = ssl_record:current_connection_state(ConnectionStates, write), +    send_handshake_flight(queue_handshake(Handshake, States), Epoch). + +queue_handshake(Handshake0, #state{tls_handshake_history = Hist0,  +				   negotiated_version = Version, +				   flight_buffer = #{handshakes := HsBuffer0, +						     change_cipher_spec := undefined, +						     next_sequence := Seq} = Flight0} = State) -> +    Handshake = dtls_handshake:encode_handshake(Handshake0, Version, Seq), +    Hist = update_handshake_history(Handshake0, Handshake, Hist0), +    State#state{flight_buffer = Flight0#{handshakes => [Handshake | HsBuffer0], +					 next_sequence => Seq +1}, +		tls_handshake_history = Hist}; + +queue_handshake(Handshake0, #state{tls_handshake_history = Hist0,  +				   negotiated_version = Version, +				   flight_buffer = #{handshakes_after_change_cipher_spec := Buffer0, +						     next_sequence := Seq} = Flight0} = State) -> +    Handshake = dtls_handshake:encode_handshake(Handshake0, Version, Seq), +    Hist = update_handshake_history(Handshake0, Handshake, Hist0), +    State#state{flight_buffer = Flight0#{handshakes_after_change_cipher_spec => [Handshake | Buffer0], +					 next_sequence => Seq +1}, +		tls_handshake_history = Hist}. -queue_flight_buffer(Msg, #state{negotiated_version = Version, -				connection_states = ConnectionStates, -				flight_buffer = Flight} = State) -> -    ConnectionState =  -	ssl_record:current_connection_state(ConnectionStates, write), -    Epoch = maps:get(epoch, ConnectionState), -    State#state{flight_buffer = Flight ++ [{Version, Epoch, Msg}]}. - -queue_handshake(Handshake, #state{negotiated_version = Version, -				  tls_handshake_history = Hist0, -				  connection_states = ConnectionStates0} = State0) -> -    {Frag, ConnectionStates, Hist} = -	encode_handshake(Handshake, Version, ConnectionStates0, Hist0), -    queue_flight_buffer(Frag, State0#state{connection_states = ConnectionStates, -					   tls_handshake_history = Hist}).  send_handshake_flight(#state{socket = Socket,  			     transport_cb = Transport, -			     flight_buffer = Flight, -			     connection_states = ConnectionStates0} = State0) -> - +			     flight_buffer = #{handshakes := Flight, +					       change_cipher_spec := undefined}, +			     negotiated_version = Version, +			     connection_states = ConnectionStates0} = State0, Epoch) -> +    %% TODO remove hardcoded Max size      {Encoded, ConnectionStates} = -	encode_handshake_flight(Flight, ConnectionStates0), +	encode_handshake_flight(lists:reverse(Flight), Version, 1400, Epoch, ConnectionStates0), +    send(Transport, Socket, Encoded), +    start_flight(State0#state{connection_states = ConnectionStates}); + +send_handshake_flight(#state{socket = Socket, +			     transport_cb = Transport, +			     flight_buffer = #{handshakes := [_|_] = Flight0, +					       change_cipher_spec := ChangeCipher, +					       handshakes_after_change_cipher_spec := []}, +			     negotiated_version = Version, +			     connection_states = ConnectionStates0} = State0, Epoch) ->       +    {HsBefore, ConnectionStates1} = +	encode_handshake_flight(lists:reverse(Flight0), Version, 1400, Epoch, ConnectionStates0), +    {EncChangeCipher, ConnectionStates} = encode_change_cipher(ChangeCipher, Version, Epoch, ConnectionStates1), + +    send(Transport, Socket, [HsBefore, EncChangeCipher]), +    start_flight(State0#state{connection_states = ConnectionStates}); -    Transport:send(Socket, Encoded), -    State0#state{flight_buffer = [], connection_states = ConnectionStates}. +send_handshake_flight(#state{socket = Socket, +			     transport_cb = Transport, +			     flight_buffer = #{handshakes := [_|_] = Flight0, +					       change_cipher_spec := ChangeCipher, +					       handshakes_after_change_cipher_spec := Flight1}, +			     negotiated_version = Version, +			     connection_states = ConnectionStates0} = State0, Epoch) ->       +    {HsBefore, ConnectionStates1} = +	encode_handshake_flight(lists:reverse(Flight0), Version, 1400, Epoch-1, ConnectionStates0), +    {EncChangeCipher, ConnectionStates2} =  +	encode_change_cipher(ChangeCipher, Version, Epoch-1, ConnectionStates1), +    {HsAfter, ConnectionStates} = +	encode_handshake_flight(lists:reverse(Flight1), Version, 1400, Epoch, ConnectionStates2), +    send(Transport, Socket, [HsBefore, EncChangeCipher, HsAfter]), +    start_flight(State0#state{connection_states = ConnectionStates}); -queue_change_cipher(Msg, State) -> -    queue_flight_buffer(Msg, State). +send_handshake_flight(#state{socket = Socket, +			     transport_cb = Transport, +			     flight_buffer = #{handshakes := [], +					       change_cipher_spec := ChangeCipher, +					       handshakes_after_change_cipher_spec := Flight1}, +			     negotiated_version = Version, +			     connection_states = ConnectionStates0} = State0, Epoch) -> +    {EncChangeCipher, ConnectionStates1} =  +	encode_change_cipher(ChangeCipher, Version, Epoch-1, ConnectionStates0), +    {HsAfter, ConnectionStates} = +	encode_handshake_flight(lists:reverse(Flight1), Version, 1400, Epoch, ConnectionStates1), +    send(Transport, Socket, [EncChangeCipher, HsAfter]), +    start_flight(State0#state{connection_states = ConnectionStates}). + +queue_change_cipher(ChangeCipher, #state{flight_buffer = Flight, +					 connection_states = ConnectionStates0} = State) ->  +    ConnectionStates =  +	dtls_record:next_epoch(ConnectionStates0, write),  +    State#state{flight_buffer = Flight#{change_cipher_spec => ChangeCipher}, +		connection_states = ConnectionStates}.  send_alert(Alert, #state{negotiated_version = Version,  			 socket = Socket,  			 transport_cb = Transport,  			 connection_states = ConnectionStates0} = State0) ->      {BinMsg, ConnectionStates} = -	ssl_alert:encode(Alert, Version, ConnectionStates0), -    Transport:send(Socket, BinMsg), +	encode_alert(Alert, Version, ConnectionStates0), +    send(Transport, Socket, BinMsg),      State0#state{connection_states = ConnectionStates}.  close(downgrade, _,_,_,_) ->      ok;  %% Other  close(_, Socket, Transport, _,_) -> -    Transport:close(Socket). +    dtls_socket:close(Transport,Socket).  reinit_handshake_data(#state{protocol_buffers = Buffers} = State) ->      State#state{premaster_secret = undefined,  		public_key_info = undefined,  		tls_handshake_history = ssl_handshake:init_handshake_history(),  		protocol_buffers = -		    Buffers#protocol_buffers{dtls_fragment_state = -						 dtls_handshake:dtls_handshake_new_flight(0)}}. +		    Buffers#protocol_buffers{ +		      dtls_handshake_next_seq = 0, +		      dtls_handshake_next_fragments = [], +		      dtls_handshake_later_fragments = [] +		     }}.  select_sni_extension(#client_hello{extensions = HelloExtensions}) ->      HelloExtensions#hello_extensions.sni; @@ -160,7 +219,7 @@ select_sni_extension(_) ->  %%--------------------------------------------------------------------  -spec start_link(atom(), host(), inet:port_number(), port(), list(), pid(), tuple()) -> -    {ok, pid()} | ignore |  {error, reason()}. +			{ok, pid()} | ignore |  {error, reason()}.  %%  %% Description: Creates a gen_fsm process which calls Module:init/1 to  %% initialize. To ensure a synchronized start-up procedure, this function @@ -191,7 +250,6 @@ init({call, From}, {start, Timeout},       #state{host = Host, port = Port, role = client,  	    ssl_options = SslOpts,  	    session = #session{own_certificate = Cert} = Session0, -	    transport_cb = Transport, socket = Socket,  	    connection_states = ConnectionStates0,  	    renegotiation = {Renegotiation, _},  	    session_cache = Cache, @@ -199,23 +257,26 @@ init({call, From}, {start, Timeout},  	   } = State0) ->      Timer = ssl_connection:start_or_recv_cancel_timer(Timeout, From),      Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, -				       Cache, CacheCb, Renegotiation, Cert), -     +					Cache, CacheCb, Renegotiation, Cert), +      Version = Hello#client_hello.client_version,      HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions), -    Handshake0 = ssl_handshake:init_handshake_history(), -    {BinMsg, ConnectionStates, Handshake} = -        encode_handshake(Hello,  HelloVersion, ConnectionStates0, Handshake0), -    Transport:send(Socket, BinMsg), -    State1 = State0#state{connection_states = ConnectionStates, -			  negotiated_version = Version, %% Requested version +    State1 = prepare_flight(State0#state{negotiated_version = Version}), +    State2 = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),   +    State3 = State2#state{negotiated_version = Version, %% Requested version  			  session =  			      Session0#session{session_id = Hello#client_hello.session_id}, -			  tls_handshake_history = Handshake,  			  start_or_recv_from = From,  			  timer = Timer}, -    {Record, State} = next_record(State1), +    {Record, State} = next_record(State3),      next_event(hello, Record, State); +init({call, _} = Type, Event, #state{role = server, transport_cb = gen_udp} = State) -> +    ssl_connection:init(Type, Event,  +			State#state{flight_state = {waiting, undefined, ?INITIAL_RETRANSMIT_TIMEOUT}}, +			?MODULE); +init({call, _} = Type, Event, #state{role = server} = State) -> +    %% I.E. DTLS over sctp +    ssl_connection:init(Type, Event, State#state{flight_state = reliable}, ?MODULE);  init(Type, Event, State) ->      ssl_connection:init(Type, Event, State, ?MODULE). @@ -232,34 +293,53 @@ error(_, _, _) ->  	    #state{}) ->  		   gen_statem:state_function_result().  %%-------------------------------------------------------------------- -hello(internal, #client_hello{client_version = ClientVersion} = Hello, -      State = #state{connection_states = ConnectionStates0, -		     port = Port, session = #session{own_certificate = Cert} = Session0, -		     renegotiation = {Renegotiation, _}, -		     session_cache = Cache, -		     session_cache_cb = CacheCb, -		     negotiated_protocol = CurrentProtocol, -		     key_algorithm = KeyExAlg, -		     ssl_options = SslOpts}) -> - -    case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb, -					      ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of -	#alert{} = Alert -> -	   ssl_connection:handle_own_alert(Alert, ClientVersion, hello, State); -	{Version, {Type, Session}, -	 ConnectionStates, Protocol0, ServerHelloExt, HashSign} -> -	    Protocol = case Protocol0 of -			   undefined -> CurrentProtocol; -			   _ -> Protocol0 -		       end, - -	    ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt}, -				 State#state{connection_states	= ConnectionStates, -					     negotiated_version = Version, -					     hashsign_algorithm = HashSign, -					     session = Session, -					     negotiated_protocol = Protocol}, ?MODULE) +hello(internal, #client_hello{cookie = <<>>, +			      client_version = Version} = Hello, #state{role = server, +									transport_cb = Transport, +									socket = Socket} = State0) -> +    %% TODO: not hard code key +    {ok, {IP, Port}} = dtls_socket:peername(Transport, Socket), +    Cookie = dtls_handshake:cookie(<<"secret">>, IP, Port, Hello), +    VerifyRequest = dtls_handshake:hello_verify_request(Cookie, Version), +    State1 = prepare_flight(State0#state{negotiated_version = Version}), +    State2 = send_handshake(VerifyRequest, State1), +    {Record, State} = next_record(State2), +    next_event(hello, Record, State#state{tls_handshake_history = ssl_handshake:init_handshake_history()}); +hello(internal, #client_hello{cookie = Cookie} = Hello, #state{role = server, +							       transport_cb = Transport, +							       socket = Socket} = State0) -> +    {ok, {IP, Port}} = dtls_socket:peername(Transport, Socket), +    %% TODO: not hard code key +    case dtls_handshake:cookie(<<"secret">>, IP, Port, Hello) of +	Cookie -> +	    handle_client_hello(Hello, State0); +	_ -> +	    %% Handle bad cookie as new cookie request RFC 6347 4.1.2 +	    hello(internal, Hello#client_hello{cookie = <<>>}, State0)       end; +hello(internal, #hello_verify_request{cookie = Cookie}, #state{role = client, +							       host = Host, port = Port,  +							       ssl_options = SslOpts, +							       session = #session{own_certificate = OwnCert}  +							       = Session0, +							       connection_states = ConnectionStates0, +							       renegotiation = {Renegotiation, _}, +							       session_cache = Cache, +							       session_cache_cb = CacheCb +							      } = State0) -> +    State1 = prepare_flight(State0#state{tls_handshake_history = ssl_handshake:init_handshake_history()}), +    Hello = dtls_handshake:client_hello(Host, Port, Cookie, ConnectionStates0, +					SslOpts, +					Cache, CacheCb, Renegotiation, OwnCert), +    Version = Hello#client_hello.client_version, +    HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions), +    State2 = send_handshake(Hello, State1#state{negotiated_version = HelloVersion}),  +    State3 = State2#state{negotiated_version = Version, %% Requested version +			  session = +			      Session0#session{session_id =  +						   Hello#client_hello.session_id}}, +    {Record, State} = next_record(State3), +    next_event(hello, Record, State);  hello(internal, #server_hello{} = Hello,        #state{connection_states = ConnectionStates0,  	     negotiated_version = ReqVersion, @@ -273,24 +353,49 @@ hello(internal, #server_hello{} = Hello,  	    ssl_connection:handle_session(Hello,   					  Version, NewId, ConnectionStates, ProtoExt, Protocol, State)      end; +hello(internal, {handshake, {#client_hello{cookie = <<>>} = Handshake, _}}, State) -> +    %% Initial hello should not be in handshake history +    {next_state, hello, State, [{next_event, internal, Handshake}]}; + +hello(internal, {handshake, {#hello_verify_request{} = Handshake, _}}, State) -> +    %% hello_verify should not be in handshake history +    {next_state, hello, State, [{next_event, internal, Handshake}]}; +  hello(info, Event, State) ->      handle_info(Event, hello, State); -  hello(Type, Event, State) ->      ssl_connection:hello(Type, Event, State, ?MODULE).  abbreviated(info, Event, State) ->      handle_info(Event, abbreviated, State); +abbreviated(internal = Type,  +	    #change_cipher_spec{type = <<1>>} = Event,  +	    #state{connection_states = ConnectionStates0} = State) -> +    ConnectionStates1 = dtls_record:save_current_connection_state(ConnectionStates0, read), +    ConnectionStates = dtls_record:next_epoch(ConnectionStates1, read), +    ssl_connection:abbreviated(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE); +abbreviated(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) -> +    ssl_connection:cipher(Type, Event, prepare_flight(State#state{connection_states = ConnectionStates}), ?MODULE);  abbreviated(Type, Event, State) ->      ssl_connection:abbreviated(Type, Event, State, ?MODULE).  certify(info, Event, State) ->      handle_info(Event, certify, State); +certify(internal = Type, #server_hello_done{} = Event, State) -> +    ssl_connection:certify(Type, Event, prepare_flight(State), ?MODULE);  certify(Type, Event, State) ->      ssl_connection:certify(Type, Event, State, ?MODULE).  cipher(info, Event, State) ->      handle_info(Event, cipher, State); +cipher(internal = Type, #change_cipher_spec{type = <<1>>} = Event,   +       #state{connection_states = ConnectionStates0} = State) -> +    ConnectionStates1 = dtls_record:save_current_connection_state(ConnectionStates0, read), +    ConnectionStates = dtls_record:next_epoch(ConnectionStates1, read), +    ssl_connection:cipher(Type, Event, State#state{connection_states = ConnectionStates}, ?MODULE); +cipher(internal = Type, #finished{} = Event, #state{connection_states = ConnectionStates} = State) -> +    ssl_connection:cipher(Type, Event,  +			  prepare_flight(State#state{connection_states = ConnectionStates}), ?MODULE);  cipher(Type, Event, State) ->       ssl_connection:cipher(Type, Event, State, ?MODULE). @@ -310,7 +415,6 @@ connection(internal, #hello_request{}, #state{host = Host, port = Port,  	  State1#state{session = Session0#session{session_id  						  = Hello#client_hello.session_id}}),      next_event(hello, Record, State); -  connection(internal, #client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) ->      %% Mitigate Computational DoS attack      %% http://www.educatedguesswork.org/2011/10/ssltls_and_computational_dos.html @@ -319,14 +423,11 @@ connection(internal, #client_hello{} = Hello, #state{role = server, allow_renego      %% renegotiations immediately after each other.      erlang:send_after(?WAIT_TO_ALLOW_RENEGOTIATION, self(), allow_renegotiate),      {next_state, hello, State#state{allow_renegotiate = false}, [{next_event, internal, Hello}]}; - -  connection(internal, #client_hello{}, #state{role = server, allow_renegotiate = false} = State0) ->      Alert = ?ALERT_REC(?WARNING, ?NO_RENEGOTIATION),      State1 = send_alert(Alert, State0),      {Record, State} = ssl_connection:prepare_connection(State1, ?MODULE),      next_event(connection, Record, State); -    connection(Type, Event, State) ->       ssl_connection:connection(Type, Event, State, ?MODULE). @@ -341,15 +442,25 @@ downgrade(Type, Event, State) ->  %%--------------------------------------------------------------------  %% raw data from socket, unpack records -handle_info({Protocol, _, Data}, StateName, +handle_info({_,flight_retransmission_timeout}, connection, _) -> +    {next_state, keep_state_and_data}; +handle_info({Ref, flight_retransmission_timeout}, StateName,  +	    #state{flight_state = {waiting, Ref, NextTimeout}} = State0) -> +    State1 = send_handshake_flight(State0#state{flight_state = {retransmit_timer, NextTimeout}},  +				   retransmit_epoch(StateName, State0)), +    {Record, State} = next_record(State1), +    next_event(StateName, Record, State); +handle_info({_, flight_retransmission_timeout}, _, _) -> +    {next_state, keep_state_and_data}; +handle_info({Protocol, _, _, _, Data}, StateName,              #state{data_tag = Protocol} = State0) -> -     case next_tls_record(Data, State0) of +    case next_dtls_record(Data, State0) of  	{Record, State} ->  	    next_event(StateName, Record, State);  	#alert{} = Alert ->  	    ssl_connection:handle_normal_shutdown(Alert, StateName, State0),   	    {stop, {shutdown, own_alert}} -     end; +    end;  handle_info({CloseTag, Socket}, StateName,  	    #state{socket = Socket, close_tag = CloseTag,  		   negotiated_version = Version} = State) -> @@ -380,23 +491,26 @@ handle_common_event(internal, #alert{} = Alert, StateName,      ssl_connection:handle_own_alert(Alert, Version, StateName, State);  %%% DTLS record protocol level handshake messages  -handle_common_event(internal,  #ssl_tls{type = ?HANDSHAKE} = Record,  +handle_common_event(internal, #ssl_tls{type = ?HANDSHAKE, +				       fragment = Data},   		    StateName,  -		    #state{protocol_buffers = -			       #protocol_buffers{dtls_packets = Packets0, -						 dtls_fragment_state = HsState0} = Buffers, +		    #state{protocol_buffers = Buffers0,  			   negotiated_version = Version} = State0) ->      try -	{Packets1, HsState} = dtls_handshake:get_dtls_handshake(Record, HsState0), -	State = -	    State0#state{protocol_buffers = -			     Buffers#protocol_buffers{dtls_fragment_state = HsState}}, -	Events = dtls_handshake_events(Packets0 ++ Packets1), -	case StateName of -	    connection -> -		ssl_connection:hibernate_after(StateName, State, Events); -	    _ -> -		{next_state, StateName, State, Events} +	case dtls_handshake:get_dtls_handshake(Version, Data, Buffers0) of +	    {more_data, Buffers} -> +		{Record, State} = next_record(State0#state{protocol_buffers = Buffers}), +		next_event(StateName, Record, State); +	    {Packets, Buffers} -> +		State = State0#state{protocol_buffers = Buffers}, +		Events = dtls_handshake_events(Packets), +		case StateName of +		    connection -> +			ssl_connection:hibernate_after(StateName, State, Events); +		    _ -> +			{next_state, StateName,  +			 State#state{unprocessed_handshake_events = unprocessed_events(Events)}, Events} +		end  	end      catch throw:#alert{} = Alert ->  	    ssl_connection:handle_own_alert(Alert, Version, StateName, State0) @@ -420,6 +534,10 @@ handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, Sta  handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) ->      {next_state, StateName, State}. +send(Transport, {_, {{_,_}, _} = Socket}, Data) -> +    send(Transport, Socket, Data); +send(Transport, Socket, Data) -> +   dtls_socket:send(Transport, Socket, Data).  %%--------------------------------------------------------------------  %% Description:This function is called by a gen_fsm when it is about  %% to terminate. It should be the opposite of Module:init/1 and do any @@ -442,96 +560,56 @@ format_status(Type, Data) ->  %%--------------------------------------------------------------------  %%% Internal functions  %%-------------------------------------------------------------------- +handle_client_hello(#client_hello{client_version = ClientVersion} = Hello, +		    #state{connection_states = ConnectionStates0, +			   port = Port, session = #session{own_certificate = Cert} = Session0, +			   renegotiation = {Renegotiation, _}, +			   session_cache = Cache, +			   session_cache_cb = CacheCb, +			   negotiated_protocol = CurrentProtocol, +			   key_algorithm = KeyExAlg, +			   ssl_options = SslOpts} = State0) -> +     +    case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb, +					       ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of +	#alert{} = Alert -> +	    ssl_connection:handle_own_alert(Alert, ClientVersion, hello, State0); +	{Version, {Type, Session}, +	 ConnectionStates, Protocol0, ServerHelloExt, HashSign} -> +	    Protocol = case Protocol0 of +			   undefined -> CurrentProtocol; +			   _ -> Protocol0 +		       end, -dtls_handshake_events([]) -> -    throw(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, malformed_handshake)); -dtls_handshake_events(Packets) -> -    lists:map(fun(Packet) -> -		      {next_event, internal, {handshake, Packet}} -	      end, Packets). - - -encode_handshake(Handshake, Version, ConnectionStates0, Hist0) -> -    {Seq, ConnectionStates} = sequence(ConnectionStates0), -    {EncHandshake, Frag} = dtls_handshake:encode_handshake(Handshake, Version, Seq), -    %% DTLS does not have an equivalent version to SSLv2. So v2 hello compatibility -    %% will always be false -    Hist = ssl_handshake:update_handshake_history(Hist0, EncHandshake, false), -    {Frag, ConnectionStates, Hist}. - -encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> -    dtls_record:encode_change_cipher_spec(Version, ConnectionStates). +	    State = prepare_flight(State0#state{connection_states = ConnectionStates, +						negotiated_version = Version, +						hashsign_algorithm = HashSign, +						session = Session, +						negotiated_protocol = Protocol}), +	     +	    ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt}, +				 State, ?MODULE) +    end. -encode_handshake_flight(Flight, ConnectionStates) -> -    MSS = 1400, -    encode_handshake_records(Flight, ConnectionStates, MSS, init_pack_records()). +encode_handshake_flight(Flight, Version, MaxFragmentSize, Epoch, ConnectionStates) -> +    Fragments = lists:map(fun(Handshake) -> +				  dtls_handshake:fragment_handshake(Handshake, MaxFragmentSize) +			  end, Flight), +    dtls_record:encode_handshake(Fragments, Version, Epoch, ConnectionStates). -encode_handshake_records([], CS, _MSS, Recs) -> -    {finish_pack_records(Recs), CS}; +encode_change_cipher(#change_cipher_spec{}, Version, Epoch, ConnectionStates) -> +    dtls_record:encode_change_cipher_spec(Version, Epoch, ConnectionStates). -encode_handshake_records([{Version, _Epoch, Frag = #change_cipher_spec{}}|Tail], ConnectionStates0, MSS, Recs0) -> -    {Encoded, ConnectionStates} = -        encode_change_cipher(Frag, Version, ConnectionStates0), -    Recs = append_pack_records([Encoded], MSS, Recs0), -    encode_handshake_records(Tail, ConnectionStates, MSS, Recs); - -encode_handshake_records([{Version, Epoch, {MsgType, MsgSeq, Bin}}|Tail], CS0, MSS, Recs0 = {Buf0, _}) -> -    Space = MSS - iolist_size(Buf0), -    Len = byte_size(Bin), -    {Encoded, CS} = -	encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, 0, MSS, [], CS0), -    Recs = append_pack_records(Encoded, MSS, Recs0), -    encode_handshake_records(Tail, CS, MSS, Recs). - -%% TODO: move to dtls_handshake???? -encode_handshake_record(_Version, _Epoch, _Space, _MsgType, _MsgSeq, _Len, <<>>, _Offset, _MRS, Encoded, CS) -  when length(Encoded) > 0 -> -    %% make sure we encode at least one segment (for empty messages like Server Hello Done -    {lists:reverse(Encoded), CS}; - -encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, -			Offset, MRS, Encoded0, CS0) -> -    MaxFragmentLen = Space - 25, -    {BinFragment, Rest} = -	case Bin of -	<<BinFragment0:MaxFragmentLen/bytes, Rest0/binary>> -> -	    {BinFragment0, Rest0}; -	_ -> -	    {Bin, <<>>} -    end, -    FragLength = byte_size(BinFragment), -    Frag = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(Offset), ?uint24(FragLength), BinFragment], -    %% TODO Real solution, now avoid dialyzer error {Encoded, CS} = ssl_record:encode_handshake({Epoch, Frag}, Version, CS0), -    {Encoded, CS} = ssl_record:encode_handshake(Frag, Version, CS0), -    encode_handshake_record(Version, Epoch, MRS, MsgType, MsgSeq, Len, Rest, Offset + FragLength, MRS, [Encoded|Encoded0], CS). - -init_pack_records() -> -    {[], []}. - -append_pack_records([], MSS, Recs = {Buf0, Acc0}) -> -    Remaining = MSS - iolist_size(Buf0), -    if Remaining < 12 -> -	    {[], [lists:reverse(Buf0)|Acc0]}; -       true -> -	    Recs -    end; -append_pack_records([Head|Tail], MSS, {Buf0, Acc0}) -> -    TotLen = iolist_size(Buf0) + iolist_size(Head), -    if TotLen > MSS -> -	    append_pack_records(Tail, MSS, {[Head], [lists:reverse(Buf0)|Acc0]}); -       true -> -	    append_pack_records(Tail, MSS, {[Head|Buf0], Acc0}) -    end. +encode_data(Data, Version, ConnectionStates0)-> +    dtls_record:encode_data(Data, Version, ConnectionStates0). -finish_pack_records({[], Acc}) -> -    lists:reverse(Acc); -finish_pack_records({Buf, Acc}) -> -    lists:reverse([lists:reverse(Buf)|Acc]). +encode_alert(#alert{} = Alert, Version, ConnectionStates) -> +    dtls_record:encode_alert_record(Alert, Version, ConnectionStates).  decode_alerts(Bin) ->      ssl_alert:decode(Bin). -initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User, +initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, _}, User,  	      {CbModule, DataTag, CloseTag, ErrorTag}) ->      #ssl_options{beast_mitigation = BeastMitigation} = SSLOptions,      ConnectionStates = dtls_record:init_connection_states(Role, BeastMitigation), @@ -566,10 +644,11 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User,  	   renegotiation = {false, first},  	   allow_renegotiate = SSLOptions#ssl_options.client_renegotiation,  	   start_or_recv_from = undefined, -	   protocol_cb = ?MODULE +	   protocol_cb = ?MODULE, +	   flight_buffer = new_flight()  	  }. -next_tls_record(Data, #state{protocol_buffers = #protocol_buffers{ +next_dtls_record(Data, #state{protocol_buffers = #protocol_buffers{  						   dtls_record_buffer = Buf0,  						   dtls_cipher_texts = CT0} = Buffers} = State0) ->      case dtls_record:get_dtls_records(Data, Buf0) of @@ -578,14 +657,15 @@ next_tls_record(Data, #state{protocol_buffers = #protocol_buffers{  	    next_record(State0#state{protocol_buffers =  					 Buffers#protocol_buffers{dtls_record_buffer = Buf1,  								  dtls_cipher_texts = CT1}}); -  	#alert{} = Alert ->  	    Alert -   end. +    end. -next_record(#state{%%flight = #flight{state = finished},  -		   protocol_buffers = -		       #protocol_buffers{dtls_packets = [], dtls_cipher_texts = [CT | Rest]} +next_record(#state{unprocessed_handshake_events = N} = State) when N > 0 -> +    {no_record, State#state{unprocessed_handshake_events = N-1}}; +					  +next_record(#state{protocol_buffers = +		       #protocol_buffers{dtls_cipher_texts = [CT | Rest]}  		   = Buffers,  		   connection_states = ConnStates0} = State) ->      case dtls_record:decode_cipher_text(CT, ConnStates0) of @@ -596,9 +676,15 @@ next_record(#state{%%flight = #flight{state = finished},  	#alert{} = Alert ->  	    {Alert, State}      end; -next_record(#state{socket = Socket, -		   transport_cb = Transport} = State) -> %% when FlightState =/= finished -    ssl_socket:setopts(Transport, Socket, [{active,once}]), +next_record(#state{role = server, +		   socket = {Listener, {Client, _}}, +		   transport_cb = gen_udp} = State) ->  +    dtls_udp_listener:active_once(Listener, Client, self()), +    {no_record, State}; +next_record(#state{role = client, +		   socket = {_Server, Socket}, +		   transport_cb = Transport} = State) ->  +    dtls_socket:setopts(Transport, Socket, [{active,once}]),      {no_record, State};  next_record(State) ->      {no_record, State}. @@ -624,75 +710,89 @@ passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName) ->  next_event(StateName, Record, State) ->      next_event(StateName, Record, State, []). -next_event(connection = StateName, no_record, State0, Actions) -> +next_event(connection = StateName, no_record, +	   #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State0, Actions) ->      case next_record_if_active(State0) of  	{no_record, State} ->  	    ssl_connection:hibernate_after(StateName, State, Actions); -	{#ssl_tls{} = Record, State} -> +	{#ssl_tls{epoch = CurrentEpoch} = Record, State} ->  	    {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]}; +	{#ssl_tls{epoch = Epoch, +		  type = ?HANDSHAKE, +		  version = _Version}, State1} = _Record when Epoch == CurrentEpoch-1 -> +	    State = send_handshake_flight(State1, Epoch), +	    {next_state, StateName, State, Actions}; +	{#ssl_tls{epoch = _Epoch, +		  version = _Version}, State} -> +	    %% TODO maybe buffer later epoch +	    {next_state, StateName, State, Actions};  	{#alert{} = Alert, State} ->  	    {next_state, StateName, State, [{next_event, internal, Alert} | Actions]}      end; -next_event(StateName, Record, State, Actions) -> +next_event(StateName, Record,  +	   #state{connection_states = #{current_read := #{epoch := CurrentEpoch}}} = State, Actions) ->      case Record of  	no_record ->  	    {next_state, StateName, State, Actions}; -	#ssl_tls{} = Record -> -	    {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]}; +	#ssl_tls{epoch = CurrentEpoch, +		  version = Version} = Record -> +	    {next_state, StateName,  +	     dtls_version(StateName, Version, State),  +	     [{next_event, internal, {protocol_record, Record}} | Actions]}; +	#ssl_tls{epoch = _Epoch, +		 version = _Version} = _Record -> +	    %% TODO maybe buffer later epoch +	    {next_state, StateName, State, Actions};  	#alert{} = Alert ->  	    {next_state, StateName, State, [{next_event, internal, Alert} | Actions]}      end. -%% TODO This generates dialyzer warnings, has to be handled differently. -%% handle_packet(Address, Port, Packet) -> -%%     try dtls_record:get_dtls_records(Packet, <<>>) of -%% 	%% expect client hello -%% 	{[#ssl_tls{type = ?HANDSHAKE, version = {254, _}} = Record], <<>>} -> -%% 	    handle_dtls_client_hello(Address, Port, Record); -%% 	_Other -> -%% 	    {error, not_dtls} -%%     catch -%% 	_Class:_Error -> -%% 	    {error, not_dtls} -%%     end. - -%% handle_dtls_client_hello(Address, Port, -%% 			 #ssl_tls{epoch = Epoch, sequence_number = Seq, -%% 				  version = Version} = Record) -> -%%     {[{Hello, _}], _} = -%% 	dtls_handshake:get_dtls_handshake(Record, -%% 					 dtls_handshake:dtls_handshake_new_flight(undefined)), -%%     #client_hello{client_version = {Major, Minor}, -%% 		  random = Random, -%% 		  session_id = SessionId, -%% 		  cipher_suites = CipherSuites, -%% 		  compression_methods = CompressionMethods} = Hello, -%%     CookieData = [address_to_bin(Address, Port), -%% 		  <<?BYTE(Major), ?BYTE(Minor)>>, -%% 		  Random, SessionId, CipherSuites, CompressionMethods], -%%     Cookie = crypto:hmac(sha, <<"secret">>, CookieData), - -%%     case Hello of -%% 	#client_hello{cookie = Cookie} -> -%% 	    accept; - -%% 	_ -> -%% 	    %% generate HelloVerifyRequest -%% 	    {RequestFragment, _} = dtls_handshake:encode_handshake( -%% 				     dtls_handshake:hello_verify_request(Cookie), -%% 				     Version, 0), -%% 	    HelloVerifyRequest = -%% 		dtls_record:encode_tls_cipher_text(?HANDSHAKE, Version, Epoch, Seq, RequestFragment), -%% 	    {reply, HelloVerifyRequest} -%%     end. - -%% address_to_bin({A,B,C,D}, Port) -> -%%     <<0:80,16#ffff:16,A,B,C,D,Port:16>>; -%% address_to_bin({A,B,C,D,E,F,G,H}, Port) -> -%%     <<A:16,B:16,C:16,D:16,E:16,F:16,G:16,H:16,Port:16>>. - -sequence(#{write_msg_seq := Seq} = ConnectionState) -> -    {Seq, ConnectionState#{write_msg_seq => Seq + 1}}. +dtls_version(hello, Version, #state{role = server} = State) -> +    State#state{negotiated_version = Version}; %%Inital version +dtls_version(_,_, State) -> +    State. + +prepare_flight(#state{flight_buffer = Flight, +		      connection_states = ConnectionStates0, +		      protocol_buffers =  +			  #protocol_buffers{} = Buffers} = State) -> +    ConnectionStates = dtls_record:save_current_connection_state(ConnectionStates0, write), +    State#state{flight_buffer = next_flight(Flight), +		connection_states = ConnectionStates, +		protocol_buffers = Buffers#protocol_buffers{ +				     dtls_handshake_next_fragments = [], +				     dtls_handshake_later_fragments = []}}. +new_flight() -> +    #{next_sequence => 0, +      handshakes => [], +      change_cipher_spec => undefined, +      handshakes_after_change_cipher_spec => []}. + +next_flight(Flight) -> +    Flight#{handshakes => [], +	    change_cipher_spec => undefined, +	    handshakes_after_change_cipher_spec => []}. + +	 +start_flight(#state{transport_cb = gen_udp, +		    flight_state = {retransmit_timer, Timeout}} = State) -> +    Ref = erlang:make_ref(), +    _ = erlang:send_after(Timeout, self(), {Ref, flight_retransmission_timeout}),     +    State#state{flight_state = {waiting, Ref, new_timeout(Timeout)}}; + +start_flight(State) -> +    %% No retransmision needed i.e DTLS over SCTP +    State#state{flight_state = reliable}. + +new_timeout(N) when N =< 30 ->  +    N * 2; +new_timeout(_) ->  +    60. + +dtls_handshake_events(Packets) -> +    lists:map(fun(Packet) -> +		      {next_event, internal, {handshake, Packet}} +	      end, Packets).  renegotiate(#state{role = client} = State, Actions) ->      %% Handle same way as if server requested @@ -722,3 +822,27 @@ handle_alerts([Alert | Alerts], {next_state, StateName, State}) ->       handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State));  handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) ->       handle_alerts(Alerts, ssl_connection:handle_alert(Alert, StateName, State)). + +retransmit_epoch(StateName, #state{connection_states = ConnectionStates}) -> +    #{epoch := Epoch} =  +	ssl_record:current_connection_state(ConnectionStates, write), +    case StateName of +	connection -> +	    Epoch-1; +	_ -> +	    Epoch +    end.	 +	     +update_handshake_history(#hello_verify_request{}, _, Hist) -> +    Hist; +update_handshake_history(_, Handshake, Hist) -> +    %% DTLS never needs option "v2_hello_compatible" to be true +    ssl_handshake:update_handshake_history(Hist, iolist_to_binary(Handshake), false). + +unprocessed_events(Events) -> +    %% The first handshake event will be processed immediately +    %% as it is entered first in the event queue and +    %% when it is processed there will be length(Events)-1 +    %% handshake events left to process before we should +    %% process more TLS-records received on the socket.  +    erlang:length(Events)-1. diff --git a/lib/ssl/src/dtls_connection.hrl b/lib/ssl/src/dtls_connection.hrl index ee3daa3c14..3dd78235d0 100644 --- a/lib/ssl/src/dtls_connection.hrl +++ b/lib/ssl/src/dtls_connection.hrl @@ -29,20 +29,14 @@  -include("ssl_connection.hrl").  -record(protocol_buffers, { -	  dtls_packets = [],              %%::[binary()],  % Not yet handled decode ssl/tls packets. -          dtls_record_buffer = <<>>,      %%:: binary(),   % Buffer of incomplete records -	  dtls_fragment_state,            %%:: [],         % DTLS fragments -          dtls_handshake_buffer = <<>>,   %%:: binary(),   % Buffer of incomplete handshakes -	  dtls_cipher_texts = [],         %%:: [binary()], -	  dtls_cipher_texts_next          %%:: [binary()]  % Received for Epoch not yet active +          dtls_record_buffer = <<>>,      %% Buffer of incomplete records +	  dtls_handshake_next_seq = 0, +	  dtls_flight_last, +	  dtls_handshake_next_fragments = [], %% Fragments of the next handshake message +	  dtls_handshake_later_fragments = [], %% Fragments of handsake messages come after the one in next buffer +	  dtls_cipher_texts = []         %%:: [binary()],  	 }). --record(flight, { -	  last_retransmit, -	  last_read_seq, -	  msl_timer, -	  state, -	  buffer        % buffer of not yet ACKed TLS records -	 }). +-define(INITIAL_RETRANSMIT_TIMEOUT, 1000). %1 sec  -endif. % -ifdef(dtls_connection). diff --git a/lib/ssl/src/dtls_connection_sup.erl b/lib/ssl/src/dtls_connection_sup.erl index dc7601a684..7d7be5743d 100644 --- a/lib/ssl/src/dtls_connection_sup.erl +++ b/lib/ssl/src/dtls_connection_sup.erl @@ -60,7 +60,7 @@ init(_O) ->      StartFunc = {dtls_connection, start_link, []},      Restart = temporary, % E.g. should not be restarted      Shutdown = 4000, -    Modules = [dtls_connection], +    Modules = [dtls_connection, ssl_connection],      Type = worker,      ChildSpec = {Name, StartFunc, Restart, Shutdown, Type, Modules}, diff --git a/lib/ssl/src/dtls_handshake.erl b/lib/ssl/src/dtls_handshake.erl index c6535d5928..af3708ddb7 100644 --- a/lib/ssl/src/dtls_handshake.erl +++ b/lib/ssl/src/dtls_handshake.erl @@ -18,15 +18,15 @@  %% %CopyrightEnd%  -module(dtls_handshake). +-include("dtls_connection.hrl").  -include("dtls_handshake.hrl").  -include("dtls_record.hrl").  -include("ssl_internal.hrl").  -include("ssl_alert.hrl"). --export([client_hello/8, client_hello/9, hello/4, -	 hello_verify_request/1, get_dtls_handshake/2, -	 dtls_handshake_new_flight/1, dtls_handshake_new_epoch/1, -	 encode_handshake/3]). +-export([client_hello/8, client_hello/9, cookie/4, hello/4,  +	 hello_verify_request/2, get_dtls_handshake/3, fragment_handshake/2, +	 handshake_bin/2, encode_handshake/3]).  -type dtls_handshake() :: #client_hello{} | #hello_verify_request{} |   			  ssl_handshake:ssl_handshake(). @@ -62,9 +62,10 @@ client_hello(Host, Port, Cookie, ConnectionStates,      Version =  dtls_record:highest_protocol_version(Versions),      Pending = ssl_record:pending_connection_state(ConnectionStates, read),      SecParams = maps:get(security_parameters, Pending), -    CipherSuites = ssl_handshake:available_suites(UserSuites, Version), +    TLSVersion = dtls_v1:corresponding_tls_version(Version), +    CipherSuites = ssl_handshake:available_suites(UserSuites, TLSVersion), -    Extensions = ssl_handshake:client_hello_extensions(Host, dtls_v1:corresponding_tls_version(Version), CipherSuites, +    Extensions = ssl_handshake:client_hello_extensions(Host, TLSVersion, CipherSuites,  						SslOpts, ConnectionStates, Renegotiation),      Id = ssl_session:client_id({Host, Port, SslOpts}, Cache, CacheCb, OwnCert), @@ -96,72 +97,51 @@ hello(#client_hello{client_version = ClientVersion} = Hello,        #ssl_options{versions = Versions} = SslOpts,        Info, Renegotiation) ->      Version = ssl_handshake:select_version(dtls_record, ClientVersion, Versions), -    %% -    %% TODO: handle Cipher Fallback -    %%      handle_client_hello(Version, Hello, SslOpts, Info, Renegotiation). --spec hello_verify_request(binary()) -> #hello_verify_request{}. +cookie(Key, Address, Port, #client_hello{client_version = {Major, Minor}, +					 random = Random, +					 session_id = SessionId, +					 cipher_suites = CipherSuites, +					 compression_methods = CompressionMethods}) -> +    CookieData = [address_to_bin(Address, Port), +		  <<?BYTE(Major), ?BYTE(Minor)>>, +		  Random, SessionId, CipherSuites, CompressionMethods], +    crypto:hmac(sha, Key, CookieData). + +-spec hello_verify_request(binary(),  dtls_record:dtls_version()) -> #hello_verify_request{}.  %%  %% Description: Creates a hello verify request message sent by server to  %% verify client  %%-------------------------------------------------------------------- -hello_verify_request(Cookie) -> -    %% TODO: DTLS Versions????? -    #hello_verify_request{protocol_version = {254, 255}, cookie = Cookie}. +hello_verify_request(Cookie, Version) -> +    #hello_verify_request{protocol_version = Version, cookie = Cookie}.  %%-------------------------------------------------------------------- -%% %%-------------------------------------------------------------------- -encode_handshake(Handshake, Version, MsgSeq) -> +encode_handshake(Handshake, Version, Seq) ->      {MsgType, Bin} = enc_handshake(Handshake, Version),      Len = byte_size(Bin), -    Enc = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(0), ?uint24(Len), Bin], -    Frag = {MsgType, MsgSeq, Bin}, -    {Enc, Frag}. +    [MsgType, ?uint24(Len), ?uint16(Seq), ?uint24(0), ?uint24(Len), Bin]. -%%-------------------------------------------------------------------- --spec get_dtls_handshake(#ssl_tls{}, #dtls_hs_state{} | undefined) -> -				{[dtls_handshake()], #dtls_hs_state{}} | {retransmit, #dtls_hs_state{}}. -%% -%% Description: Given a DTLS state and new data from ssl_record, collects -%% and returns it as a list of handshake messages, also returns a new -%% DTLS state -%%-------------------------------------------------------------------- -get_dtls_handshake(Records, undefined) -> -    HsState = #dtls_hs_state{highest_record_seq = 0, -			     starting_read_seq = 0, -			     fragments = gb_trees:empty(), -			     completed = []}, -    get_dtls_handshake(Records, HsState); -get_dtls_handshake(Records, HsState0) when is_list(Records) -> -    HsState1 = lists:foldr(fun get_dtls_handshake_aux/2, HsState0, Records), -    get_dtls_handshake_completed(HsState1); -get_dtls_handshake(Record, HsState0) when is_record(Record, ssl_tls) -> -    HsState1 = get_dtls_handshake_aux(Record, HsState0), -    get_dtls_handshake_completed(HsState1). +fragment_handshake(Bin, _) when is_binary(Bin)->  +    %% This is the change_cipher_spec not a "real handshake" but part of the flight +    Bin; +fragment_handshake([MsgType, Len, Seq, _, Len, Bin], Size) -> +    Bins = bin_fragments(Bin, Size), +    handshake_fragments(MsgType, Seq, Len, Bins, []). +handshake_bin([Type, Length, Data], Seq) ->	 +    handshake_bin(Type, Length, Seq, Data). +      %%-------------------------------------------------------------------- --spec dtls_handshake_new_epoch(#dtls_hs_state{}) -> #dtls_hs_state{}. +-spec get_dtls_handshake(dtls_record:dtls_version(), binary(), #protocol_buffers{}) -> +     {[{dtls_handshake(), binary()}], #protocol_buffers{}} | {more_data, #protocol_buffers{}}.  %% -%% Description: Reset the DTLS decoder state for a new Epoch +%% Description: ...  %%-------------------------------------------------------------------- -%% dtls_handshake_new_epoch(<<>>) -> -%%     dtls_hs_state_init(); -dtls_handshake_new_epoch(HsState) -> -    HsState#dtls_hs_state{highest_record_seq = 0, -			  starting_read_seq = HsState#dtls_hs_state.current_read_seq, -			  fragments = gb_trees:empty(), completed = []}. - -%-------------------------------------------------------------------- --spec dtls_handshake_new_flight(integer() | undefined) -> #dtls_hs_state{}. -% -% Description: Init the DTLS decoder state for a new Flight -dtls_handshake_new_flight(ExpectedReadReq) -> -    #dtls_hs_state{current_read_seq = ExpectedReadReq, -		   highest_record_seq = 0, -		   starting_read_seq = 0, -		   fragments = gb_trees:empty(), completed = []}. +get_dtls_handshake(Version, Fragment, ProtocolBuffers) -> +    handle_fragments(Version, Fragment, ProtocolBuffers, []).  %%--------------------------------------------------------------------  %%% Internal functions @@ -170,27 +150,29 @@ handle_client_hello(Version, #client_hello{session_id = SugesstedId,  					   cipher_suites = CipherSuites,  					   compression_methods = Compressions,  					   random = Random, -					   extensions = #hello_extensions{elliptic_curves = Curves, -									  signature_algs = ClientHashSigns} = HelloExt}, +					   extensions = +					       #hello_extensions{elliptic_curves = Curves, +								 signature_algs = ClientHashSigns} = HelloExt},  		    #ssl_options{versions = Versions,  				 signature_algs = SupportedHashSigns} = SslOpts,  		    {Port, Session0, Cache, CacheCb, ConnectionStates0, Cert, _}, Renegotiation) ->      case dtls_record:is_acceptable_version(Version, Versions) of  	true -> +	    TLSVersion = dtls_v1:corresponding_tls_version(Version),  	    AvailableHashSigns = ssl_handshake:available_signature_algs( -				   ClientHashSigns, SupportedHashSigns, Cert, -				   dtls_v1:corresponding_tls_version(Version)), -	    ECCCurve = ssl_handshake:select_curve(Curves, ssl_handshake:supported_ecc(Version)), +				   ClientHashSigns, SupportedHashSigns, Cert,TLSVersion), +	    ECCCurve = ssl_handshake:select_curve(Curves, ssl_handshake:supported_ecc(TLSVersion)),  	    {Type, #session{cipher_suite = CipherSuite} = Session1}  		= ssl_handshake:select_session(SugesstedId, CipherSuites, AvailableHashSigns, Compressions, -					       Port, Session0#session{ecc = ECCCurve}, Version, +					       Port, Session0#session{ecc = ECCCurve}, TLSVersion,  					       SslOpts, Cache, CacheCb, Cert),  	    case CipherSuite of  		no_suite ->  		    ?ALERT_REC(?FATAL, ?INSUFFICIENT_SECURITY);  		_ ->  		    {KeyExAlg,_,_,_} = ssl_cipher:suite_definition(CipherSuite), -		    case ssl_handshake:select_hashsign(ClientHashSigns, Cert, KeyExAlg, SupportedHashSigns, Version) of +		    case ssl_handshake:select_hashsign(ClientHashSigns, Cert, KeyExAlg,  +						       SupportedHashSigns, TLSVersion) of  			#alert{} = Alert ->  			    Alert;  			HashSign -> @@ -228,214 +210,15 @@ handle_server_hello_extensions(Version, SessionId, Random, CipherSuite,  	    {Version, SessionId, ConnectionStates, ProtoExt, Protocol}      end. -get_dtls_handshake_completed(HsState = #dtls_hs_state{completed = Completed}) -> -    {lists:reverse(Completed), HsState#dtls_hs_state{completed = []}}. - -get_dtls_handshake_aux(#ssl_tls{version = Version, - 				sequence_number = SeqNo, - 				fragment = Data}, HsState) -> -    get_dtls_handshake_aux(Version, SeqNo, Data, HsState). - -get_dtls_handshake_aux(Version, SeqNo, - 		       <<?BYTE(Type), ?UINT24(Length), - 			 ?UINT16(MessageSeq), - 			 ?UINT24(FragmentOffset), ?UINT24(FragmentLength), - 			 Body:FragmentLength/binary, Rest/binary>>, - 		       HsState0) -> -    case reassemble_dtls_fragment(SeqNo, Type, Length, MessageSeq, -				  FragmentOffset, FragmentLength, -				  Body, HsState0) of - 	{HsState1, HighestSeqNo, MsgBody} -> - 	    HsState2 = dec_dtls_fragment(Version, HighestSeqNo, Type, Length, MessageSeq, MsgBody, HsState1), - 	    HsState3 = process_dtls_fragments(Version, HsState2), - 	    get_dtls_handshake_aux(Version, SeqNo, Rest, HsState3); - - 	HsState2 -> - 	    HsState3 = process_dtls_fragments(Version, HsState2), - 	    get_dtls_handshake_aux(Version, SeqNo, Rest, HsState3) -     end; - -get_dtls_handshake_aux(_Version, _SeqNo, <<>>, HsState) -> -    HsState. - -dec_dtls_fragment(Version, SeqNo, Type, Length, MessageSeq, MsgBody, - 		  HsState = #dtls_hs_state{highest_record_seq = HighestSeqNo, completed = Acc}) -> -    Raw = <<?BYTE(Type), ?UINT24(Length), ?UINT16(MessageSeq), ?UINT24(0), ?UINT24(Length), MsgBody/binary>>, -    H = decode_handshake(Version, Type, MsgBody), -    HsState#dtls_hs_state{completed = [{H,Raw}|Acc], highest_record_seq = erlang:max(HighestSeqNo, SeqNo)}. - -process_dtls_fragments(Version, - 		       HsState0 = #dtls_hs_state{current_read_seq = CurrentReadSeq, - 						 fragments = Fragments0}) -> -    case gb_trees:is_empty(Fragments0) of - 	true -> - 	    HsState0; - 	_ -> - 	    case gb_trees:smallest(Fragments0) of - 		{CurrentReadSeq, {SeqNo, Type, Length, CurrentReadSeq, {Length, [{0, Length}], MsgBody}}} -> - 		    HsState1 = dtls_hs_state_process_seq(HsState0), - 		    HsState2 = dec_dtls_fragment(Version, SeqNo, Type, Length, CurrentReadSeq, MsgBody, HsState1), - 		    process_dtls_fragments(Version, HsState2); - 		_ -> - 		    HsState0 - 	    end -     end. - -dtls_hs_state_process_seq(HsState0 = #dtls_hs_state{current_read_seq = CurrentReadSeq, - 						    fragments = Fragments0}) -> -    Fragments1 = gb_trees:delete_any(CurrentReadSeq, Fragments0), -    HsState0#dtls_hs_state{current_read_seq = CurrentReadSeq + 1, - 			   fragments = Fragments1}. - -dtls_hs_state_add_fragment(MessageSeq, Fragment, HsState0 = #dtls_hs_state{fragments = Fragments0}) -> -    Fragments1 = gb_trees:enter(MessageSeq, Fragment, Fragments0), -    HsState0#dtls_hs_state{fragments = Fragments1}. - -reassemble_dtls_fragment(SeqNo, Type, Length, MessageSeq, 0, Length, - 			 Body, HsState0 = #dtls_hs_state{current_read_seq = undefined}) -  when Type == ?CLIENT_HELLO; -       Type == ?SERVER_HELLO; -        Type == ?HELLO_VERIFY_REQUEST -> -    %% First message, should be client hello -    %% return the current message and set the next expected Sequence -    %% -    %% Note: this could (should?) be restricted further, ClientHello and -    %%       HelloVerifyRequest have to have message_seq = 0, ServerHello -    %%       can have a message_seq of 0 or 1 -    %% -    {HsState0#dtls_hs_state{current_read_seq = MessageSeq + 1}, SeqNo, Body}; - -reassemble_dtls_fragment(_SeqNo, _Type, Length, _MessageSeq, _, Length, -			 _Body, HsState = #dtls_hs_state{current_read_seq = undefined}) -> -    %% not what we expected, drop it -    HsState; - -reassemble_dtls_fragment(SeqNo, _Type, Length, MessageSeq, 0, Length, - 			 Body, HsState0 = - 			     #dtls_hs_state{starting_read_seq = StartingReadSeq}) -  when MessageSeq < StartingReadSeq -> -    %% this has to be the start of a new flight, let it through -    %% -    %% Note: this could (should?) be restricted further, the first message of a -    %%       new flight has to have message_seq = 0 -    %% -    HsState = dtls_hs_state_process_seq(HsState0), -    {HsState, SeqNo, Body}; - -reassemble_dtls_fragment(_SeqNo, _Type, Length, MessageSeq, 0, Length, - 			 _Body, HsState = #dtls_hs_state{current_read_seq = CurrentReadSeq}) -  when MessageSeq < CurrentReadSeq -> -    HsState; - -reassemble_dtls_fragment(SeqNo, _Type, Length, MessageSeq, 0, Length, - 			 Body, HsState0 = #dtls_hs_state{current_read_seq = MessageSeq}) -> -    %% Message fully contained and it's the current seq -    HsState1 = dtls_hs_state_process_seq(HsState0), -    {HsState1, SeqNo, Body}; - -reassemble_dtls_fragment(SeqNo, Type, Length, MessageSeq, 0, Length, - 			 Body, HsState) -> -    %% Message fully contained and it's the NOT the current seq -> buffer -    Fragment = {SeqNo, Type, Length, MessageSeq, - 		dtls_fragment_init(Length, 0, Length, Body)}, -    dtls_hs_state_add_fragment(MessageSeq, Fragment, HsState); - -reassemble_dtls_fragment(_SeqNo, _Type, Length, MessageSeq, FragmentOffset, FragmentLength, - 			 _Body, - 			 HsState = #dtls_hs_state{current_read_seq = CurrentReadSeq}) -  when FragmentOffset + FragmentLength == Length andalso MessageSeq == (CurrentReadSeq - 1) -> -    {retransmit, HsState}; - -reassemble_dtls_fragment(_SeqNo, _Type, _Length, MessageSeq, _FragmentOffset, _FragmentLength, - 			 _Body, - 			 HsState = #dtls_hs_state{current_read_seq = CurrentReadSeq}) -  when MessageSeq < CurrentReadSeq -> -    HsState; - -reassemble_dtls_fragment(SeqNo, Type, Length, MessageSeq, - 			 FragmentOffset, FragmentLength, - 			 Body, - 			 HsState = #dtls_hs_state{fragments = Fragments0}) -> -    case gb_trees:lookup(MessageSeq, Fragments0) of - 	{value, Fragment} -> - 	    dtls_fragment_reassemble(SeqNo, Type, Length, MessageSeq, - 				     FragmentOffset, FragmentLength, - 				     Body, Fragment, HsState); - 	none -> - 	    dtls_fragment_start(SeqNo, Type, Length, MessageSeq, - 				FragmentOffset, FragmentLength, - 				Body, HsState) -    end. -dtls_fragment_start(SeqNo, Type, Length, MessageSeq, -		    FragmentOffset, FragmentLength, -		    Body, HsState = #dtls_hs_state{fragments = Fragments0}) -> -    Fragment = {SeqNo, Type, Length, MessageSeq, - 		dtls_fragment_init(Length, FragmentOffset, FragmentLength, Body)}, -     Fragments1 = gb_trees:insert(MessageSeq, Fragment, Fragments0), -    HsState#dtls_hs_state{fragments = Fragments1}. - -dtls_fragment_reassemble(SeqNo, Type, Length, MessageSeq, -			 FragmentOffset, FragmentLength, - 			 Body, - 			 {LastSeqNo, Type, Length, MessageSeq, FragBuffer0}, - 			 HsState = #dtls_hs_state{fragments = Fragments0}) -> -    FragBuffer1 = dtls_fragment_add(FragBuffer0, FragmentOffset, FragmentLength, Body), -    Fragment = {erlang:max(SeqNo, LastSeqNo), Type, Length, MessageSeq, FragBuffer1}, -    Fragments1 = gb_trees:enter(MessageSeq, Fragment, Fragments0), -    HsState#dtls_hs_state{fragments = Fragments1}; - -%% Type, Length or Seq mismatch, drop everything... -%% Note: the RFC is not clear on how to handle this... -dtls_fragment_reassemble(_SeqNo, _Type, _Length, MessageSeq, - 			 _FragmentOffset, _FragmentLength, _Body, _Fragment, - 			 HsState = #dtls_hs_state{fragments = Fragments0}) -> -    Fragments1 = gb_trees:delete_any(MessageSeq, Fragments0), -    HsState#dtls_hs_state{fragments = Fragments1}. - -dtls_fragment_add({Length, FragmentList0, Bin0}, FragmentOffset, FragmentLength, Body) -> -    Bin1 = dtls_fragment_bin_add(FragmentOffset, FragmentLength, Body, Bin0), -    FragmentList1 = add_fragment(FragmentList0, {FragmentOffset, FragmentLength}), -    {Length, FragmentList1, Bin1}. - -dtls_fragment_init(Length, 0, Length, Body) -> -    {Length, [{0, Length}], Body}; -dtls_fragment_init(Length, FragmentOffset, FragmentLength, Body) -> -    Bin = dtls_fragment_bin_add(FragmentOffset, FragmentLength, Body, <<0:(Length*8)>>), -    {Length, [{FragmentOffset, FragmentOffset + FragmentLength}], Bin}. - -dtls_fragment_bin_add(FragmentOffset, FragmentLength, Add, Buffer) -> -    <<First:FragmentOffset/bytes, _:FragmentLength/bytes, Rest/binary>> = Buffer, -    <<First/binary, Add/binary, Rest/binary>>. - -merge_fragment_list([], Fragment, Acc) -> -    lists:reverse([Fragment|Acc]); - -merge_fragment_list([H = {_, HEnd}|Rest], Frag = {FStart, _}, Acc) -  when FStart > HEnd -> -    merge_fragment_list(Rest, Frag, [H|Acc]); - -merge_fragment_list(Rest = [{HStart, _HEnd}|_], Frag = {_FStart, FEnd}, Acc) -  when FEnd < HStart -> -    lists:reverse(Acc) ++ [Frag|Rest]; - -merge_fragment_list([{HStart, HEnd}|Rest], _Frag = {FStart, FEnd}, Acc) -   when -      FStart =< HEnd orelse FEnd >= HStart -> -    Start = erlang:min(HStart, FStart), -    End = erlang:max(HEnd, FEnd), -    NewFrag = {Start, End}, -    merge_fragment_list(Rest, NewFrag, Acc). - -add_fragment(List, {FragmentOffset, FragmentLength}) -> -    merge_fragment_list(List, {FragmentOffset, FragmentOffset + FragmentLength}, []). +%%%%%%%  Encodeing   %%%%%%%%%%%%%  enc_handshake(#hello_verify_request{protocol_version = {Major, Minor},   				       cookie = Cookie}, _Version) -> -     CookieLength = byte_size(Cookie), +    CookieLength = byte_size(Cookie),      {?HELLO_VERIFY_REQUEST, <<?BYTE(Major), ?BYTE(Minor),   			      ?BYTE(CookieLength), - 			      Cookie/binary>>}; + 			      Cookie:CookieLength/binary>>};  enc_handshake(#hello_request{}, _Version) ->      {?HELLO_REQUEST, <<>>}; @@ -459,38 +242,246 @@ enc_handshake(#client_hello{client_version = {Major, Minor},  		      ?BYTE(CookieLength), Cookie/binary,  		      ?UINT16(CsLength), BinCipherSuites/binary,   		      ?BYTE(CmLength), BinCompMethods/binary, ExtensionsBin/binary>>}; + +enc_handshake(#server_hello{} = HandshakeMsg, Version) -> +    {Type, <<?BYTE(Major), ?BYTE(Minor), Rest/binary>>} =  +	ssl_handshake:encode_handshake(HandshakeMsg, Version), +    {DTLSMajor, DTLSMinor} = dtls_v1:corresponding_dtls_version({Major, Minor}), +    {Type,  <<?BYTE(DTLSMajor), ?BYTE(DTLSMinor), Rest/binary>>}; +  enc_handshake(HandshakeMsg, Version) ->      ssl_handshake:encode_handshake(HandshakeMsg, Version). -decode_handshake(_Version, ?CLIENT_HELLO, <<?BYTE(Major), ?BYTE(Minor), Random:32/binary, +bin_fragments(Bin, Size) -> +     bin_fragments(Bin, size(Bin), Size, 0, []). + +bin_fragments(Bin, BinSize,  FragSize, Offset, Fragments) -> +    case (BinSize - Offset - FragSize)  > 0 of +	true -> +	    Frag = binary:part(Bin, {Offset, FragSize}), +	    bin_fragments(Bin, BinSize, FragSize, Offset + FragSize, [{Frag, Offset} | Fragments]); +	false -> +	    Frag = binary:part(Bin, {Offset, BinSize-Offset}), +	    lists:reverse([{Frag, Offset} | Fragments]) +    end. + +handshake_fragments(_, _, _, [], Acc) -> +    lists:reverse(Acc); +handshake_fragments(MsgType, Seq, Len, [{Bin, Offset} | Bins], Acc) -> +    FragLen = size(Bin), +    handshake_fragments(MsgType, Seq, Len, Bins,  +      [<<?BYTE(MsgType), Len/binary, Seq/binary, ?UINT24(Offset), +	 ?UINT24(FragLen), Bin/binary>> | Acc]). + +address_to_bin({A,B,C,D}, Port) -> +    <<0:80,16#ffff:16,A,B,C,D,Port:16>>; +address_to_bin({A,B,C,D,E,F,G,H}, Port) -> +    <<A:16,B:16,C:16,D:16,E:16,F:16,G:16,H:16,Port:16>>. + +%%%%%%%  Decodeing   %%%%%%%%%%%%% + +handle_fragments(Version, FragmentData, Buffers0, Acc) -> +    Fragments = decode_handshake_fragments(FragmentData), +    do_handle_fragments(Version, Fragments, Buffers0, Acc). + +do_handle_fragments(_, [], Buffers, Acc) -> +    {lists:reverse(Acc), Buffers}; +do_handle_fragments(Version, [Fragment | Fragments], Buffers0, Acc) -> +    case reassemble(Version, Fragment, Buffers0) of +	{more_data, _} = More when Acc == []-> +	    More; +	{more_data, Buffers} when Fragments == [] -> +	    {lists:reverse(Acc), Buffers}; +	{more_data, Buffers} -> +	    do_handle_fragments(Version, Fragments, Buffers, Acc); +	{HsPacket, Buffers} -> +	    do_handle_fragments(Version, Fragments, Buffers, [HsPacket | Acc]) +    end. + +decode_handshake(Version, <<?BYTE(Type), Bin/binary>>) -> +    decode_handshake(Version, Type, Bin). + +decode_handshake(_, ?HELLO_REQUEST, <<>>) -> +    #hello_request{}; +decode_handshake(_Version, ?CLIENT_HELLO, <<?UINT24(_), ?UINT16(_), +					    ?UINT24(_),  ?UINT24(_),  +					    ?BYTE(Major), ?BYTE(Minor), Random:32/binary,  					    ?BYTE(SID_length), Session_ID:SID_length/binary, -					    ?BYTE(Cookie_length), Cookie:Cookie_length/binary, +					    ?BYTE(CookieLength), Cookie:CookieLength/binary,  					    ?UINT16(Cs_length), CipherSuites:Cs_length/binary,  					    ?BYTE(Cm_length), Comp_methods:Cm_length/binary,  					    Extensions/binary>>) -> -    DecodedExtensions = ssl_handshake:decode_hello_extensions(Extensions), -     +    DecodedExtensions = ssl_handshake:decode_hello_extensions({client, Extensions}), +      #client_hello{         client_version = {Major,Minor},         random = Random, -        session_id = Session_ID,         cookie = Cookie, +       session_id = Session_ID,         cipher_suites = ssl_handshake:decode_suites('2_bytes', CipherSuites),         compression_methods = Comp_methods,         extensions = DecodedExtensions -       }; +      }; + +decode_handshake(_Version, ?HELLO_VERIFY_REQUEST, <<?UINT24(_), ?UINT16(_), +						    ?UINT24(_),  ?UINT24(_), +						    ?BYTE(Major), ?BYTE(Minor), +						    ?BYTE(CookieLength), +						    Cookie:CookieLength/binary>>) -> +    #hello_verify_request{protocol_version = {Major, Minor}, +			  cookie = Cookie}; + +decode_handshake(Version, Tag,  <<?UINT24(_), ?UINT16(_), +				  ?UINT24(_),  ?UINT24(_), Msg/binary>>) ->  +    %% DTLS specifics stripped +    decode_tls_thandshake(Version, Tag, Msg). + +decode_tls_thandshake(Version, Tag, Msg) -> +    TLSVersion = dtls_v1:corresponding_tls_version(Version), +    ssl_handshake:decode_handshake(TLSVersion, Tag, Msg). + +decode_handshake_fragments(<<>>) -> +    [<<>>]; +decode_handshake_fragments(<<?BYTE(Type), ?UINT24(Length), +			     ?UINT16(MessageSeq), +			     ?UINT24(FragmentOffset), ?UINT24(FragmentLength), +			    Fragment:FragmentLength/binary, Rest/binary>>) -> +    [#handshake_fragment{type = Type,  +			length = Length, +			message_seq = MessageSeq, +			fragment_offset = FragmentOffset, +			fragment_length = FragmentLength, +			fragment = Fragment} | decode_handshake_fragments(Rest)]. + +reassemble(Version,  #handshake_fragment{message_seq = Seq} = Fragment,  +	   #protocol_buffers{dtls_handshake_next_seq = Seq, +			     dtls_handshake_next_fragments = Fragments0, +			     dtls_handshake_later_fragments = LaterFragments0} =  +	       Buffers0)->  +    case reassemble_fragments(Fragment, Fragments0) of +	{more_data, Fragments} -> +	    {more_data,  Buffers0#protocol_buffers{dtls_handshake_next_fragments = Fragments}}; +	{raw, RawHandshake} -> +	    Handshake = decode_handshake(Version, RawHandshake), +	    {NextFragments, LaterFragments} = next_fragments(LaterFragments0), +	    {{Handshake, RawHandshake}, Buffers0#protocol_buffers{dtls_handshake_next_seq = Seq + 1, +						  dtls_handshake_next_fragments = NextFragments, +						  dtls_handshake_later_fragments = LaterFragments}} +    end; +reassemble(_,  #handshake_fragment{message_seq = FragSeq} = Fragment,  +	   #protocol_buffers{dtls_handshake_next_seq = Seq, +			     dtls_handshake_later_fragments = LaterFragments} = Buffers0) when FragSeq > Seq->  +     {more_data, +      Buffers0#protocol_buffers{dtls_handshake_later_fragments = [Fragment | LaterFragments]}}; +reassemble(_, _, Buffers) ->  +    %% Disregard fragments FragSeq < Seq +    {more_data, Buffers}. + +reassemble_fragments(Current, Fragments0) -> +    [Frag1 | Frags] = lists:keysort(#handshake_fragment.fragment_offset, [Current | Fragments0]), +    [Fragment | _] = Fragments = merge_fragment(Frag1, Frags), +    case is_complete_handshake(Fragment) of +	true -> +	    {raw, handshake_bin(Fragment)}; +	false -> +	    {more_data, Fragments} +    end. -decode_handshake(_Version, ?HELLO_VERIFY_REQUEST, <<?BYTE(Major), ?BYTE(Minor), -						    ?BYTE(CookieLength), Cookie:CookieLength/binary>>) -> +merge_fragment(Frag0, []) -> +    [Frag0]; +merge_fragment(Frag0, [Frag1 | Rest]) -> +    case merge_fragments(Frag0, Frag1) of +	[_|_] = Frags -> +	    Frags ++ Rest; +	Frag -> +	    merge_fragment(Frag, Rest) +    end. -    #hello_verify_request{ -       protocol_version = {Major,Minor}, -       cookie = Cookie}; -decode_handshake(Version, Tag, Msg) -> -    ssl_handshake:decode_handshake(Version, Tag, Msg). +is_complete_handshake(#handshake_fragment{length = Length, fragment_length = Length}) -> +    true; +is_complete_handshake(_) -> +    false. + +next_fragments(LaterFragments) -> +    case lists:keysort(#handshake_fragment.message_seq, LaterFragments) of +	[] -> +	    {[], []};  +	[#handshake_fragment{message_seq = Seq} | _] = Fragments -> +	    split_frags(Fragments, Seq, []) +    end. -%% address_to_bin({A,B,C,D}, Port) -> -%%     <<0:80,16#ffff:16,A,B,C,D,Port:16>>; -%% address_to_bin({A,B,C,D,E,F,G,H}, Port) -> -%%     <<A:16,B:16,C:16,D:16,E:16,F:16,G:16,H:16,Port:16>>. +split_frags([#handshake_fragment{message_seq = Seq} = Frag | Rest], Seq, Acc) -> +    split_frags(Rest, Seq, [Frag | Acc]); +split_frags(Frags, _, Acc) -> +    {lists:reverse(Acc), Frags}. + + +%% Duplicate +merge_fragments(#handshake_fragment{ +		   fragment_offset = PreviousOffSet,  +		   fragment_length = PreviousLen, +		   fragment = PreviousData +		  } = Previous,  +		#handshake_fragment{ +		   fragment_offset = PreviousOffSet, +		   fragment_length = PreviousLen, +		   fragment = PreviousData}) -> +    Previous; + +%% Lager fragment save new data +merge_fragments(#handshake_fragment{ +		   fragment_offset = PreviousOffSet,  +		   fragment_length = PreviousLen, +		   fragment = PreviousData +		  } = Previous,  +		#handshake_fragment{ +		   fragment_offset = PreviousOffSet, +		   fragment_length = CurrentLen, +		   fragment = CurrentData}) when CurrentLen > PreviousLen -> +    NewLength = CurrentLen - PreviousLen, +    <<_:PreviousLen/binary, NewData/binary>> = CurrentData,  +    Previous#handshake_fragment{ +      fragment_length = PreviousLen + NewLength, +      fragment = <<PreviousData/binary, NewData/binary>> +     }; + +%% Smaller fragment +merge_fragments(#handshake_fragment{ +		   fragment_offset = PreviousOffSet,  +		   fragment_length = PreviousLen +		  } = Previous,  +		#handshake_fragment{ +		   fragment_offset = PreviousOffSet, +		   fragment_length = CurrentLen}) when CurrentLen < PreviousLen -> +    Previous; +%% Next fragment +merge_fragments(#handshake_fragment{ +		   fragment_offset = PreviousOffSet,  +		   fragment_length = PreviousLen, +		   fragment = PreviousData +		  } = Previous,  +		#handshake_fragment{ +		   fragment_offset = CurrentOffSet, +		   fragment_length = CurrentLen, +		   fragment = CurrentData}) when PreviousOffSet + PreviousLen == CurrentOffSet-> +	    Previous#handshake_fragment{ +	      fragment_length =  PreviousLen + CurrentLen, +	      fragment = <<PreviousData/binary, CurrentData/binary>>}; +%% No merge there is a gap +merge_fragments(Previous, Current) -> +    [Previous, Current]. +	     +handshake_bin(#handshake_fragment{ +		 type = Type, +		 length = Len,  +		 message_seq = Seq, +		 fragment_length = Len, +		 fragment_offset = 0, +		 fragment = Fragment}) ->	     +    handshake_bin(Type, Len, Seq, Fragment). + +handshake_bin(Type, Length, Seq, FragmentData) ->  +    <<?BYTE(Type), ?UINT24(Length), +      ?UINT16(Seq), ?UINT24(0), ?UINT24(Length), +      FragmentData:Length/binary>>.   diff --git a/lib/ssl/src/dtls_handshake.hrl b/lib/ssl/src/dtls_handshake.hrl index 0298fd3105..0a980c5f31 100644 --- a/lib/ssl/src/dtls_handshake.hrl +++ b/lib/ssl/src/dtls_handshake.hrl @@ -46,12 +46,13 @@  	  cookie  	 }). --record(dtls_hs_state, -	{current_read_seq, -	 starting_read_seq, -	 highest_record_seq, -	 fragments, -	 completed -	}). +-record(handshake_fragment, { +	  type, +	  length, +	  message_seq,                +	  fragment_offset,            +	  fragment_length, +	  fragment +	 }).  -endif. % -ifdef(dtls_handshake). diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl index 8a6e2d315c..2b42ddf9b9 100644 --- a/lib/ssl/src/dtls_record.erl +++ b/lib/ssl/src/dtls_record.erl @@ -36,7 +36,9 @@  -export([decode_cipher_text/2]).  %% Encoding --export([encode_plain_text/4, encode_tls_cipher_text/5, encode_change_cipher_spec/2]). +-export([encode_handshake/4, encode_alert_record/3, +	 encode_change_cipher_spec/3, encode_data/3]). +-export([encode_plain_text/5]).  %% Protocol version handling  -export([protocol_version/1, lowest_protocol_version/1, lowest_protocol_version/2, @@ -44,9 +46,9 @@  	 is_higher/2, supported_protocol_versions/0,  	 is_acceptable_version/2]). -%% DTLS Epoch handling --export([init_connection_state_seq/2, current_connection_state_epoch/2, -	 set_connection_state_by_epoch/3, connection_state_by_epoch/3]). +-export([save_current_connection_state/2, next_epoch/2]). + +-export([init_connection_state_seq/2, current_connection_state_epoch/2]).  -export_type([dtls_version/0, dtls_atom_version/0]). @@ -68,16 +70,64 @@  %%--------------------------------------------------------------------  init_connection_states(Role, BeastMitigation) ->      ConnectionEnd = ssl_record:record_protocol_role(Role), -    Current = initial_connection_state(ConnectionEnd, BeastMitigation), -    Pending = ssl_record:empty_connection_state(ConnectionEnd, BeastMitigation), -    #{write_msg_seq => 0,  -      prvious_read  => undefined, +    Initial = initial_connection_state(ConnectionEnd, BeastMitigation), +    Current = Initial#{epoch := 0}, +    InitialPending = ssl_record:empty_connection_state(ConnectionEnd, BeastMitigation), +    Pending = InitialPending#{epoch => undefined}, +    #{saved_read  => Current,        current_read  => Current,        pending_read  => Pending, -      prvious_write => undefined, +      saved_write => Current,        current_write => Current,        pending_write => Pending}. -   + +%%-------------------------------------------------------------------- +-spec save_current_connection_state(ssl_record:connection_states(), read | write) -> +				      ssl_record:connection_states(). +%% +%% Description: Returns the instance of the connection_state map +%% where the current read|write state has been copied to the save state. +%%-------------------------------------------------------------------- +save_current_connection_state(#{current_read := Current} = States, read) -> +    States#{saved_read := Current}; + +save_current_connection_state(#{current_write := Current} = States, write) -> +    States#{saved_write := Current}. + +next_epoch(#{pending_read := Pending, +	     current_read := #{epoch := Epoch}} = States, read) -> +    States#{pending_read := Pending#{epoch := Epoch + 1}}; + +next_epoch(#{pending_write := Pending, +	     current_write := #{epoch := Epoch}} = States, write) -> +    States#{pending_write := Pending#{epoch := Epoch + 1}}. + +get_connection_state_by_epoch(Epoch, #{current_write := #{epoch := Epoch} = Current}, +			      write) -> +    Current; +get_connection_state_by_epoch(Epoch, #{saved_write := #{epoch := Epoch} = Saved}, +			      write) -> +    Saved; +get_connection_state_by_epoch(Epoch, #{current_read := #{epoch := Epoch} = Current}, +			      read) -> +    Current; +get_connection_state_by_epoch(Epoch, #{saved_read := #{epoch := Epoch} = Saved}, +			      read) -> +    Saved. + +set_connection_state_by_epoch(WriteState, Epoch, #{current_write := #{epoch := Epoch}} = States, +			      write) -> +    States#{current_write := WriteState}; +set_connection_state_by_epoch(WriteState, Epoch, #{saved_write := #{epoch := Epoch}} = States, +			      write) -> +    States#{saved_write := WriteState}; +set_connection_state_by_epoch(ReadState, Epoch, #{current_read := #{epoch := Epoch}} = States, +			      read) -> +    States#{current_read := ReadState}; +set_connection_state_by_epoch(ReadState, Epoch, #{saved_read := #{epoch := Epoch}} = States, +			      read) -> +    States#{saved_read := ReadState}. +  %%--------------------------------------------------------------------  -spec get_dtls_records(binary(), binary()) -> {[binary()], binary()} | #alert{}.  %% @@ -140,98 +190,57 @@ get_dtls_records_aux(Data, Acc) ->  	    ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE)      end. -encode_plain_text(Type, Version, Data, -		  #{current_write := -			#{epoch := Epoch, -			  sequence_number := Seq, -			  compression_state := CompS0, -			  security_parameters := -			      #security_parameters{ -				 cipher_type = ?AEAD, -				 compression_algorithm = CompAlg} -			 }= WriteState0} = ConnectionStates) -> -    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), -    WriteState1 = WriteState0#{compression_state => CompS1}, -    AAD = calc_aad(Type, Version, Epoch, Seq), -    {CipherFragment, WriteState} = ssl_record:cipher_aead(dtls_v1:corresponding_tls_version(Version), -							  Comp, WriteState1, AAD), -    CipherText = encode_tls_cipher_text(Type, Version, Epoch, Seq, CipherFragment), -    {CipherText, ConnectionStates#{current_write => WriteState#{sequence_number => Seq +1}}}; - -encode_plain_text(Type, Version, Data, -		  #{current_write :=  -			#{epoch := Epoch, -			  sequence_number := Seq, -			  compression_state := CompS0, -			  security_parameters := -			      #security_parameters{compression_algorithm = CompAlg} -			 }= WriteState0} = ConnectionStates) -> -    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), -    WriteState1 = WriteState0#{compression_state => CompS1}, -    MacHash = calc_mac_hash(WriteState1, Type, Version, Epoch, Seq, Comp), -    {CipherFragment, WriteState} = ssl_record:cipher(dtls_v1:corresponding_tls_version(Version),  -						     Comp, WriteState1, MacHash), -    CipherText = encode_tls_cipher_text(Type, Version, Epoch, Seq, CipherFragment), -    {CipherText, ConnectionStates#{current_write => WriteState#{sequence_number => Seq +1}}}. +%%-------------------------------------------------------------------- +-spec encode_handshake(iolist(), dtls_version(), integer(), ssl_record:connection_states()) -> +			      {iolist(), ssl_record:connection_states()}. +% +%% Description: Encodes a handshake message to send on the ssl-socket. +%%-------------------------------------------------------------------- +encode_handshake(Frag, Version, Epoch, ConnectionStates) -> +    encode_plain_text(?HANDSHAKE, Version, Epoch, Frag, ConnectionStates). -decode_cipher_text(#ssl_tls{type = Type, version = Version, -			    epoch = Epoch, -			    sequence_number = Seq, -			    fragment = CipherFragment} = CipherText, -		   #{current_read := -			 #{compression_state := CompressionS0, -			   security_parameters := -			       #security_parameters{ -				  cipher_type = ?AEAD, -				  compression_algorithm = CompAlg} -			  } = ReadState0} = ConnnectionStates0) -> -    AAD = calc_aad(Type, Version, Epoch, Seq), -    case ssl_record:decipher_aead(dtls_v1:corresponding_tls_version(Version), -				  CipherFragment, ReadState0, AAD) of -	{PlainFragment, ReadState1} -> -	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, -							   PlainFragment, CompressionS0), -	    ConnnectionStates = ConnnectionStates0#{ -				  current_read => ReadState1#{ -						    compression_state => CompressionS1}}, -	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; -	#alert{} = Alert -> -	    Alert -    end; -decode_cipher_text(#ssl_tls{type = Type, version = Version, -			    epoch = Epoch, -			    sequence_number = Seq, -			    fragment = CipherFragment} = CipherText, -		   #{current_read := -			 #{compression_state := CompressionS0, -			   security_parameters := -			       #security_parameters{ -				  compression_algorithm = CompAlg} -			  } = ReadState0}= ConnnectionStates0) -> -    {PlainFragment, Mac, ReadState1} = ssl_record:decipher(dtls_v1:corresponding_tls_version(Version), -							   CipherFragment, ReadState0, true), -    MacHash = calc_mac_hash(ReadState1, Type, Version, Epoch, Seq, PlainFragment), -    case ssl_record:is_correct_mac(Mac, MacHash) of -	true -> -	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, -							   PlainFragment, CompressionS0), -	    ConnnectionStates = ConnnectionStates0#{ -				  current_read => ReadState1#{ -						    compression_state => CompressionS1}}, -	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; -	false -> -	    ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) -    end. +%%-------------------------------------------------------------------- +-spec encode_alert_record(#alert{}, dtls_version(), ssl_record:connection_states()) -> +				 {iolist(), ssl_record:connection_states()}. +%% +%% Description: Encodes an alert message to send on the ssl-socket. +%%-------------------------------------------------------------------- +encode_alert_record(#alert{level = Level, description = Description}, +                    Version, ConnectionStates) -> +    #{epoch := Epoch} = ssl_record:current_connection_state(ConnectionStates, write), +    encode_plain_text(?ALERT, Version, Epoch, <<?BYTE(Level), ?BYTE(Description)>>, +		      ConnectionStates).  %%-------------------------------------------------------------------- --spec encode_change_cipher_spec(dtls_version(), ssl_record:connection_states()) -> +-spec encode_change_cipher_spec(dtls_version(), integer(), ssl_record:connection_states()) ->  				       {iolist(), ssl_record:connection_states()}.  %%  %% Description: Encodes a change_cipher_spec-message to send on the ssl socket.  %%-------------------------------------------------------------------- -encode_change_cipher_spec(Version, ConnectionStates) -> -    encode_plain_text(?CHANGE_CIPHER_SPEC, Version, <<1:8>>, ConnectionStates). +encode_change_cipher_spec(Version, Epoch, ConnectionStates) -> +    encode_plain_text(?CHANGE_CIPHER_SPEC, Version, Epoch, ?byte(?CHANGE_CIPHER_SPEC_PROTO), ConnectionStates). + +%%-------------------------------------------------------------------- +-spec encode_data(binary(), dtls_version(), ssl_record:connection_states()) -> +			 {iolist(),ssl_record:connection_states()}. +%% +%% Description: Encodes data to send on the ssl-socket. +%%-------------------------------------------------------------------- +encode_data(Data, Version, ConnectionStates) -> +    #{epoch := Epoch} = ssl_record:current_connection_state(ConnectionStates, write), +    encode_plain_text(?APPLICATION_DATA, Version, Epoch, Data, ConnectionStates). + +encode_plain_text(Type, Version, Epoch, Data, ConnectionStates) -> +    Write0 = get_connection_state_by_epoch(Epoch, ConnectionStates, write), +    {CipherFragment, Write1} = encode_plain_text(Type, Version, Data, Write0), +    {CipherText, Write} = encode_dtls_cipher_text(Type, Version, CipherFragment, Write1), +    {CipherText, set_connection_state_by_epoch(Write, Epoch, ConnectionStates, write)}. + + +decode_cipher_text(#ssl_tls{epoch = Epoch} = CipherText, ConnnectionStates0) -> +    ReadState = get_connection_state_by_epoch(Epoch, ConnnectionStates0, read), +    decode_cipher_text(CipherText, ReadState, ConnnectionStates0).  %%--------------------------------------------------------------------  -spec protocol_version(dtls_atom_version() | dtls_version()) -> @@ -373,12 +382,11 @@ is_acceptable_version(Version, Versions) ->  %% This is only valid for DTLS in the first client_hello  %%--------------------------------------------------------------------  init_connection_state_seq({254, _}, -			  #{current_read := #{epoch := 0} = Read, -			    current_write := #{epoch := 0} = Write} = CS0) -> -    Seq = maps:get(sequence_number, Read), -    CS0#{current_write => Write#{sequence_number => Seq}}; -init_connection_state_seq(_, CS) -> -    CS. +			  #{current_read := #{epoch := 0, sequence_number := Seq}, +			    current_write := #{epoch := 0} = Write} = ConnnectionStates0) -> +    ConnnectionStates0#{current_write => Write#{sequence_number => Seq}}; +init_connection_state_seq(_, ConnnectionStates) -> +    ConnnectionStates.  %%--------------------------------------------------------  -spec current_connection_state_epoch(ssl_record:connection_states(), read | write) -> @@ -387,49 +395,12 @@ init_connection_state_seq(_, CS) ->  %% Description: Returns the epoch the connection_state record  %% that is currently defined as the current conection state.  %%-------------------------------------------------------------------- -current_connection_state_epoch(#{current_read := Current}, +current_connection_state_epoch(#{current_read := #{epoch := Epoch}},  			       read) -> -    maps:get(epoch, Current); -current_connection_state_epoch(#{current_write := Current}, +    Epoch; +current_connection_state_epoch(#{current_write := #{epoch := Epoch}},  			       write) -> -    maps:get(epoch, Current). - -%%-------------------------------------------------------------------- - --spec connection_state_by_epoch(ssl_record:connection_states(), integer(), read | write) -> -				       ssl_record:connection_state(). -%% -%% Description: Returns the instance of the connection_state record -%% that is defined by the Epoch. -%%-------------------------------------------------------------------- -connection_state_by_epoch(#{current_read := #{epoch := Epoch}} = CS, Epoch, read) -> -    CS; -connection_state_by_epoch(#{pending_read := #{epoch := Epoch}} = CS, Epoch, read) -> -    CS; -connection_state_by_epoch(#{current_write := #{epoch := Epoch}} = CS, Epoch, write) -> -    CS; -connection_state_by_epoch(#{pending_write := #{epoch := Epoch}} = CS, Epoch, write) -> -    CS. -%%-------------------------------------------------------------------- --spec set_connection_state_by_epoch(ssl_record:connection_states(), -				    ssl_record:connection_state(), read | write) -				   -> ssl_record:connection_states(). -%% -%% Description: Returns the instance of the connection_state record -%% that is defined by the Epoch. -%%-------------------------------------------------------------------- -set_connection_state_by_epoch(#{current_read := #{epoch := Epoch}} = ConnectionStates0, -                              NewCS = #{epoch := Epoch}, read) -> -    ConnectionStates0#{current_read => NewCS}; -set_connection_state_by_epoch(#{pending_read := #{epoch := Epoch}} = ConnectionStates0, -			      NewCS = #{epoch := Epoch}, read) -> -    ConnectionStates0#{pending_read => NewCS}; -set_connection_state_by_epoch(#{current_write := #{epoch := Epoch}} = ConnectionStates0, -			      NewCS = #{epoch := Epoch}, write) -> -    ConnectionStates0#{current_write => NewCS}; -set_connection_state_by_epoch(#{pending_write := #{epoch := Epoch}} = ConnectionStates0, -NewCS = #{epoch := Epoch}, write) -> -    ConnectionStates0#{pending_write => NewCS}. +    Epoch.  %%--------------------------------------------------------------------  %%% Internal functions @@ -437,8 +408,8 @@ NewCS = #{epoch := Epoch}, write) ->  initial_connection_state(ConnectionEnd, BeastMitigation) ->      #{security_parameters =>  	  ssl_record:initial_security_params(ConnectionEnd), -      epoch => 0, -      sequence_number => 1, +      epoch => undefined, +      sequence_number => 0,        beast_mitigation => BeastMitigation,        compression_state  => undefined,        cipher_state  => undefined, @@ -458,14 +429,85 @@ highest_list_protocol_version(Ver, []) ->  highest_list_protocol_version(Ver1,  [Ver2 | Rest]) ->      highest_list_protocol_version(highest_protocol_version(Ver1, Ver2), Rest). -encode_tls_cipher_text(Type, {MajVer, MinVer}, Epoch, Seq, Fragment) -> +encode_dtls_cipher_text(Type, {MajVer, MinVer}, Fragment,  +		       #{epoch := Epoch, sequence_number := Seq} = WriteState) ->      Length = erlang:iolist_size(Fragment), -    [<<?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer), ?UINT16(Epoch), -       ?UINT48(Seq), ?UINT16(Length)>>, Fragment]. +    {[<<?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer), ?UINT16(Epoch), +	?UINT48(Seq), ?UINT16(Length)>>, Fragment],  +     WriteState#{sequence_number => Seq + 1}}. + +encode_plain_text(Type, Version, Data, #{compression_state := CompS0, +					 epoch := Epoch, +					 sequence_number := Seq, +					 security_parameters := +					     #security_parameters{ +						cipher_type = ?AEAD, +						compression_algorithm = CompAlg} +					} = WriteState0) -> +    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), +    WriteState1 = WriteState0#{compression_state => CompS1}, +    AAD = calc_aad(Type, Version, Epoch, Seq), +    ssl_record:cipher_aead(dtls_v1:corresponding_tls_version(Version), Comp, WriteState1, AAD); +encode_plain_text(Type, Version, Data, #{compression_state := CompS0, +					 epoch := Epoch, +					 sequence_number := Seq, +					 security_parameters := +					     #security_parameters{compression_algorithm = CompAlg} +					}= WriteState0) -> +    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), +    WriteState1 = WriteState0#{compression_state => CompS1}, +    MacHash = calc_mac_hash(Type, Version, WriteState1, Epoch, Seq, Comp), +    ssl_record:cipher(dtls_v1:corresponding_tls_version(Version), Comp, WriteState1, MacHash). + +decode_cipher_text(#ssl_tls{type = Type, version = Version, +			    epoch = Epoch, +			    sequence_number = Seq, +			    fragment = CipherFragment} = CipherText, +		   #{compression_state := CompressionS0, +		     security_parameters := +			 #security_parameters{ +			    cipher_type = ?AEAD, +			    compression_algorithm = CompAlg}} = ReadState0,  +		   ConnnectionStates0) -> +    AAD = calc_aad(Type, Version, Epoch, Seq), +    case ssl_record:decipher_aead(dtls_v1:corresponding_tls_version(Version), +				  CipherFragment, ReadState0, AAD) of +	{PlainFragment, ReadState1} -> +	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, +							   PlainFragment, CompressionS0), +	    ReadState = ReadState1#{compression_state => CompressionS1}, +	    ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read), +	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; +	  #alert{} = Alert -> +	    Alert +    end; +decode_cipher_text(#ssl_tls{type = Type, version = Version, +			    epoch = Epoch, +			    sequence_number = Seq, +			    fragment = CipherFragment} = CipherText, +		   #{compression_state := CompressionS0, +		     security_parameters := +			 #security_parameters{ +			    compression_algorithm = CompAlg}} = ReadState0, +		   ConnnectionStates0) -> +    {PlainFragment, Mac, ReadState1} = ssl_record:decipher(dtls_v1:corresponding_tls_version(Version), +							   CipherFragment, ReadState0, true), +    MacHash = calc_mac_hash(Type, Version, ReadState1, Epoch, Seq, PlainFragment), +    case ssl_record:is_correct_mac(Mac, MacHash) of +	true -> +	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, +							   PlainFragment, CompressionS0), +	     +	    ReadState = ReadState1#{compression_state => CompressionS1}, +	    ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read), +	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; +	false -> +	    ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) +    end. -calc_mac_hash(#{mac_secret := MacSecret, -		security_parameters := #security_parameters{mac_algorithm = MacAlg}}, -	      Type, Version, Epoch, SeqNo, Fragment) -> +calc_mac_hash(Type, Version, #{mac_secret := MacSecret, +			       security_parameters := #security_parameters{mac_algorithm = MacAlg}}, +	      Epoch, SeqNo, Fragment) ->      Length = erlang:iolist_size(Fragment),      NewSeq = (Epoch bsl 48) + SeqNo,      mac_hash(Version, MacAlg, MacSecret, NewSeq, Type, diff --git a/lib/ssl/src/dtls_record.hrl b/lib/ssl/src/dtls_record.hrl index b9f84cbe7f..373481c3f8 100644 --- a/lib/ssl/src/dtls_record.hrl +++ b/lib/ssl/src/dtls_record.hrl @@ -34,11 +34,10 @@  -record(ssl_tls, {     	  type,  	  version, -	  epoch,            -	  sequence_number,       -	  offset, -	  length, -	  fragment +	  %%length, +	  fragment, +	  epoch,    +	  sequence_number  	 }).  -endif. % -ifdef(dtls_record). diff --git a/lib/ssl/src/dtls_socket.erl b/lib/ssl/src/dtls_socket.erl new file mode 100644 index 0000000000..570b3ae83a --- /dev/null +++ b/lib/ssl/src/dtls_socket.erl @@ -0,0 +1,148 @@ +%% +%% %CopyrightBegin% +%% +%% Copyright Ericsson AB 2016-2016. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%%     http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%% +%% %CopyrightEnd% +%% +-module(dtls_socket). + +-include("ssl_internal.hrl"). +-include("ssl_api.hrl"). + +-export([send/3, listen/3, accept/3, connect/4, socket/4, setopts/3, getopts/3, getstat/3,  +	 peername/2, sockname/2, port/2, close/2]). +-export([emulated_options/0, internal_inet_values/0, default_inet_values/0, default_cb_info/0]). + +send(Transport, {{IP,Port},Socket}, Data) -> +    Transport:send(Socket, IP, Port, Data). + +listen(gen_udp = Transport, Port, #config{transport_info = {Transport, _, _, _}, +					  ssl = SslOpts,  +					  emulated = EmOpts, +					  inet_user = Options} = Config) -> +     +     +    case dtls_udp_sup:start_child([Port, emulated_socket_options(EmOpts, #socket_options{}),  +				   Options ++ internal_inet_values(), SslOpts]) of +	{ok, Pid} -> +	    {ok, #sslsocket{pid = {udp, Config#config{udp_handler = {Pid, Port}}}}}; +	Err = {error, _} -> +	    Err +    end. + +accept(udp, #config{transport_info = {Transport = gen_udp,_,_,_}, +		    connection_cb = ConnectionCb, +		    udp_handler = {Listner, _}}, _Timeout) ->  +    case dtls_udp_listener:accept(Listner, self()) of +	{ok, Pid, Socket} -> +	    {ok, socket(Pid, Transport, {Listner, Socket}, ConnectionCb)}; +	{error, Reason} -> +	    {error, Reason} +    end. + +connect(Address, Port, #config{transport_info = {Transport, _, _, _} = CbInfo, +				connection_cb = ConnectionCb, +				ssl = SslOpts, +				emulated = EmOpts, +				inet_ssl = SocketOpts}, Timeout) -> +    case Transport:open(0, SocketOpts ++ internal_inet_values()) of +	{ok, Socket} -> +	    ssl_connection:connect(ConnectionCb, Address, Port, {{Address, Port},Socket},  +				   {SslOpts,  +				    emulated_socket_options(EmOpts, #socket_options{}), undefined}, +				   self(), CbInfo, Timeout); +	{error, _} = Error->	 +	    Error +    end. + +close(gen_udp, {_Client, _Socket}) -> +    ok. + +socket(Pid, Transport, Socket, ConnectionCb) -> +    #sslsocket{pid = Pid,  +	       %% "The name "fd" is keept for backwards compatibility +	       fd = {Transport, Socket, ConnectionCb}}.	 + +%% Vad göra med emulerade +setopts(gen_udp, #sslsocket{pid = {Socket, _}}, Options) -> +    {SockOpts, _} = tls_socket:split_options(Options), +    inet:setopts(Socket, SockOpts); +setopts(_, #sslsocket{pid = {ListenSocket, #config{transport_info = {Transport,_,_,_}}}}, Options) -> +    {SockOpts, _} = tls_socket:split_options(Options), +    Transport:setopts(ListenSocket, SockOpts); +%%% Following clauses will not be called for emulated options, they are  handled in the connection process +setopts(gen_udp, Socket, Options) -> +    inet:setopts(Socket, Options); +setopts(Transport, Socket, Options) -> +    Transport:setopts(Socket, Options). + +getopts(gen_udp,  #sslsocket{pid = {Socket, #config{emulated = EmOpts}}}, Options) -> +    {SockOptNames, EmulatedOptNames} = tls_socket:split_options(Options), +    EmulatedOpts = get_emulated_opts(EmOpts, EmulatedOptNames), +    SocketOpts = tls_socket:get_socket_opts(Socket, SockOptNames, inet), +    {ok, EmulatedOpts ++ SocketOpts};  +getopts(Transport,  #sslsocket{pid = {ListenSocket, #config{emulated = EmOpts}}}, Options) -> +    {SockOptNames, EmulatedOptNames} = tls_socket:split_options(Options), +    EmulatedOpts = get_emulated_opts(EmOpts, EmulatedOptNames), +    SocketOpts = tls_socket:get_socket_opts(ListenSocket, SockOptNames, Transport), +    {ok, EmulatedOpts ++ SocketOpts};  +%%% Following clauses will not be called for emulated options, they are  handled in the connection process +getopts(gen_udp, {_,Socket}, Options) -> +    inet:getopts(Socket, Options); +getopts(Transport, Socket, Options) -> +    Transport:getopts(Socket, Options). +getstat(gen_udp, {_,Socket}, Options) -> +	inet:getstat(Socket, Options); +getstat(Transport, Socket, Options) -> +	Transport:getstat(Socket, Options). +peername(gen_udp, {_, {Client, _Socket}}) -> +    {ok, Client}; +peername(Transport, Socket) -> +    Transport:peername(Socket). +sockname(gen_udp, {_,Socket}) -> +    inet:sockname(Socket); +sockname(Transport, Socket) -> +    Transport:sockname(Socket). + +port(gen_udp, {_,Socket}) -> +    inet:port(Socket); +port(Transport, Socket) -> +    Transport:port(Socket). + +emulated_options() -> +    [mode, active,  packet, packet_size]. + +internal_inet_values() -> +    [{active, false}, {mode,binary}]. + +default_inet_values() -> +    [{active, true}, {mode, list}]. + +default_cb_info() -> +    {gen_udp, udp, udp_closed, udp_error}. + +get_emulated_opts(EmOpts, EmOptNames) ->  +    lists:map(fun(Name) -> {value, Value} = lists:keysearch(Name, 1, EmOpts), +			   Value end, +	      EmOptNames). + +emulated_socket_options(InetValues, #socket_options{ +				       mode   = Mode, +				       active = Active}) -> +    #socket_options{ +       mode   = proplists:get_value(mode, InetValues, Mode), +       active = proplists:get_value(active, InetValues, Active) +      }. diff --git a/lib/ssl/src/dtls_udp_listener.erl b/lib/ssl/src/dtls_udp_listener.erl new file mode 100644 index 0000000000..b7f115582e --- /dev/null +++ b/lib/ssl/src/dtls_udp_listener.erl @@ -0,0 +1,205 @@ +%% +%% %CopyrightBegin% +%% +%% Copyright Ericsson AB 2016-2016. All Rights Reserved. +%% +%% The contents of this file are subject to the Erlang Public License, +%% Version 1.1, (the "License"); you may not use this file except in +%% compliance with the License. You should have received a copy of the +%% Erlang Public License along with this software. If not, it can be +%% retrieved online at http://www.erlang.org/. +%% +%% Software distributed under the License is distributed on an "AS IS" +%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See +%% the License for the specific language governing rights and limitations +%% under the License. +%% +%% %CopyrightEnd% +%% + +%% + +-module(dtls_udp_listener). + +-behaviour(gen_server). + +%% API +-export([start_link/4, active_once/3, accept/2, sockname/1]). + +%% gen_server callbacks +-export([init/1, handle_call/3, handle_cast/2, handle_info/2, +	 terminate/2, code_change/3]). + +-record(state,  +	{port,  +	 listner, +	 dtls_options, +	 emulated_options, +	 dtls_msq_queues = kv_new(), +	 clients = set_new(), +	 dtls_processes = kv_new(), +	 accepters  = queue:new(), +	 first +	}). + +%%%=================================================================== +%%% API +%%%=================================================================== + +start_link(Port, EmOpts, InetOptions, DTLSOptions) -> +    gen_server:start_link(?MODULE, [Port, EmOpts, InetOptions, DTLSOptions], []). + +active_once(UDPConnection, Client, Pid) -> +    gen_server:cast(UDPConnection, {active_once, Client, Pid}). + +accept(UDPConnection, Accepter) -> +    gen_server:call(UDPConnection, {accept, Accepter}, infinity). + +sockname(UDPConnection) -> +    gen_server:call(UDPConnection, sockname, infinity). + +%%%=================================================================== +%%% gen_server callbacks +%%%=================================================================== + +init([Port, EmOpts, InetOptions, DTLSOptions]) -> +    try  +	{ok, Socket} = gen_udp:open(Port, InetOptions), +	{ok, #state{port = Port, +		    first = true, +		    dtls_options = DTLSOptions, +		    emulated_options = EmOpts, +		    listner = Socket}} +    catch _:_ -> +	    {error, closed} +    end. + +handle_call({accept, Accepter}, From, #state{first = true, +					     accepters = Accepters, +					     listner = Socket} = State0) -> +    next_datagram(Socket), +    State = State0#state{first = false, +			 accepters = queue:in({Accepter, From}, Accepters)}, 		  +    {noreply, State}; + +handle_call({accept, Accepter}, From, #state{accepters = Accepters} = State0) -> +    State = State0#state{accepters = queue:in({Accepter, From}, Accepters)}, 		  +    {noreply, State}; +handle_call(sockname, _, #state{listner = Socket} = State) -> +    Reply = inet:sockname(Socket), +    {reply, Reply, State}. + +handle_cast({active_once, Client, Pid}, State0) -> +    State = handle_active_once(Client, Pid, State0), +    {noreply, State}. + +handle_info({udp, Socket, IP, InPortNo, _} = Msg, #state{listner = Socket} = State0) -> +    State = handle_datagram({IP, InPortNo}, Msg, State0), +    next_datagram(Socket), +    {noreply, State}; + +handle_info({'DOWN', _, process, Pid, _}, #state{clients = Clients, +						 dtls_processes = Processes0} = State) -> +    Client = kv_get(Pid, Processes0), +    Processes = kv_delete(Pid, Processes0), +    {noreply, State#state{clients = set_delete(Client, Clients), +			  dtls_processes = Processes}}. + +terminate(_Reason, _State) -> +    ok. + +code_change(_OldVsn, State, _Extra) -> +    {ok, State}. + +%%%=================================================================== +%%% Internal functions +%%%=================================================================== +handle_datagram(Client, Msg, #state{clients = Clients, +				    accepters = AcceptorsQueue0} = State) -> +    case set_is_member(Client, Clients) of +	false -> +	    case queue:out(AcceptorsQueue0) of +		{{value, {UserPid, From}}, AcceptorsQueue} ->	 +		    setup_new_connection(UserPid, From, Client, Msg,  +					 State#state{accepters = AcceptorsQueue}); +		{empty, _} -> +		    %% Drop packet client will resend +		    State +	    end; +	true ->  +	    dispatch(Client, Msg, State) +    end. + +dispatch(Client, Msg, #state{dtls_msq_queues = MsgQueues} = State) -> +    case kv_lookup(Client, MsgQueues) of +	{value, Queue0} -> +	    case queue:out(Queue0) of +		{{value, Pid}, Queue} when is_pid(Pid) -> +		    Pid ! Msg, +		    State#state{dtls_msq_queues =  +				    kv_update(Client, Queue, MsgQueues)}; +		{{value, _}, Queue} ->  +		    State#state{dtls_msq_queues =  +				    kv_update(Client, queue:in(Msg, Queue), MsgQueues)}; +		{empty, Queue} -> +		    State#state{dtls_msq_queues =  +				    kv_update(Client, queue:in(Msg, Queue), MsgQueues)} +	    end +    end. +next_datagram(Socket) -> +    inet:setopts(Socket, [{active, once}]). + +handle_active_once(Client, Pid, #state{dtls_msq_queues = MsgQueues} = State0) -> +    Queue0 = kv_get(Client, MsgQueues), +    case queue:out(Queue0) of +	{{value, Pid}, _} when is_pid(Pid) -> +	    State0; +	{{value, Msg}, Queue} ->	       +	    Pid ! Msg, +	    State0#state{dtls_msq_queues = kv_update(Client, Queue, MsgQueues)}; +	{empty, Queue0} -> +	    State0#state{dtls_msq_queues = kv_update(Client, queue:in(Pid, Queue0), MsgQueues)} +    end. + +setup_new_connection(User, From, Client, Msg, #state{dtls_processes = Processes, +						     clients = Clients, +						     dtls_msq_queues = MsgQueues, +						     dtls_options = DTLSOpts, +						     port = Port, +						     listner = Socket, +						     emulated_options = EmOpts} = State) -> +    ConnArgs = [server, "localhost", Port, {self(), {Client, Socket}}, +		{DTLSOpts, EmOpts, udp_listner}, User, dtls_socket:default_cb_info()], +    case dtls_connection_sup:start_child(ConnArgs) of +	{ok, Pid} -> +	    erlang:monitor(process, Pid), +	    gen_server:reply(From, {ok, Pid, {Client, Socket}}), +	    Pid ! Msg, +	    State#state{clients = set_insert(Client, Clients),  +			dtls_msq_queues = kv_insert(Client, queue:new(), MsgQueues), +			dtls_processes = kv_insert(Pid, Client, Processes)}; +	{error, Reason} -> +	    gen_server:reply(From, {error, Reason}), +	    State +    end. +kv_update(Key, Value, Store) -> +    gb_trees:update(Key, Value, Store). +kv_lookup(Key, Store) -> +    gb_trees:lookup(Key, Store). +kv_insert(Key, Value, Store) -> +    gb_trees:insert(Key, Value, Store). +kv_get(Key, Store) ->  +    gb_trees:get(Key, Store). +kv_delete(Key, Store) -> +    gb_trees:delete(Key, Store). +kv_new() -> +    gb_trees:empty(). + +set_new() -> +    gb_sets:empty(). +set_insert(Item, Set) -> +    gb_sets:insert(Item, Set). +set_delete(Item, Set) -> +    gb_sets:delete(Item, Set). +set_is_member(Item, Set) -> +    gb_sets:is_member(Item, Set). diff --git a/lib/ssl/src/dtls_udp_sup.erl b/lib/ssl/src/dtls_udp_sup.erl new file mode 100644 index 0000000000..197882e92f --- /dev/null +++ b/lib/ssl/src/dtls_udp_sup.erl @@ -0,0 +1,62 @@ +%% +%% %CopyrightBegin% +%%  +%% Copyright Ericsson AB 2016-2016. All Rights Reserved. +%%  +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%%     http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%%  +%% %CopyrightEnd% +%% + +%% +%%---------------------------------------------------------------------- +%% Purpose: Supervisor for a procsses dispatching upd datagrams to +%% correct DTLS handler  +%%---------------------------------------------------------------------- +-module(dtls_udp_sup). + +-behaviour(supervisor). + +%% API +-export([start_link/0]). +-export([start_child/1]). + +%% Supervisor callback +-export([init/1]). + +%%%========================================================================= +%%%  API +%%%========================================================================= +start_link() -> +    supervisor:start_link({local, ?MODULE}, ?MODULE, []). + +start_child(Args) -> +    supervisor:start_child(?MODULE, Args). +     +%%%========================================================================= +%%%  Supervisor callback +%%%========================================================================= +init(_O) -> +    RestartStrategy = simple_one_for_one, +    MaxR = 0, +    MaxT = 3600, +    +    Name = undefined, % As simple_one_for_one is used. +    StartFunc = {dtls_udp_listener, start_link, []}, +    Restart = temporary, % E.g. should not be restarted +    Shutdown = 4000, +    Modules = [dtls_udp_listener], +    Type = worker, +     +    ChildSpec = {Name, StartFunc, Restart, Shutdown, Type, Modules}, +    {ok, {{RestartStrategy, MaxR, MaxT}, [ChildSpec]}}. diff --git a/lib/ssl/src/dtls_v1.erl b/lib/ssl/src/dtls_v1.erl index 8c03bda513..ffd3e4b833 100644 --- a/lib/ssl/src/dtls_v1.erl +++ b/lib/ssl/src/dtls_v1.erl @@ -21,7 +21,7 @@  -include("ssl_cipher.hrl"). --export([suites/1, mac_hash/7, ecc_curves/1, corresponding_tls_version/1]). +-export([suites/1, mac_hash/7, ecc_curves/1, corresponding_tls_version/1, corresponding_dtls_version/1]).  -spec suites(Minor:: 253|255) -> [ssl_cipher:cipher_suite()]. @@ -29,7 +29,7 @@ suites(Minor) ->     tls_v1:suites(corresponding_minor_tls_version(Minor)).  mac_hash(Version, MacAlg, MacSecret, SeqNo, Type, Length, Fragment) -> -    tls_v1:mac_hash(MacAlg, MacSecret, SeqNo, Type, corresponding_tls_version(Version), +    tls_v1:mac_hash(MacAlg, MacSecret, SeqNo, Type, Version,  		    Length, Fragment).  ecc_curves({_Major, Minor}) -> @@ -42,3 +42,11 @@ corresponding_minor_tls_version(255) ->      2;  corresponding_minor_tls_version(253) ->      3. + +corresponding_dtls_version({3, Minor}) ->  +    {254, corresponding_minor_dtls_version(Minor)}. + +corresponding_minor_dtls_version(2) -> +    255; +corresponding_minor_dtls_version(3) -> +    253. diff --git a/lib/ssl/src/ssl.app.src b/lib/ssl/src/ssl.app.src index 00b0513891..9c5d795848 100644 --- a/lib/ssl/src/ssl.app.src +++ b/lib/ssl/src/ssl.app.src @@ -6,6 +6,7 @@  	       tls_connection,  	       tls_handshake,  	       tls_record, +	       tls_socket,  	       tls_v1,  	       ssl_v3,  	       ssl_v2, @@ -13,7 +14,10 @@  	       dtls_connection,  	       dtls_handshake,  	       dtls_record, +	       dtls_socket,  	       dtls_v1, +	       dtls_udp_listener, +	       dtls_udp_sup,  	       %% API  	       ssl,  %% Main API		    	       tls,  %% TLS specific @@ -27,7 +31,6 @@  	       ssl_cipher,  	       ssl_srp_primes,  	       ssl_alert, -	       ssl_socket,  	       ssl_listen_tracker_sup,	  	       %% Erlang Distribution over SSL/TLS  	       inet_tls_dist, diff --git a/lib/ssl/src/ssl.erl b/lib/ssl/src/ssl.erl index aa62ab8865..c72ee44a95 100644 --- a/lib/ssl/src/ssl.erl +++ b/lib/ssl/src/ssl.erl @@ -101,33 +101,27 @@ connect(Socket, SslOptions0, Timeout) when is_port(Socket),  					    (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->      {Transport,_,_,_} = proplists:get_value(cb_info, SslOptions0,  					      {gen_tcp, tcp, tcp_closed, tcp_error}), -    EmulatedOptions = ssl_socket:emulated_options(), -    {ok, SocketValues} = ssl_socket:getopts(Transport, Socket, EmulatedOptions), +    EmulatedOptions = tls_socket:emulated_options(), +    {ok, SocketValues} = tls_socket:getopts(Transport, Socket, EmulatedOptions),      try handle_options(SslOptions0 ++ SocketValues, client) of -	{ok, #config{transport_info = CbInfo, ssl = SslOptions, emulated = EmOpts, -		     connection_cb = ConnectionCb}} -> - -	    ok = ssl_socket:setopts(Transport, Socket, ssl_socket:internal_inet_values()), -	    case ssl_socket:peername(Transport, Socket) of -		{ok, {Address, Port}} -> -		    ssl_connection:connect(ConnectionCb, Address, Port, Socket, -					   {SslOptions, emulated_socket_options(EmOpts, #socket_options{}), undefined}, -					   self(), CbInfo, Timeout); -		{error, Error} -> -		    {error, Error} -	    end +	{ok, Config} -> +	    tls_socket:upgrade(Socket, Config, Timeout)      catch  	_:{error, Reason} ->              {error, Reason}      end; -  connect(Host, Port, Options) ->      connect(Host, Port, Options, infinity).  connect(Host, Port, Options, Timeout) when (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) -> -    try handle_options(Options, client) of -	{ok, Config} -> -	    do_connect(Host,Port,Config,Timeout) +    try +	{ok, Config} = handle_options(Options, client), +	case Config#config.connection_cb of +	    tls_connection -> +		tls_socket:connect(Host,Port,Config,Timeout); +	    dtls_connection -> +		dtls_socket:connect(Host,Port,Config,Timeout) +	end      catch  	throw:Error ->  	    Error @@ -144,17 +138,7 @@ listen(_Port, []) ->  listen(Port, Options0) ->      try  	{ok, Config} = handle_options(Options0, server), -	ConnectionCb = connection_cb(Options0), -	#config{transport_info = {Transport, _, _, _}, inet_user = Options, connection_cb = ConnectionCb, -		ssl = SslOpts, emulated = EmOpts} = Config, -	case Transport:listen(Port, Options) of -	    {ok, ListenSocket} -> -		ok = ssl_socket:setopts(Transport, ListenSocket, ssl_socket:internal_inet_values()), -		{ok, Tracker} = ssl_socket:inherit_tracker(ListenSocket, EmOpts, SslOpts), -		{ok, #sslsocket{pid = {ListenSocket, Config#config{emulated = Tracker}}}}; -	    Err = {error, _} -> -		Err -	end +	do_listen(Port, Config, connection_cb(Options0))      catch  	Error = {error, _} ->  	    Error @@ -171,27 +155,15 @@ transport_accept(ListenSocket) ->      transport_accept(ListenSocket, infinity).  transport_accept(#sslsocket{pid = {ListenSocket, -				   #config{transport_info =  {Transport,_,_, _} =CbInfo, -					   connection_cb = ConnectionCb, -					   ssl = SslOpts, -					   emulated = Tracker}}}, Timeout) when (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) -> -    case Transport:accept(ListenSocket, Timeout) of -	{ok, Socket} -> -	    {ok, EmOpts} = ssl_socket:get_emulated_opts(Tracker), -	    {ok, Port} = ssl_socket:port(Transport, Socket), -	    ConnArgs = [server, "localhost", Port, Socket, -			{SslOpts, emulated_socket_options(EmOpts, #socket_options{}), Tracker}, self(), CbInfo], -	    ConnectionSup = connection_sup(ConnectionCb), -	    case ConnectionSup:start_child(ConnArgs) of -		{ok, Pid} -> -		    ssl_connection:socket_control(ConnectionCb, Socket, Pid, Transport, Tracker); -		{error, Reason} -> -		    {error, Reason} -	    end; -	{error, Reason} -> -	    {error, Reason} +				   #config{connection_cb = ConnectionCb} = Config}}, Timeout)  +  when (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) -> +    case ConnectionCb of +	tls_connection -> +	    tls_socket:accept(ListenSocket, Config, Timeout); +	dtls_connection -> +	    dtls_socket:accept(ListenSocket, Config, Timeout)      end. - +    %%--------------------------------------------------------------------  -spec ssl_accept(#sslsocket{}) -> ok | {error, reason()}.  -spec ssl_accept(#sslsocket{} | port(), timeout()| [ssl_option() @@ -214,13 +186,14 @@ ssl_accept(ListenSocket, SslOptions)  when is_port(ListenSocket) ->      ssl_accept(ListenSocket, SslOptions, infinity).  ssl_accept(#sslsocket{} = Socket, [], Timeout) when (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity)-> -    ssl_accept(#sslsocket{} = Socket, Timeout); +    ssl_accept(Socket, Timeout);  ssl_accept(#sslsocket{fd = {_, _, _, Tracker}} = Socket, SslOpts0, Timeout) when        (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity)->      try -	{ok, EmOpts, InheritedSslOpts} = ssl_socket:get_all_opts(Tracker), +	{ok, EmOpts, InheritedSslOpts} = tls_socket:get_all_opts(Tracker),  	SslOpts = handle_options(SslOpts0, InheritedSslOpts), -	ssl_connection:handshake(Socket, {SslOpts, emulated_socket_options(EmOpts, #socket_options{})}, Timeout) +	ssl_connection:handshake(Socket, {SslOpts,  +					  tls_socket:emulated_socket_options(EmOpts, #socket_options{})}, Timeout)      catch  	Error = {error, _Reason} -> Error      end; @@ -228,15 +201,16 @@ ssl_accept(Socket, SslOptions, Timeout) when is_port(Socket),  					     (is_integer(Timeout) andalso Timeout >= 0) or (Timeout == infinity) ->      {Transport,_,_,_} =  	proplists:get_value(cb_info, SslOptions, {gen_tcp, tcp, tcp_closed, tcp_error}), -    EmulatedOptions = ssl_socket:emulated_options(), -    {ok, SocketValues} = ssl_socket:getopts(Transport, Socket, EmulatedOptions), +    EmulatedOptions = tls_socket:emulated_options(), +    {ok, SocketValues} = tls_socket:getopts(Transport, Socket, EmulatedOptions),      ConnetionCb = connection_cb(SslOptions),      try handle_options(SslOptions ++ SocketValues, server) of  	{ok, #config{transport_info = CbInfo, ssl = SslOpts, emulated = EmOpts}} -> -	    ok = ssl_socket:setopts(Transport, Socket, ssl_socket:internal_inet_values()), -	    {ok, Port} = ssl_socket:port(Transport, Socket), +	    ok = tls_socket:setopts(Transport, Socket, tls_socket:internal_inet_values()), +	    {ok, Port} = tls_socket:port(Transport, Socket),  	    ssl_connection:ssl_accept(ConnetionCb, Port, Socket, -				      {SslOpts, emulated_socket_options(EmOpts, #socket_options{}), undefined}, +				      {SslOpts,  +				       tls_socket:emulated_socket_options(EmOpts, #socket_options{}), undefined},  				      self(), CbInfo, Timeout)      catch  	Error = {error, _Reason} -> Error @@ -275,6 +249,8 @@ close(#sslsocket{pid = {ListenSocket, #config{transport_info={Transport,_, _, _}  %%--------------------------------------------------------------------  send(#sslsocket{pid = Pid}, Data) when is_pid(Pid) ->      ssl_connection:send(Pid, Data); +send(#sslsocket{pid = {_, #config{transport_info={gen_udp, _, _, _}}}}, _) -> +    {error,enotconn}; %% Emulate connection behaviour  send(#sslsocket{pid = {ListenSocket, #config{transport_info={Transport, _, _, _}}}}, Data) ->      Transport:send(ListenSocket, Data). %% {error,enotconn} @@ -358,9 +334,9 @@ connection_info(#sslsocket{} = SSLSocket) ->  %% Description: same as inet:peername/1.  %%--------------------------------------------------------------------  peername(#sslsocket{pid = Pid, fd = {Transport, Socket, _, _}}) when is_pid(Pid)-> -    ssl_socket:peername(Transport, Socket); +    tls_socket:peername(Transport, Socket);  peername(#sslsocket{pid = {ListenSocket,  #config{transport_info = {Transport,_,_,_}}}}) -> -    ssl_socket:peername(Transport, ListenSocket). %% Will return {error, enotconn} +    tls_socket:peername(Transport, ListenSocket). %% Will return {error, enotconn}  %%--------------------------------------------------------------------  -spec peercert(#sslsocket{}) ->{ok, DerCert::binary()} | {error, reason()}. @@ -456,7 +432,7 @@ getopts(#sslsocket{pid = Pid}, OptionTags) when is_pid(Pid), is_list(OptionTags)      ssl_connection:get_opts(Pid, OptionTags);  getopts(#sslsocket{pid = {_,  #config{transport_info = {Transport,_,_,_}}}} = ListenSocket,  	OptionTags) when is_list(OptionTags) -> -    try ssl_socket:getopts(Transport, ListenSocket, OptionTags) of +    try tls_socket:getopts(Transport, ListenSocket, OptionTags) of  	{ok, _} = Result ->  	    Result;  	{error, InetError} -> @@ -484,7 +460,7 @@ setopts(#sslsocket{pid = Pid}, Options0) when is_pid(Pid), is_list(Options0)  ->      end;  setopts(#sslsocket{pid = {_, #config{transport_info = {Transport,_,_,_}}}} = ListenSocket, Options) when is_list(Options) -> -    try ssl_socket:setopts(Transport, ListenSocket, Options) of +    try tls_socket:setopts(Transport, ListenSocket, Options) of  	ok ->  	    ok;  	{error, InetError} -> @@ -517,10 +493,10 @@ getstat(Socket) ->  %% Description: Get one or more statistic options for a socket.  %%--------------------------------------------------------------------  getstat(#sslsocket{pid = {Listen,  #config{transport_info = {Transport, _, _, _}}}}, Options) when is_port(Listen), is_list(Options) -> -    ssl_socket:getstat(Transport, Listen, Options); +    tls_socket:getstat(Transport, Listen, Options);  getstat(#sslsocket{pid = Pid, fd = {Transport, Socket, _, _}}, Options) when is_pid(Pid), is_list(Options) -> -    ssl_socket:getstat(Transport, Socket, Options). +    tls_socket:getstat(Transport, Socket, Options).  %%---------------------------------------------------------------  -spec shutdown(#sslsocket{}, read | write | read_write) ->  ok | {error, reason()}. @@ -539,10 +515,13 @@ shutdown(#sslsocket{pid = Pid}, How) ->  %% Description: Same as inet:sockname/1  %%--------------------------------------------------------------------  sockname(#sslsocket{pid = {Listen,  #config{transport_info = {Transport, _, _, _}}}}) when is_port(Listen) -> -    ssl_socket:sockname(Transport, Listen); - +    tls_socket:sockname(Transport, Listen); +sockname(#sslsocket{pid = {udp, #config{udp_handler = {Pid, _}}}}) -> +    dtls_udp_listener:sockname(Pid); +sockname(#sslsocket{pid = Pid, fd = {gen_udp= Transport, Socket, _, _}}) when is_pid(Pid) -> +    dtls_socket:sockname(Transport, Socket);  sockname(#sslsocket{pid = Pid, fd = {Transport, Socket, _, _}}) when is_pid(Pid) -> -    ssl_socket:sockname(Transport, Socket). +    tls_socket:sockname(Transport, Socket).  %%---------------------------------------------------------------  -spec session_info(#sslsocket{}) -> {ok, list()} | {error, reason()}. @@ -652,27 +631,12 @@ available_suites(all) ->      Version = tls_record:highest_protocol_version([]),			        ssl_cipher:filter_suites(ssl_cipher:all_suites(Version)). -do_connect(Address, Port, -	   #config{transport_info = CbInfo, inet_user = UserOpts, ssl = SslOpts, -		   emulated = EmOpts, inet_ssl = SocketOpts, connection_cb = ConnetionCb}, -	   Timeout) -> -    {Transport, _, _, _} = CbInfo, -    try Transport:connect(Address, Port,  SocketOpts, Timeout) of -	{ok, Socket} -> -	    ssl_connection:connect(ConnetionCb, Address, Port, Socket,  -				   {SslOpts, emulated_socket_options(EmOpts, #socket_options{}), undefined}, -				   self(), CbInfo, Timeout); -	{error, Reason} -> -	    {error, Reason} -    catch -	exit:{function_clause, _} -> -	    {error, {options, {cb_info, CbInfo}}}; -	exit:badarg -> -	    {error, {options, {socket_options, UserOpts}}}; -	exit:{badarg, _} -> -	    {error, {options, {socket_options, UserOpts}}} -    end. +do_listen(Port, #config{transport_info = {Transport, _, _, _}} = Config, tls_connection) -> +    tls_socket:listen(Transport, Port, Config); +do_listen(Port,  #config{transport_info = {Transport, _, _, _}} = Config, dtls_connection) -> +    dtls_socket:listen(Transport, Port, Config). +	  %% Handle extra ssl options given to ssl_accept  -spec handle_options([any()], #ssl_options{}) -> #ssl_options{}        ;             ([any()], client | server) -> {ok, #config{}}. @@ -732,6 +696,8 @@ handle_options(Opts0, Role) ->  		       [RecordCb:protocol_version(Vsn) || Vsn <- Vsns]  	       end, +    Protocol = proplists:get_value(protocol, Opts, tls), +      SSLOptions = #ssl_options{  		    versions   = Versions,  		    verify     = validate_option(verify, Verify), @@ -759,7 +725,7 @@ handle_options(Opts0, Role) ->  		    signature_algs = handle_hashsigns_option(proplists:get_value(signature_algs, Opts,   									     default_option_role(server,   												 tls_v1:default_signature_algs(Versions), Role)), -							 RecordCb:highest_protocol_version(Versions)),  +							 tls_version(RecordCb:highest_protocol_version(Versions))),   		    %% Server side option  		    reuse_session = handle_option(reuse_session, Opts, ReuseSessionFun),  		    reuse_sessions = handle_option(reuse_sessions, Opts, true), @@ -789,7 +755,7 @@ handle_options(Opts0, Role) ->  		    honor_ecc_order = handle_option(honor_ecc_order, Opts,  						       default_option_role(server, false, Role),  						       server, Role), -		    protocol = proplists:get_value(protocol, Opts, tls), +		    protocol = Protocol,   		    padding_check =  proplists:get_value(padding_check, Opts, true),  		    beast_mitigation = handle_option(beast_mitigation, Opts, one_n_minus_one),  		    fallback = handle_option(fallback, Opts, @@ -802,7 +768,7 @@ handle_options(Opts0, Role) ->  		    v2_hello_compatible = handle_option(v2_hello_compatible, Opts, false)  		   }, -    CbInfo  = proplists:get_value(cb_info, Opts, {gen_tcp, tcp, tcp_closed, tcp_error}), +    CbInfo  = proplists:get_value(cb_info, Opts, default_cb_info(Protocol)),      SslOptions = [protocol, versions, verify, verify_fun, partial_chain,  		  fail_if_no_peer_cert, verify_client_once,  		  depth, cert, certfile, key, keyfile, @@ -820,7 +786,7 @@ handle_options(Opts0, Role) ->  				   proplists:delete(Key, PropList)  			   end, Opts, SslOptions), -    {Sock, Emulated} = emulated_options(SockOpts), +    {Sock, Emulated} = emulated_options(Protocol, SockOpts),      ConnetionCb = connection_cb(Opts),      {ok, #config{ssl = SSLOptions, emulated = Emulated, inet_ssl = Sock, @@ -1139,8 +1105,13 @@ ca_cert_default(verify_peer, {Fun,_}, _) when is_function(Fun) ->  %% some trusted certs.  ca_cert_default(verify_peer, undefined, _) ->      "". -emulated_options(Opts) -> -    emulated_options(Opts, ssl_socket:internal_inet_values(), ssl_socket:default_inet_values()). +emulated_options(Protocol, Opts) -> +    case Protocol of +	tls -> +	    emulated_options(Opts, tls_socket:internal_inet_values(), tls_socket:default_inet_values()); +	dtls -> +	    emulated_options(Opts, dtls_socket:internal_inet_values(), dtls_socket:default_inet_values()) +    end.  emulated_options([{mode, Value} = Opt |Opts], Inet, Emulated) ->      validate_inet_option(mode, Value), @@ -1281,11 +1252,6 @@ record_cb(dtls) ->  record_cb(Opts) ->      record_cb(proplists:get_value(protocol, Opts, tls)). -connection_sup(tls_connection) -> -    tls_connection_sup; -connection_sup(dtls_connection) -> -    dtls_connection_sup. -  binary_filename(FileName) ->      Enc = file:native_name_encoding(),      unicode:characters_to_binary(FileName, unicode, Enc). @@ -1304,20 +1270,6 @@ assert_proplist([inet6 | Rest]) ->  assert_proplist([Value | _]) ->      throw({option_not_a_key_value_tuple, Value}). -emulated_socket_options(InetValues, #socket_options{ -				       mode   = Mode, -				       header = Header, -				       active = Active, -				       packet = Packet, -				       packet_size = Size}) -> -    #socket_options{ -       mode   = proplists:get_value(mode, InetValues, Mode), -       header = proplists:get_value(header, InetValues, Header), -       active = proplists:get_value(active, InetValues, Active), -       packet = proplists:get_value(packet, InetValues, Packet), -       packet_size = proplists:get_value(packet_size, InetValues, Size) -      }. -  new_ssl_options([], #ssl_options{} = Opts, _) ->       Opts;  new_ssl_options([{verify_client_once, Value} | Rest], #ssl_options{} = Opts, RecordCB) ->  @@ -1390,7 +1342,7 @@ new_ssl_options([{signature_algs, Value} | Rest], #ssl_options{} = Opts, RecordC      new_ssl_options(Rest,   		    Opts#ssl_options{signature_algs =   					 handle_hashsigns_option(Value,  -								 RecordCB:highest_protocol_version())},  +								 tls_version(RecordCB:highest_protocol_version()))},   		    RecordCB);  new_ssl_options([{Key, Value} | _Rest], #ssl_options{}, _) ->  @@ -1454,3 +1406,8 @@ default_option_role(Role, Value, Role) ->      Value;  default_option_role(_,_,_) ->      undefined. + +default_cb_info(tls) -> +    {gen_tcp, tcp, tcp_closed, tcp_error}; +default_cb_info(dtls) -> +    {gen_udp, udp, udp_closed, udp_error}. diff --git a/lib/ssl/src/ssl_alert.erl b/lib/ssl/src/ssl_alert.erl index 05dfb4c1b3..7b1603df6e 100644 --- a/lib/ssl/src/ssl_alert.erl +++ b/lib/ssl/src/ssl_alert.erl @@ -32,22 +32,13 @@  -include("ssl_record.hrl").  -include("ssl_internal.hrl"). --export([encode/3, decode/1, alert_txt/1, reason_code/2]). +-export([decode/1, alert_txt/1, reason_code/2]).  %%====================================================================  %% Internal application API  %%====================================================================  %%-------------------------------------------------------------------- --spec encode(#alert{}, ssl_record:ssl_version(), ssl_record:connection_states()) ->  -		    {iolist(), ssl_record:connection_states()}. -%% -%% Description: Encodes an alert -%%-------------------------------------------------------------------- -encode(#alert{} = Alert, Version, ConnectionStates) -> -    ssl_record:encode_alert_record(Alert, Version, ConnectionStates). - -%%--------------------------------------------------------------------  -spec decode(binary()) -> [#alert{}] | #alert{}.  %%  %% Description: Decode alert(s), will return a singel own alert if peer diff --git a/lib/ssl/src/ssl_cipher.erl b/lib/ssl/src/ssl_cipher.erl index 605bbd859a..32fec03b8e 100644 --- a/lib/ssl/src/ssl_cipher.erl +++ b/lib/ssl/src/ssl_cipher.erl @@ -40,7 +40,7 @@  	 ec_keyed_suites/0, anonymous_suites/1, psk_suites/1, srp_suites/0,  	 rc4_suites/1, des_suites/1, openssl_suite/1, openssl_suite_name/1, filter/2, filter_suites/1,  	 hash_algorithm/1, sign_algorithm/1, is_acceptable_hash/2, is_fallback/1, -	 random_bytes/1]). +	 random_bytes/1, calc_aad/3, calc_mac_hash/4]).  -export_type([cipher_suite/0,  	      erl_cipher_suite/0, openssl_cipher_suite/0, @@ -311,7 +311,9 @@ aead_decipher(Type, #cipher_state{key = Key, iv = IV} = CipherState,  suites({3, 0}) ->      ssl_v3:suites();  suites({3, N}) -> -    tls_v1:suites(N). +    tls_v1:suites(N); +suites(Version) -> +    suites(dtls_v1:corresponding_tls_version(Version)).  all_suites(Version) ->      suites(Version) @@ -1525,9 +1527,32 @@ is_fallback(CipherSuites)->  random_bytes(N) ->      crypto:strong_rand_bytes(N). +calc_aad(Type, {MajVer, MinVer}, +	 #{sequence_number := SeqNo}) -> +    <<SeqNo:64/integer, ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. + +calc_mac_hash(Type, Version, +	      PlainFragment, #{sequence_number := SeqNo, +			       mac_secret := MacSecret, +			       security_parameters:= +				   SecPars}) -> +    Length = erlang:iolist_size(PlainFragment), +    mac_hash(Version, SecPars#security_parameters.mac_algorithm, +	     MacSecret, SeqNo, Type, +	     Length, PlainFragment). +  %%--------------------------------------------------------------------  %%% Internal functions  %%-------------------------------------------------------------------- +mac_hash({_,_}, ?NULL, _MacSecret, _SeqNo, _Type, +	 _Length, _Fragment) -> +    <<>>; +mac_hash({3, 0}, MacAlg, MacSecret, SeqNo, Type, Length, Fragment) -> +    ssl_v3:mac_hash(MacAlg, MacSecret, SeqNo, Type, Length, Fragment); +mac_hash({3, N} = Version, MacAlg, MacSecret, SeqNo, Type, Length, Fragment)   +  when N =:= 1; N =:= 2; N =:= 3 -> +    tls_v1:mac_hash(MacAlg, MacSecret, SeqNo, Type, Version, +		      Length, Fragment).  bulk_cipher_algorithm(null) ->      ?NULL; diff --git a/lib/ssl/src/ssl_connection.erl b/lib/ssl/src/ssl_connection.erl index b6e4d5b433..6e7c8c5ddd 100644 --- a/lib/ssl/src/ssl_connection.erl +++ b/lib/ssl/src/ssl_connection.erl @@ -44,7 +44,7 @@  -export([send/2, recv/3, close/2, shutdown/2,  	 new_user/2, get_opts/2, set_opts/2, session_info/1,   	 peer_certificate/1, renegotiation/1, negotiated_protocol/1, prf/5, -	 connection_information/1 +	 connection_information/1, handle_common_event/5  	]).  %% General gen_statem state functions with extra callback argument  @@ -71,7 +71,8 @@  %%====================================================================	       %%--------------------------------------------------------------------  -spec connect(tls_connection | dtls_connection, -	      host(), inet:port_number(), port(), +	      host(), inet:port_number(),  +	      port() | {tuple(), port()}, %% TLS | DTLS    	      {#ssl_options{}, #socket_options{},  	       %% Tracker only needed on server side  	       undefined}, @@ -145,14 +146,24 @@ socket_control(Connection, Socket, Pid, Transport) ->  -spec socket_control(tls_connection | dtls_connection, port(), pid(), atom(), pid()| undefined) ->       {ok, #sslsocket{}} | {error, reason()}.    %%--------------------------------------------------------------------	     -socket_control(Connection, Socket, Pid, Transport, ListenTracker) -> +socket_control(Connection, Socket, Pid, Transport, udp_listner) -> +    %% dtls listner process must have the socket control +    {ok, dtls_socket:socket(Pid, Transport, Socket, Connection)}; + +socket_control(tls_connection = Connection, Socket, Pid, Transport, ListenTracker) ->      case Transport:controlling_process(Socket, Pid) of  	ok -> -	    {ok, ssl_socket:socket(Pid, Transport, Socket, Connection, ListenTracker)}; +	    {ok, tls_socket:socket(Pid, Transport, Socket, Connection, ListenTracker)}; +	{error, Reason}	-> +	    {error, Reason} +    end; +socket_control(dtls_connection = Connection, {_, Socket}, Pid, Transport, ListenTracker) -> +    case Transport:controlling_process(Socket, Pid) of +	ok -> +	    {ok, tls_socket:socket(Pid, Transport, Socket, Connection, ListenTracker)};  	{error, Reason}	->  	    {error, Reason}      end. -  %%--------------------------------------------------------------------  -spec send(pid(), iodata()) -> ok | {error, reason()}.  %% @@ -461,7 +472,7 @@ certify(internal, #certificate{asn1_certificates = []},  	#state{role = server, negotiated_version = Version,  	       ssl_options = #ssl_options{verify = verify_peer,  					  fail_if_no_peer_cert = true}} = -	    State, _Connection) -> +	    State, _) ->      Alert =  ?ALERT_REC(?FATAL,?HANDSHAKE_FAILURE),      handle_own_alert(Alert, Version, certify, State); @@ -478,7 +489,7 @@ certify(internal, #certificate{},  	#state{role = server,  	       negotiated_version = Version,  	       ssl_options = #ssl_options{verify = verify_none}} = -	    State, _Connection) -> +	    State, _) ->      Alert =  ?ALERT_REC(?FATAL,?UNEXPECTED_MESSAGE, unrequested_certificate),      handle_own_alert(Alert, Version, certify, State); @@ -788,7 +799,7 @@ connection(Type, Msg, State, Connection) ->  downgrade(internal, #alert{description = ?CLOSE_NOTIFY},  	  #state{transport_cb = Transport, socket = Socket,  		 downgrade = {Pid, From}} = State, _) -> -    ssl_socket:setopts(Transport, Socket, [{active, false}, {packet, 0}, {mode, binary}]), +    tls_socket:setopts(Transport, Socket, [{active, false}, {packet, 0}, {mode, binary}]),      Transport:controlling_process(Socket, Pid),      gen_statem:reply(From, {ok, Socket}),      {stop, normal, State}; @@ -819,7 +830,7 @@ handle_common_event(internal, {handshake, {Handshake, Raw}}, StateName,      %% a client_hello, which needs to be determined by the connection callback.      %% In other cases this is a noop      State = handle_sni_extension(PossibleSNI, State0), -    HsHist = ssl_handshake:update_handshake_history(Hs0, Raw, V2HComp), +    HsHist = ssl_handshake:update_handshake_history(Hs0, iolist_to_binary(Raw), V2HComp),      {next_state, StateName, State#state{tls_handshake_history = HsHist},        [{next_event, internal, Handshake}]};  handle_common_event(internal, {protocol_record, TLSorDTLSRecord}, StateName, State, Connection) ->  @@ -864,13 +875,13 @@ handle_call({shutdown, How0}, From, _,  	    #state{transport_cb = Transport,  		   negotiated_version = Version,  		   connection_states = ConnectionStates, -		   socket = Socket}, _) -> +		   socket = Socket}, Connection) ->      case How0 of  	How when How == write; How == both ->	      	    Alert = ?ALERT_REC(?WARNING, ?CLOSE_NOTIFY),  	    {BinMsg, _} = -		ssl_alert:encode(Alert, Version, ConnectionStates), -	    Transport:send(Socket, BinMsg); +		Connection:encode_alert(Alert, Version, ConnectionStates), +	    Connection:send(Transport, Socket, BinMsg);  	_ ->  	    ok      end, @@ -1025,8 +1036,8 @@ terminate(Reason, connection, #state{negotiated_version = Version,  				     transport_cb = Transport, socket = Socket  				    } = State) ->      handle_trusted_certs_db(State), -    {BinAlert, ConnectionStates} = terminate_alert(Reason, Version, ConnectionStates0), -    Transport:send(Socket, BinAlert), +    {BinAlert, ConnectionStates} = terminate_alert(Reason, Version, ConnectionStates0, Connection), +    Connection:send(Transport, Socket, BinAlert),      Connection:close(Reason, Socket, Transport, ConnectionStates, Check);  terminate(Reason, _StateName, #state{transport_cb = Transport, protocol_cb = Connection, @@ -1079,8 +1090,8 @@ write_application_data(Data0, From,  	    Connection:renegotiate(State#state{renegotiation = {true, internal}},   			[{next_event, {call, From}, {application_data, Data0}}]);  	false -> -	    {Msgs, ConnectionStates} = ssl_record:encode_data(Data, Version, ConnectionStates0), -	    Result = Transport:send(Socket, Msgs), +	    {Msgs, ConnectionStates} = Connection:encode_data(Data, Version, ConnectionStates0), +	    Result = Connection:send(Transport, Socket, Msgs),  	        ssl_connection:hibernate_after(connection, State#state{connection_states = ConnectionStates},   					       [{reply, From, Result}])      end. @@ -1913,7 +1924,7 @@ get_socket_opts(Transport, Socket, [active | Tags], SockOpts, Acc) ->      get_socket_opts(Transport, Socket, Tags, SockOpts,   		    [{active, SockOpts#socket_options.active} | Acc]);  get_socket_opts(Transport, Socket, [Tag | Tags], SockOpts, Acc) -> -    try ssl_socket:getopts(Transport, Socket, [Tag]) of +    try tls_socket:getopts(Transport, Socket, [Tag]) of  	{ok, [Opt]} ->  	    get_socket_opts(Transport, Socket, Tags, SockOpts, [Opt | Acc]);  	{error, Error} -> @@ -1929,7 +1940,7 @@ set_socket_opts(_,_, [], SockOpts, []) ->      {ok, SockOpts};  set_socket_opts(Transport, Socket, [], SockOpts, Other) ->      %% Set non emulated options  -    try ssl_socket:setopts(Transport, Socket, Other) of +    try tls_socket:setopts(Transport, Socket, Other) of  	ok ->  	    {ok, SockOpts};  	{error, InetError} -> @@ -1995,17 +2006,17 @@ hibernate_after(connection = StateName,  hibernate_after(StateName, State, Actions) ->      {next_state, StateName, State, Actions}. -terminate_alert(normal, Version, ConnectionStates)  -> -    ssl_alert:encode(?ALERT_REC(?WARNING, ?CLOSE_NOTIFY), +terminate_alert(normal, Version, ConnectionStates, Connection)  -> +    Connection:encode_alert(?ALERT_REC(?WARNING, ?CLOSE_NOTIFY),  		     Version, ConnectionStates); -terminate_alert({Reason, _}, Version, ConnectionStates) when Reason == close; -							     Reason == shutdown -> -    ssl_alert:encode(?ALERT_REC(?WARNING, ?CLOSE_NOTIFY), +terminate_alert({Reason, _}, Version, ConnectionStates, Connection) when Reason == close; +									 Reason == shutdown -> +    Connection:encode_alert(?ALERT_REC(?WARNING, ?CLOSE_NOTIFY),  		     Version, ConnectionStates); -terminate_alert(_, Version, ConnectionStates) -> -    {BinAlert, _} = ssl_alert:encode(?ALERT_REC(?FATAL, ?INTERNAL_ERROR), -				 Version, ConnectionStates), +terminate_alert(_, Version, ConnectionStates, Connection) -> +    {BinAlert, _} = Connection:encode_alert(?ALERT_REC(?FATAL, ?INTERNAL_ERROR), +					    Version, ConnectionStates),      BinAlert.  handle_trusted_certs_db(#state{ssl_options =  @@ -2285,7 +2296,7 @@ format_reply(_, _,#socket_options{active = false, mode = Mode, packet = Packet,      {ok, do_format_reply(Mode, Packet, Header, Data)};  format_reply(Transport, Socket, #socket_options{active = _, mode = Mode, packet = Packet,  						header = Header}, Data, Tracker, Connection) -> -    {ssl, ssl_socket:socket(self(), Transport, Socket, Connection, Tracker),  +    {ssl, tls_socket:socket(self(), Transport, Socket, Connection, Tracker),        do_format_reply(Mode, Packet, Header, Data)}.  deliver_packet_error(Transport, Socket, SO= #socket_options{active = Active}, Data, Pid, From, Tracker, Connection) -> @@ -2294,7 +2305,7 @@ deliver_packet_error(Transport, Socket, SO= #socket_options{active = Active}, Da  format_packet_error(_, _,#socket_options{active = false, mode = Mode}, Data, _, _) ->      {error, {invalid_packet, do_format_reply(Mode, raw, 0, Data)}};  format_packet_error(Transport, Socket, #socket_options{active = _, mode = Mode}, Data, Tracker, Connection) -> -    {ssl_error, ssl_socket:socket(self(), Transport, Socket, Connection, Tracker),  +    {ssl_error, tls_socket:socket(self(), Transport, Socket, Connection, Tracker),        {invalid_packet, do_format_reply(Mode, raw, 0, Data)}}.  do_format_reply(binary, _, N, Data) when N > 0 ->  % Header mode @@ -2349,11 +2360,11 @@ alert_user(Transport, Tracker, Socket, Active, Pid, From, Alert, Role, Connectio      case ssl_alert:reason_code(Alert, Role) of  	closed ->  	    send_or_reply(Active, Pid, From, -			  {ssl_closed, ssl_socket:socket(self(),  +			  {ssl_closed, tls_socket:socket(self(),   							 Transport, Socket, Connection, Tracker)});  	ReasonCode ->  	    send_or_reply(Active, Pid, From, -			  {ssl_error, ssl_socket:socket(self(),  +			  {ssl_error, tls_socket:socket(self(),   							Transport, Socket, Connection, Tracker), ReasonCode})      end. @@ -2366,12 +2377,13 @@ log_alert(false, _, _) ->  handle_own_alert(Alert, Version, StateName,   		 #state{transport_cb = Transport,  			socket = Socket, +			protocol_cb = Connection,  			connection_states = ConnectionStates,  			ssl_options = SslOpts} = State) ->      try %% Try to tell the other side  	{BinMsg, _} = -	ssl_alert:encode(Alert, Version, ConnectionStates), -	Transport:send(Socket, BinMsg) +	Connection:encode_alert(Alert, Version, ConnectionStates), +	Connection:send(Transport, Socket, BinMsg)      catch _:_ ->  %% Can crash if we are in a uninitialized state  	    ignore      end, diff --git a/lib/ssl/src/ssl_connection.hrl b/lib/ssl/src/ssl_connection.hrl index fca3e11894..2027652a7f 100644 --- a/lib/ssl/src/ssl_connection.hrl +++ b/lib/ssl/src/ssl_connection.hrl @@ -43,7 +43,7 @@  	  error_tag             :: atom(),   % ex tcp_error            host                  :: string() | inet:ip_address(),            port                  :: integer(), -          socket                :: port(), +          socket                :: port() | tuple(), %% TODO: dtls socket            ssl_options           :: #ssl_options{},            socket_options        :: #socket_options{},            connection_states     :: ssl_record:connection_states() | secret_printout(), @@ -81,17 +81,18 @@  	  allow_renegotiate = true                    ::boolean(),            expecting_next_protocol_negotiation = false ::boolean(),  	  expecting_finished =                  false ::boolean(), -          negotiated_protocol = undefined             :: undefined | binary(), +          next_protocol = undefined                   :: undefined | binary(), +	  negotiated_protocol,  	  tracker              :: pid() | 'undefined', %% Tracker process for listen socket  	  sni_hostname = undefined,  	  downgrade, -	  flight_buffer = []   :: list()  %% Buffer of TLS/DTLS records, used during the TLS handshake -				          %% to when possible pack more than on TLS record into the  -                                          %% underlaying packet format. Introduced by DTLS - RFC 4347. -				          %% The mecahnism is also usefull in TLS although we do not -				          %% need to worry about packet loss in TLS. +	  flight_buffer = []   :: list() | map(),  %% Buffer of TLS/DTLS records, used during the TLS handshake +				   %% to when possible pack more than on TLS record into the  +				   %% underlaying packet format. Introduced by DTLS - RFC 4347. +				   %% The mecahnism is also usefull in TLS although we do not +				   %% need to worry about packet loss in TLS. In DTLS we need to track DTLS handshake seqnr +	 flight_state = reliable  %% reliable | {retransmit, integer()}| {waiting, ref(), integer()} - last two is used in DTLS over udp.     	 }). -  -define(DEFAULT_DIFFIE_HELLMAN_PARAMS,  	#'DHParameter'{prime = ?DEFAULT_DIFFIE_HELLMAN_PRIME,  		       base = ?DEFAULT_DIFFIE_HELLMAN_GENERATOR}). diff --git a/lib/ssl/src/ssl_dist_sup.erl b/lib/ssl/src/ssl_dist_sup.erl index a6eb1be1f6..d47cd76bf5 100644 --- a/lib/ssl/src/ssl_dist_sup.erl +++ b/lib/ssl/src/ssl_dist_sup.erl @@ -85,10 +85,10 @@ proxy_server_child_spec() ->      {Name, StartFunc, Restart, Shutdown, Type, Modules}.  listen_options_tracker_child_spec() -> -    Name = ssl_socket_dist,   +    Name = tls_socket_dist,        StartFunc = {ssl_listen_tracker_sup, start_link_dist, []},      Restart = permanent,       Shutdown = 4000, -    Modules = [ssl_socket], +    Modules = [tls_socket],      Type = supervisor,      {Name, StartFunc, Restart, Shutdown, Type, Modules}. diff --git a/lib/ssl/src/ssl_internal.hrl b/lib/ssl/src/ssl_internal.hrl index 487d1fa096..cbfcaa46a0 100644 --- a/lib/ssl/src/ssl_internal.hrl +++ b/lib/ssl/src/ssl_internal.hrl @@ -156,7 +156,8 @@  -record(config, {ssl,               %% SSL parameters  		 inet_user,         %% User set inet options -		 emulated,          %% Emulated option list or "inherit_tracker" pid  +		 emulated,          %% Emulated option list or "inherit_tracker" pid +		 udp_handler,  		 inet_ssl,          %% inet options for internal ssl socket  		 transport_info,                 %% Callback info  		 connection_cb diff --git a/lib/ssl/src/ssl_listen_tracker_sup.erl b/lib/ssl/src/ssl_listen_tracker_sup.erl index 7f685a2ead..f7e97bcb76 100644 --- a/lib/ssl/src/ssl_listen_tracker_sup.erl +++ b/lib/ssl/src/ssl_listen_tracker_sup.erl @@ -57,10 +57,10 @@ init(_O) ->      MaxT = 3600,      Name = undefined, % As simple_one_for_one is used. -    StartFunc = {ssl_socket, start_link, []}, +    StartFunc = {tls_socket, start_link, []},      Restart = temporary, % E.g. should not be restarted      Shutdown = 4000, -    Modules = [ssl_socket], +    Modules = [tls_socket],      Type = worker,      ChildSpec = {Name, StartFunc, Restart, Shutdown, Type, Modules}, diff --git a/lib/ssl/src/ssl_record.erl b/lib/ssl/src/ssl_record.erl index 71cd0279f3..b10069c3cb 100644 --- a/lib/ssl/src/ssl_record.erl +++ b/lib/ssl/src/ssl_record.erl @@ -41,10 +41,6 @@  	 set_server_verify_data/3,  	 empty_connection_state/2, initial_connection_state/2, record_protocol_role/1]). -%% Encoding records --export([encode_handshake/3, encode_alert_record/3, -	 encode_change_cipher_spec/2, encode_data/3]). -  %% Compression  -export([compress/3, uncompress/3, compressions/0]). @@ -52,6 +48,9 @@  -export([cipher/4, decipher/4, is_correct_mac/2,  	 cipher_aead/4, decipher_aead/4]). +%% Encoding +-export([encode_plain_text/4]). +  -export_type([ssl_version/0, ssl_atom_version/0, connection_states/0, connection_state/0]).  -type ssl_version()       :: {integer(), integer()}. @@ -272,70 +271,26 @@ set_pending_cipher_state(#{pending_read := Read,        pending_read => Read#{cipher_state => ServerState},        pending_write => Write#{cipher_state => ClientState}}. - -%%-------------------------------------------------------------------- --spec encode_handshake(iolist(), ssl_version(), connection_states()) -> -			      {iolist(), connection_states()}. -%% -%% Description: Encodes a handshake message to send on the ssl-socket. -%%-------------------------------------------------------------------- -encode_handshake(Frag, Version,  -		 #{current_write := -		       #{beast_mitigation := BeastMitigation, -			  security_parameters := -			     #security_parameters{bulk_cipher_algorithm = BCA}}} =  -		     ConnectionStates) -  when is_list(Frag) -> -    case iolist_size(Frag) of -	N  when N > ?MAX_PLAIN_TEXT_LENGTH -> -	    Data = split_bin(iolist_to_binary(Frag), ?MAX_PLAIN_TEXT_LENGTH, Version, BCA, BeastMitigation), -	    encode_iolist(?HANDSHAKE, Data, Version, ConnectionStates); -	_  -> -	    encode_plain_text(?HANDSHAKE, Version, Frag, ConnectionStates) -    end; -%% TODO: this is a workarround for DTLS -%% -%% DTLS need to select the connection write state based on Epoch it wants to -%% send this fragment in. That Epoch does not nessarily has to be the same -%% as the current_write epoch. -%% The right solution might be to pass the WriteState instead of the ConnectionStates, -%% however, this will require substantion API changes. -encode_handshake(Frag, Version, ConnectionStates) -> -    encode_plain_text(?HANDSHAKE, Version, Frag, ConnectionStates). - -%%-------------------------------------------------------------------- --spec encode_alert_record(#alert{}, ssl_version(), connection_states()) -> -				 {iolist(), connection_states()}. -%% -%% Description: Encodes an alert message to send on the ssl-socket. -%%-------------------------------------------------------------------- -encode_alert_record(#alert{level = Level, description = Description}, -                    Version, ConnectionStates) -> -    encode_plain_text(?ALERT, Version, <<?BYTE(Level), ?BYTE(Description)>>, -		      ConnectionStates). - -%%-------------------------------------------------------------------- --spec encode_change_cipher_spec(ssl_version(), connection_states()) -> -				       {iolist(), connection_states()}. -%% -%% Description: Encodes a change_cipher_spec-message to send on the ssl socket. -%%-------------------------------------------------------------------- -encode_change_cipher_spec(Version, ConnectionStates) -> -    encode_plain_text(?CHANGE_CIPHER_SPEC, Version, <<1:8>>, ConnectionStates). - -%%-------------------------------------------------------------------- --spec encode_data(binary(), ssl_version(), connection_states()) -> -			 {iolist(), connection_states()}. -%% -%% Description: Encodes data to send on the ssl-socket. -%%-------------------------------------------------------------------- -encode_data(Frag, Version, -	    #{current_write := #{beast_mitigation := BeastMitigation, -				 security_parameters := -				     #security_parameters{bulk_cipher_algorithm = BCA}}} = -		ConnectionStates) -> -    Data = split_bin(Frag, ?MAX_PLAIN_TEXT_LENGTH, Version, BCA, BeastMitigation), -    encode_iolist(?APPLICATION_DATA, Data, Version, ConnectionStates). +encode_plain_text(Type, Version, Data, #{compression_state := CompS0, +					 security_parameters := +					     #security_parameters{ +						cipher_type = ?AEAD, +						compression_algorithm = CompAlg} +					} = WriteState0) -> +    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), +    WriteState1 = WriteState0#{compression_state => CompS1}, +    AAD = ssl_cipher:calc_aad(Type, Version, WriteState1), +    ssl_record:cipher_aead(Version, Comp, WriteState1, AAD); +encode_plain_text(Type, Version, Data, #{compression_state := CompS0, +					 security_parameters := +					     #security_parameters{compression_algorithm = CompAlg} +					}= WriteState0) -> +    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), +    WriteState1 = WriteState0#{compression_state => CompS1}, +    MacHash = ssl_cipher:calc_mac_hash(Type, Version, Comp, WriteState1), +    ssl_record:cipher(Version, Comp, WriteState1, MacHash); +encode_plain_text(_,_,_,CS) -> +    exit({cs, CS}).  uncompress(?NULL, Data, CS) ->      {Data, CS}. @@ -451,11 +406,6 @@ random() ->      Random_28_bytes = ssl_cipher:random_bytes(28),      <<?UINT32(Secs_since_1970), Random_28_bytes/binary>>. -%% dtls_next_epoch(#connection_state{epoch = undefined}) -> %% SSL/TLS -%%     undefined; -%% dtls_next_epoch(#connection_state{epoch = Epoch}) -> %% DTLS -%%     Epoch + 1. -  is_correct_mac(Mac, Mac) ->      true;  is_correct_mac(_M,_H) -> @@ -484,47 +434,3 @@ initial_security_params(ConnectionEnd) ->  				     compression_algorithm = ?NULL},      ssl_cipher:security_parameters(?TLS_NULL_WITH_NULL_NULL, SecParams). - -encode_plain_text(Type, Version, Data, ConnectionStates) -> -    RecordCB = protocol_module(Version), -    RecordCB:encode_plain_text(Type, Version, Data, ConnectionStates). - -encode_iolist(Type, Data, Version, ConnectionStates0) -> -    RecordCB = protocol_module(Version), -    {ConnectionStates, EncodedMsg} = -        lists:foldl(fun(Text, {CS0, Encoded}) -> -			    {Enc, CS1} = -				RecordCB:encode_plain_text(Type, Version, Text, CS0), -			    {CS1, [Enc | Encoded]} -		    end, {ConnectionStates0, []}, Data), -    {lists:reverse(EncodedMsg), ConnectionStates}. - -%% 1/n-1 splitting countermeasure Rizzo/Duong-Beast, RC4 chiphers are -%% not vulnerable to this attack. -split_bin(<<FirstByte:8, Rest/binary>>, ChunkSize, Version, BCA, one_n_minus_one) when -      BCA =/= ?RC4 andalso ({3, 1} == Version orelse -			    {3, 0} == Version) -> -    do_split_bin(Rest, ChunkSize, [[FirstByte]]); -%% 0/n splitting countermeasure for clients that are incompatible with 1/n-1 -%% splitting. -split_bin(Bin, ChunkSize, Version, BCA, zero_n) when -      BCA =/= ?RC4 andalso ({3, 1} == Version orelse -			    {3, 0} == Version) -> -    do_split_bin(Bin, ChunkSize, [[<<>>]]); -split_bin(Bin, ChunkSize, _, _, _) -> -    do_split_bin(Bin, ChunkSize, []). - -do_split_bin(<<>>, _, Acc) -> -    lists:reverse(Acc); -do_split_bin(Bin, ChunkSize, Acc) -> -    case Bin of -        <<Chunk:ChunkSize/binary, Rest/binary>> -> -            do_split_bin(Rest, ChunkSize, [Chunk | Acc]); -        _ -> -            lists:reverse(Acc, [Bin]) -    end. - -protocol_module({3, _}) -> -    tls_record; -protocol_module({254, _}) -> -    dtls_record. diff --git a/lib/ssl/src/ssl_sup.erl b/lib/ssl/src/ssl_sup.erl index ba20f65f44..8245801139 100644 --- a/lib/ssl/src/ssl_sup.erl +++ b/lib/ssl/src/ssl_sup.erl @@ -46,14 +46,17 @@ start_link() ->  init([]) ->          SessionCertManager = session_and_cert_manager_child_spec(),      TLSConnetionManager = tls_connection_manager_child_spec(), -    %% Not supported yet -    %%DTLSConnetionManager = dtls_connection_manager_child_spec(), -    %% Handles emulated options so that they inherited by the accept socket, even when setopts is performed on  -    %% the listen socket +    %% Handles emulated options so that they inherited by the accept +    %% socket, even when setopts is performed on the listen socket      ListenOptionsTracker = listen_options_tracker_child_spec(),  +     +    DTLSConnetionManager = dtls_connection_manager_child_spec(), +    DTLSUdpListeners = dtls_udp_listeners_spec(), +      {ok, {{one_for_all, 10, 3600}, [SessionCertManager, TLSConnetionManager,  -				    %%DTLSConnetionManager,  -				    ListenOptionsTracker]}}. +				    ListenOptionsTracker, +				    DTLSConnetionManager, DTLSUdpListeners +				   ]}}.  manager_opts() -> @@ -94,24 +97,32 @@ tls_connection_manager_child_spec() ->      Type = supervisor,      {Name, StartFunc, Restart, Shutdown, Type, Modules}. -%% dtls_connection_manager_child_spec() -> -%%     Name = dtls_connection, -%%     StartFunc = {dtls_connection_sup, start_link, []}, -%%     Restart = permanent, -%%     Shutdown = 4000, -%%     Modules = [dtls_connection, ssl_connection], -%%     Type = supervisor, -%%     {Name, StartFunc, Restart, Shutdown, Type, Modules}. +dtls_connection_manager_child_spec() -> +    Name = dtls_connection, +    StartFunc = {dtls_connection_sup, start_link, []}, +    Restart = permanent, +    Shutdown = 4000, +    Modules = [dtls_connection_sup], +    Type = supervisor, +    {Name, StartFunc, Restart, Shutdown, Type, Modules}.  listen_options_tracker_child_spec() -> -    Name = ssl_socket,   +    Name = tls_socket,        StartFunc = {ssl_listen_tracker_sup, start_link, []},      Restart = permanent,       Shutdown = 4000, -    Modules = [ssl_socket], +    Modules = [tls_socket],      Type = supervisor,      {Name, StartFunc, Restart, Shutdown, Type, Modules}. +dtls_udp_listeners_spec() -> +    Name = dtls_udp_listener,   +    StartFunc = {dtls_udp_sup, start_link, []}, +    Restart = permanent,  +    Shutdown = 4000, +    Modules = [], +    Type = supervisor, +    {Name, StartFunc, Restart, Shutdown, Type, Modules}.  session_cb_init_args() ->      case application:get_env(ssl, session_cb_init_args) of diff --git a/lib/ssl/src/tls_connection.erl b/lib/ssl/src/tls_connection.erl index 932bb139c1..32991d3079 100644 --- a/lib/ssl/src/tls_connection.erl +++ b/lib/ssl/src/tls_connection.erl @@ -45,6 +45,8 @@  %% Setup  -export([start_fsm/8, start_link/7, init/1]). +-export([encode_data/3, encode_alert/3]). +  %% State transition handling	   -export([next_record/1, next_event/3]). @@ -57,7 +59,7 @@  -export([send_alert/2, close/5]).  %% Data handling --export([passive_receive/2, next_record_if_active/1, handle_common_event/4]). +-export([passive_receive/2, next_record_if_active/1, handle_common_event/4, send/3]).  %% gen_statem state functions  -export([init/3, error/3, downgrade/3, %% Initiation and take down states @@ -114,7 +116,7 @@ queue_handshake(Handshake, #state{negotiated_version = Version,  send_handshake_flight(#state{socket = Socket,  			     transport_cb = Transport,  			     flight_buffer = Flight} = State0) -> -    Transport:send(Socket, Flight), +    send(Transport, Socket, Flight),      State0#state{flight_buffer = []}.  queue_change_cipher(Msg, #state{negotiated_version = Version, @@ -130,8 +132,8 @@ send_alert(Alert, #state{negotiated_version = Version,  			 transport_cb = Transport,  			 connection_states = ConnectionStates0} = State0) ->      {BinMsg, ConnectionStates} = -	ssl_alert:encode(Alert, Version, ConnectionStates0), -    Transport:send(Socket, BinMsg), +	encode_alert(Alert, Version, ConnectionStates0), +    send(Transport, Socket, BinMsg),      State0#state{connection_states = ConnectionStates}.  reinit_handshake_data(State) -> @@ -149,6 +151,18 @@ select_sni_extension(#client_hello{extensions = HelloExtensions}) ->  select_sni_extension(_) ->      undefined. +encode_data(Data, Version, ConnectionStates0)-> +    tls_record:encode_data(Data, Version, ConnectionStates0). + +%%-------------------------------------------------------------------- +-spec encode_alert(#alert{}, ssl_record:ssl_version(), ssl_record:connection_states()) ->  +		    {iolist(), ssl_record:connection_states()}. +%% +%% Description: Encodes an alert +%%-------------------------------------------------------------------- +encode_alert(#alert{} = Alert, Version, ConnectionStates) -> +    tls_record:encode_alert_record(Alert, Version, ConnectionStates). +  %%====================================================================  %% tls_connection_sup API  %%==================================================================== @@ -205,7 +219,7 @@ init({call, From}, {start, Timeout},      Handshake0 = ssl_handshake:init_handshake_history(),      {BinMsg, ConnectionStates, Handshake} =          encode_handshake(Hello,  HelloVersion, ConnectionStates0, Handshake0, V2HComp), -    Transport:send(Socket, BinMsg), +    send(Transport, Socket, BinMsg),      State1 = State0#state{connection_states = ConnectionStates,  			  negotiated_version = Version, %% Requested version  			  session = @@ -450,6 +464,9 @@ handle_common_event(internal, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, Sta  handle_common_event(internal, #ssl_tls{type = _Unknown}, StateName, State) ->      {next_state, StateName, State}. +send(Transport, Socket, Data) -> +   tls_socket:send(Transport, Socket, Data). +  %%--------------------------------------------------------------------  %% gen_statem callbacks  %%-------------------------------------------------------------------- @@ -476,11 +493,11 @@ encode_handshake(Handshake, Version, ConnectionStates0, Hist0, V2HComp) ->      Frag = tls_handshake:encode_handshake(Handshake, Version),      Hist = ssl_handshake:update_handshake_history(Hist0, Frag, V2HComp),      {Encoded, ConnectionStates} = -        ssl_record:encode_handshake(Frag, Version, ConnectionStates0), +        tls_record:encode_handshake(Frag, Version, ConnectionStates0),      {Encoded, ConnectionStates, Hist}.  encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> -    ssl_record:encode_change_cipher_spec(Version, ConnectionStates). +    tls_record:encode_change_cipher_spec(Version, ConnectionStates).  decode_alerts(Bin) ->      ssl_alert:decode(Bin). @@ -553,7 +570,7 @@ next_record(#state{protocol_buffers =  next_record(#state{protocol_buffers = #protocol_buffers{tls_packets = [], tls_cipher_texts = []},  		   socket = Socket,  		   transport_cb = Transport} = State) -> -    ssl_socket:setopts(Transport, Socket, [{active,once}]), +    tls_socket:setopts(Transport, Socket, [{active,once}]),      {no_record, State};  next_record(State) ->      {no_record, State}. @@ -622,8 +639,8 @@ renegotiate(#state{role = server,      Frag = tls_handshake:encode_handshake(HelloRequest, Version),      Hs0 = ssl_handshake:init_handshake_history(),      {BinMsg, ConnectionStates} =  -	ssl_record:encode_handshake(Frag, Version, ConnectionStates0), -    Transport:send(Socket, BinMsg), +	tls_record:encode_handshake(Frag, Version, ConnectionStates0), +    send(Transport, Socket, BinMsg),      State1 = State0#state{connection_states =   			     ConnectionStates,  			 tls_handshake_history = Hs0}, @@ -642,7 +659,7 @@ handle_alerts([Alert | Alerts], {next_state, StateName, State, _Actions}) ->  %% User closes or recursive call!  close({close, Timeout}, Socket, Transport = gen_tcp, _,_) -> -    ssl_socket:setopts(Transport, Socket, [{active, false}]), +    tls_socket:setopts(Transport, Socket, [{active, false}]),      Transport:shutdown(Socket, write),      _ = Transport:recv(Socket, 0, Timeout),      ok; @@ -684,7 +701,7 @@ gen_handshake(GenConnection, StateName, Type, Event,  	    Result      catch   	_:_ -> -	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,  + 	    ssl_connection:handle_own_alert(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE,   						       malformed_handshake_data),  					    Version, StateName, State)        end. diff --git a/lib/ssl/src/tls_record.erl b/lib/ssl/src/tls_record.erl index 5331dd1303..aa8a2aa334 100644 --- a/lib/ssl/src/tls_record.erl +++ b/lib/ssl/src/tls_record.erl @@ -34,10 +34,9 @@  %% Handling of incoming data  -export([get_tls_records/2, init_connection_states/2]). -%% Decoding --export([decode_cipher_text/3]). - -%% Encoding +%% Encoding TLS records +-export([encode_handshake/3, encode_alert_record/3, +	 encode_change_cipher_spec/2, encode_data/3]).  -export([encode_plain_text/4]).  %% Protocol version handling @@ -46,6 +45,9 @@  	 is_higher/2, supported_protocol_versions/0,  	 is_acceptable_version/1, is_acceptable_version/2]). +%% Decoding +-export([decode_cipher_text/3]). +  -export_type([tls_version/0, tls_atom_version/0]).  -type tls_version()       :: ssl_record:ssl_version(). @@ -85,152 +87,61 @@ get_tls_records(Data, <<>>) ->  get_tls_records(Data, Buffer) ->      get_tls_records_aux(list_to_binary([Buffer, Data]), []). -get_tls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer), -		     ?UINT16(Length), Data:Length/binary, Rest/binary>>,  -		    Acc) -> -    get_tls_records_aux(Rest, [#ssl_tls{type = ?APPLICATION_DATA, -					version = {MajVer, MinVer}, -					fragment = Data} | Acc]); -get_tls_records_aux(<<?BYTE(?HANDSHAKE),?BYTE(MajVer),?BYTE(MinVer), -		     ?UINT16(Length),  -		     Data:Length/binary, Rest/binary>>, Acc) -> -    get_tls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE, -					version = {MajVer, MinVer}, -					fragment = Data} | Acc]); -get_tls_records_aux(<<?BYTE(?ALERT),?BYTE(MajVer),?BYTE(MinVer), -		     ?UINT16(Length), Data:Length/binary,  -		     Rest/binary>>, Acc) -> -    get_tls_records_aux(Rest, [#ssl_tls{type = ?ALERT, -					version = {MajVer, MinVer}, -					fragment = Data} | Acc]); -get_tls_records_aux(<<?BYTE(?CHANGE_CIPHER_SPEC),?BYTE(MajVer),?BYTE(MinVer), -		     ?UINT16(Length), Data:Length/binary, Rest/binary>>,  -		    Acc) -> -    get_tls_records_aux(Rest, [#ssl_tls{type = ?CHANGE_CIPHER_SPEC, -					version = {MajVer, MinVer}, -					fragment = Data} | Acc]); -%% Matches an ssl v2 client hello message. -%% The server must be able to receive such messages, from clients that -%% are willing to use ssl v3 or higher, but have ssl v2 compatibility. -get_tls_records_aux(<<1:1, Length0:15, Data0:Length0/binary, Rest/binary>>, -		    Acc) -> -    case Data0 of -	<<?BYTE(?CLIENT_HELLO), ?BYTE(MajVer), ?BYTE(MinVer), _/binary>> -> -	    Length = Length0-1, -	    <<?BYTE(_), Data1:Length/binary>> = Data0, -	    Data = <<?BYTE(?CLIENT_HELLO), ?UINT24(Length), Data1/binary>>, -	    get_tls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE, -						version = {MajVer, MinVer}, -						fragment = Data} | Acc]); -	_ -> -	    ?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE) -	     -    end; - -get_tls_records_aux(<<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer), -                     ?UINT16(Length), _/binary>>, -                    _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH -> -    ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); - -get_tls_records_aux(<<1:1, Length0:15, _/binary>>,_Acc)  -  when Length0 > ?MAX_CIPHER_TEXT_LENGTH -> -    ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); +%%-------------------------------------------------------------------- +-spec encode_handshake(iolist(), tls_version(), ssl_record:connection_states()) -> +			      {iolist(), ssl_record:connection_states()}. +% +%% Description: Encodes a handshake message to send on the ssl-socket. +%%-------------------------------------------------------------------- +encode_handshake(Frag, Version,  +		 #{current_write := +		       #{beast_mitigation := BeastMitigation, +			  security_parameters := +			     #security_parameters{bulk_cipher_algorithm = BCA}}} =  +		     ConnectionStates) -> +    case iolist_size(Frag) of +	N  when N > ?MAX_PLAIN_TEXT_LENGTH -> +	    Data = split_bin(iolist_to_binary(Frag), ?MAX_PLAIN_TEXT_LENGTH, Version, BCA, BeastMitigation), +	    encode_iolist(?HANDSHAKE, Data, Version, ConnectionStates); +	_  -> +	    encode_plain_text(?HANDSHAKE, Version, Frag, ConnectionStates) +    end. -get_tls_records_aux(Data, Acc) -> -    case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of -	true -> -	    {lists:reverse(Acc), Data}; -	false -> -	    ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE) -	end. +%%-------------------------------------------------------------------- +-spec encode_alert_record(#alert{}, tls_version(), ssl_record:connection_states()) -> +				 {iolist(), ssl_record:connection_states()}. +%% +%% Description: Encodes an alert message to send on the ssl-socket. +%%-------------------------------------------------------------------- +encode_alert_record(#alert{level = Level, description = Description}, +                    Version, ConnectionStates) -> +    encode_plain_text(?ALERT, Version, <<?BYTE(Level), ?BYTE(Description)>>, +		      ConnectionStates). -encode_plain_text(Type, Version, Data, -		  #{current_write := -			#{sequence_number := Seq, -			  compression_state := CompS0, -			  security_parameters := -			      #security_parameters{ -				 cipher_type = ?AEAD, -				 compression_algorithm = CompAlg} -			 }= WriteState0} = ConnectionStates) -> -    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), -    WriteState1 = WriteState0#{compression_state => CompS1}, -    AAD = calc_aad(Type, Version, WriteState1), -    {CipherFragment, WriteState} = ssl_record:cipher_aead(Version, Comp, WriteState1, AAD), -    CipherText = encode_tls_cipher_text(Type, Version, CipherFragment), -    {CipherText, ConnectionStates#{current_write => WriteState#{sequence_number => Seq +1}}}; - -encode_plain_text(Type, Version, Data, -		  #{current_write := -			#{sequence_number := Seq, -			  compression_state := CompS0, -			  security_parameters := -			      #security_parameters{compression_algorithm = CompAlg} -			 }= WriteState0} = ConnectionStates) -> -    {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), -    WriteState1 = WriteState0#{compression_state => CompS1}, -    MacHash = calc_mac_hash(Type, Version, Comp, WriteState1), -    {CipherFragment, WriteState} = ssl_record:cipher(Version, Comp, WriteState1, MacHash), -    CipherText = encode_tls_cipher_text(Type, Version, CipherFragment), -    {CipherText, ConnectionStates#{current_write => WriteState#{sequence_number => Seq +1}}}; -encode_plain_text(_,_,_, CS) -> -    exit({cs, CS}). +%%-------------------------------------------------------------------- +-spec encode_change_cipher_spec(tls_version(), ssl_record:connection_states()) -> +				       {iolist(), ssl_record:connection_states()}. +%% +%% Description: Encodes a change_cipher_spec-message to send on the ssl socket. +%%-------------------------------------------------------------------- +encode_change_cipher_spec(Version, ConnectionStates) -> +    encode_plain_text(?CHANGE_CIPHER_SPEC, Version, ?byte(?CHANGE_CIPHER_SPEC_PROTO), ConnectionStates).  %%-------------------------------------------------------------------- --spec decode_cipher_text(#ssl_tls{}, ssl_record:connection_states(), boolean()) -> -				{#ssl_tls{}, ssl_record:connection_states()}| #alert{}. +-spec encode_data(binary(), tls_version(), ssl_record:connection_states()) -> +			 {iolist(), ssl_record:connection_states()}.  %% -%% Description: Decode cipher text +%% Description: Encodes data to send on the ssl-socket.  %%-------------------------------------------------------------------- -decode_cipher_text(#ssl_tls{type = Type, version = Version, -			    fragment = CipherFragment} = CipherText, -		   #{current_read := -			 #{compression_state := CompressionS0, -			   sequence_number := Seq, -			   security_parameters := -			       #security_parameters{ -				  cipher_type = ?AEAD, -				  compression_algorithm = CompAlg} -			  } = ReadState0} = ConnnectionStates0, _) -> -    AAD = calc_aad(Type, Version, ReadState0), -    case ssl_record:decipher_aead(Version, CipherFragment, ReadState0, AAD) of -	{PlainFragment, ReadState1} -> -	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, -							   PlainFragment, CompressionS0), -	    ConnnectionStates = ConnnectionStates0#{ -				  current_read => ReadState1#{sequence_number => Seq + 1, -							      compression_state => CompressionS1}}, -	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; -	#alert{} = Alert -> -	    Alert -    end; +encode_data(Frag, Version, +	    #{current_write := #{beast_mitigation := BeastMitigation, +				 security_parameters := +				     #security_parameters{bulk_cipher_algorithm = BCA}}} = +		ConnectionStates) -> +    Data = split_bin(Frag, ?MAX_PLAIN_TEXT_LENGTH, Version, BCA, BeastMitigation), +    encode_iolist(?APPLICATION_DATA, Data, Version, ConnectionStates). + -decode_cipher_text(#ssl_tls{type = Type, version = Version, -			    fragment = CipherFragment} = CipherText, -		   #{current_read := -			 #{compression_state := CompressionS0, -			   sequence_number := Seq, -			   security_parameters := -			       #security_parameters{compression_algorithm = CompAlg} -			  } = ReadState0} = ConnnectionStates0, PaddingCheck) -> -    case ssl_record:decipher(Version, CipherFragment, ReadState0, PaddingCheck) of -	{PlainFragment, Mac, ReadState1} -> -	    MacHash = calc_mac_hash(Type, Version, PlainFragment, ReadState1), -	    case ssl_record:is_correct_mac(Mac, MacHash) of -		true -> -		    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, -								   PlainFragment, CompressionS0), -		    ConnnectionStates = ConnnectionStates0#{ -					  current_read => ReadState1#{ -							    sequence_number => Seq + 1, -							    compression_state => CompressionS1}}, -		    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; -		false -> -			?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) -	    end; -	    #alert{} = Alert -> -	    Alert -    end.   %%--------------------------------------------------------------------  -spec protocol_version(tls_atom_version() | tls_version()) ->   			      tls_version() | tls_atom_version().		       @@ -401,6 +312,70 @@ initial_connection_state(ConnectionEnd, BeastMitigation) ->        server_verify_data => undefined       }. +get_tls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer), +		     ?UINT16(Length), Data:Length/binary, Rest/binary>>,  +		    Acc) -> +    get_tls_records_aux(Rest, [#ssl_tls{type = ?APPLICATION_DATA, +					version = {MajVer, MinVer}, +					fragment = Data} | Acc]); +get_tls_records_aux(<<?BYTE(?HANDSHAKE),?BYTE(MajVer),?BYTE(MinVer), +		     ?UINT16(Length),  +		     Data:Length/binary, Rest/binary>>, Acc) -> +    get_tls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE, +					version = {MajVer, MinVer}, +					fragment = Data} | Acc]); +get_tls_records_aux(<<?BYTE(?ALERT),?BYTE(MajVer),?BYTE(MinVer), +		     ?UINT16(Length), Data:Length/binary,  +		     Rest/binary>>, Acc) -> +    get_tls_records_aux(Rest, [#ssl_tls{type = ?ALERT, +					version = {MajVer, MinVer}, +					fragment = Data} | Acc]); +get_tls_records_aux(<<?BYTE(?CHANGE_CIPHER_SPEC),?BYTE(MajVer),?BYTE(MinVer), +		     ?UINT16(Length), Data:Length/binary, Rest/binary>>,  +		    Acc) -> +    get_tls_records_aux(Rest, [#ssl_tls{type = ?CHANGE_CIPHER_SPEC, +					version = {MajVer, MinVer}, +					fragment = Data} | Acc]); +%% Matches an ssl v2 client hello message. +%% The server must be able to receive such messages, from clients that +%% are willing to use ssl v3 or higher, but have ssl v2 compatibility. +get_tls_records_aux(<<1:1, Length0:15, Data0:Length0/binary, Rest/binary>>, +		    Acc) -> +    case Data0 of +	<<?BYTE(?CLIENT_HELLO), ?BYTE(MajVer), ?BYTE(MinVer), _/binary>> -> +	    Length = Length0-1, +	    <<?BYTE(_), Data1:Length/binary>> = Data0, +	    Data = <<?BYTE(?CLIENT_HELLO), ?UINT24(Length), Data1/binary>>, +	    get_tls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE, +						version = {MajVer, MinVer}, +						fragment = Data} | Acc]); +	_ -> +	    ?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE) +	     +    end; + +get_tls_records_aux(<<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer), +                     ?UINT16(Length), _/binary>>, +                    _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH -> +    ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); + +get_tls_records_aux(<<1:1, Length0:15, _/binary>>,_Acc)  +  when Length0 > ?MAX_CIPHER_TEXT_LENGTH -> +    ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); + +get_tls_records_aux(Data, Acc) -> +    case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of +	true -> +	    {lists:reverse(Acc), Data}; +	false -> +	    ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE) +	end. + +encode_plain_text(Type, Version, Data, #{current_write := Write0} = ConnectionStates) -> +    {CipherFragment, Write1} = ssl_record:encode_plain_text(Type, Version, Data, Write0), +    {CipherText, Write} = encode_tls_cipher_text(Type, Version, CipherFragment, Write1), +    {CipherText, ConnectionStates#{current_write => Write}}. +  lowest_list_protocol_version(Ver, []) ->      Ver;  lowest_list_protocol_version(Ver1,  [Ver2 | Rest]) -> @@ -411,20 +386,10 @@ highest_list_protocol_version(Ver, []) ->  highest_list_protocol_version(Ver1,  [Ver2 | Rest]) ->      highest_list_protocol_version(highest_protocol_version(Ver1, Ver2), Rest). -encode_tls_cipher_text(Type, {MajVer, MinVer}, Fragment) -> +encode_tls_cipher_text(Type, {MajVer, MinVer}, Fragment, #{sequence_number := Seq} = Write) ->      Length = erlang:iolist_size(Fragment), -    [<<?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer), ?UINT16(Length)>>, Fragment]. - - -mac_hash({_,_}, ?NULL, _MacSecret, _SeqNo, _Type, -	 _Length, _Fragment) -> -    <<>>; -mac_hash({3, 0}, MacAlg, MacSecret, SeqNo, Type, Length, Fragment) -> -    ssl_v3:mac_hash(MacAlg, MacSecret, SeqNo, Type, Length, Fragment); -mac_hash({3, N} = Version, MacAlg, MacSecret, SeqNo, Type, Length, Fragment)   -  when N =:= 1; N =:= 2; N =:= 3 -> -    tls_v1:mac_hash(MacAlg, MacSecret, SeqNo, Type, Version, -		      Length, Fragment). +    {[<<?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer), ?UINT16(Length)>>, Fragment], +     Write#{sequence_number => Seq +1}}.  highest_protocol_version() ->      highest_protocol_version(supported_protocol_versions()). @@ -432,21 +397,96 @@ highest_protocol_version() ->  lowest_protocol_version() ->      lowest_protocol_version(supported_protocol_versions()). -  sufficient_tlsv1_2_crypto_support() ->      CryptoSupport = crypto:supports(),      proplists:get_bool(sha256, proplists:get_value(hashs, CryptoSupport)). -calc_mac_hash(Type, Version, -	      PlainFragment, #{sequence_number := SeqNo, -			       mac_secret := MacSecret, -			       security_parameters:= -				   SecPars}) -> -    Length = erlang:iolist_size(PlainFragment), -    mac_hash(Version, SecPars#security_parameters.mac_algorithm, -	     MacSecret, SeqNo, Type, -	     Length, PlainFragment). - -calc_aad(Type, {MajVer, MinVer}, -	 #{sequence_number := SeqNo}) -> -    <<SeqNo:64/integer, ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. +encode_iolist(Type, Data, Version, ConnectionStates0) -> +    {ConnectionStates, EncodedMsg} = +        lists:foldl(fun(Text, {CS0, Encoded}) -> +			    {Enc, CS1} = +				encode_plain_text(Type, Version, Text, CS0), +			    {CS1, [Enc | Encoded]} +		    end, {ConnectionStates0, []}, Data), +    {lists:reverse(EncodedMsg), ConnectionStates}. + +%% 1/n-1 splitting countermeasure Rizzo/Duong-Beast, RC4 chiphers are +%% not vulnerable to this attack. +split_bin(<<FirstByte:8, Rest/binary>>, ChunkSize, Version, BCA, one_n_minus_one) when +      BCA =/= ?RC4 andalso ({3, 1} == Version orelse +			    {3, 0} == Version) -> +    do_split_bin(Rest, ChunkSize, [[FirstByte]]); +%% 0/n splitting countermeasure for clients that are incompatible with 1/n-1 +%% splitting. +split_bin(Bin, ChunkSize, Version, BCA, zero_n) when +      BCA =/= ?RC4 andalso ({3, 1} == Version orelse +			    {3, 0} == Version) -> +    do_split_bin(Bin, ChunkSize, [[<<>>]]); +split_bin(Bin, ChunkSize, _, _, _) -> +    do_split_bin(Bin, ChunkSize, []). + +do_split_bin(<<>>, _, Acc) -> +    lists:reverse(Acc); +do_split_bin(Bin, ChunkSize, Acc) -> +    case Bin of +        <<Chunk:ChunkSize/binary, Rest/binary>> -> +            do_split_bin(Rest, ChunkSize, [Chunk | Acc]); +        _ -> +            lists:reverse(Acc, [Bin]) +    end. + +%%-------------------------------------------------------------------- +-spec decode_cipher_text(#ssl_tls{}, ssl_record:connection_states(), boolean()) -> +				{#ssl_tls{}, ssl_record:connection_states()}| #alert{}. +%% +%% Description: Decode cipher text +%%-------------------------------------------------------------------- +decode_cipher_text(#ssl_tls{type = Type, version = Version, +			    fragment = CipherFragment} = CipherText, +		   #{current_read := +			 #{compression_state := CompressionS0, +			   sequence_number := Seq, +			   security_parameters := +			       #security_parameters{ +				  cipher_type = ?AEAD, +				  compression_algorithm = CompAlg} +			  } = ReadState0} = ConnnectionStates0, _) -> +    AAD = ssl_cipher:calc_aad(Type, Version, ReadState0), +    case ssl_record:decipher_aead(Version, CipherFragment, ReadState0, AAD) of +	{PlainFragment, ReadState1} -> +	    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, +							   PlainFragment, CompressionS0), +	    ConnnectionStates = ConnnectionStates0#{ +				  current_read => ReadState1#{sequence_number => Seq + 1, +							      compression_state => CompressionS1}}, +	    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; +	#alert{} = Alert -> +	    Alert +    end; + +decode_cipher_text(#ssl_tls{type = Type, version = Version, +			    fragment = CipherFragment} = CipherText, +		   #{current_read := +			 #{compression_state := CompressionS0, +			   sequence_number := Seq, +			   security_parameters := +			       #security_parameters{compression_algorithm = CompAlg} +			  } = ReadState0} = ConnnectionStates0, PaddingCheck) -> +    case ssl_record:decipher(Version, CipherFragment, ReadState0, PaddingCheck) of +	{PlainFragment, Mac, ReadState1} -> +	    MacHash = ssl_cipher:calc_mac_hash(Type, Version, PlainFragment, ReadState1), +	    case ssl_record:is_correct_mac(Mac, MacHash) of +		true -> +		    {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, +								   PlainFragment, CompressionS0), +		    ConnnectionStates = ConnnectionStates0#{ +					  current_read => ReadState1#{ +							    sequence_number => Seq + 1, +							    compression_state => CompressionS1}}, +		    {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; +		false -> +			?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) +	    end; +	    #alert{} = Alert -> +	    Alert +    end.  diff --git a/lib/ssl/src/ssl_socket.erl b/lib/ssl/src/tls_socket.erl index b2aea2ba9c..e76d9c100a 100644 --- a/lib/ssl/src/ssl_socket.erl +++ b/lib/ssl/src/tls_socket.erl @@ -17,16 +17,19 @@  %%  %% %CopyrightEnd%  %% --module(ssl_socket). +-module(tls_socket).  -behaviour(gen_server).  -include("ssl_internal.hrl").  -include("ssl_api.hrl"). --export([socket/5, setopts/3, getopts/3, getstat/3, peername/2, sockname/2, port/2]). +-export([send/3, listen/3, accept/3, socket/5, connect/4, upgrade/3, +	 setopts/3, getopts/3, getstat/3, peername/2, sockname/2, port/2]). +-export([split_options/1, get_socket_opts/3]).  -export([emulated_options/0, internal_inet_values/0, default_inet_values/0, -	 init/1, start_link/3, terminate/2, inherit_tracker/3, get_emulated_opts/1,  +	 init/1, start_link/3, terminate/2, inherit_tracker/3,  +	 emulated_socket_options/2, get_emulated_opts/1,   	 set_emulated_opts/2, get_all_opts/1, handle_call/3, handle_cast/2,  	 handle_info/2, code_change/3]). @@ -39,6 +42,76 @@  %%--------------------------------------------------------------------  %%% Internal API  %%-------------------------------------------------------------------- +send(Transport, Socket, Data) -> +    Transport:send(Socket, Data). + +listen(Transport, Port, #config{transport_info = {Transport, _, _, _},  +				inet_user = Options,  +				ssl = SslOpts, emulated = EmOpts} = Config) -> +    case Transport:listen(Port, Options ++ internal_inet_values()) of +	{ok, ListenSocket} -> +	    {ok, Tracker} = inherit_tracker(ListenSocket, EmOpts, SslOpts), +	    {ok, #sslsocket{pid = {ListenSocket, Config#config{emulated = Tracker}}}}; +	Err = {error, _} -> +	    Err +    end. + +accept(ListenSocket, #config{transport_info = {Transport,_,_,_} = CbInfo, +			     connection_cb = ConnectionCb, +			     ssl = SslOpts, +			     emulated = Tracker}, Timeout) ->  +    case Transport:accept(ListenSocket, Timeout) of +	{ok, Socket} -> +	    {ok, EmOpts} = get_emulated_opts(Tracker), +	    {ok, Port} = tls_socket:port(Transport, Socket), +	    ConnArgs = [server, "localhost", Port, Socket, +			{SslOpts, emulated_socket_options(EmOpts, #socket_options{}), Tracker}, self(), CbInfo], +	    case tls_connection_sup:start_child(ConnArgs) of +		{ok, Pid} -> +		    ssl_connection:socket_control(ConnectionCb, Socket, Pid, Transport, Tracker); +		{error, Reason} -> +		    {error, Reason} +	    end; +	{error, Reason} -> +	    {error, Reason} +    end. + +upgrade(Socket, #config{transport_info = {Transport,_,_,_}= CbInfo, +			ssl = SslOptions, +			emulated = EmOpts, connection_cb = ConnectionCb}, Timeout) -> +    ok = setopts(Transport, Socket, tls_socket:internal_inet_values()), +    case peername(Transport, Socket) of +	{ok, {Address, Port}} -> +	    ssl_connection:connect(ConnectionCb, Address, Port, Socket, +				   {SslOptions,  +				    emulated_socket_options(EmOpts, #socket_options{}), undefined}, +				   self(), CbInfo, Timeout); +	{error, Error} -> +	    {error, Error} +    end. + +connect(Address, Port, +	#config{transport_info = CbInfo, inet_user = UserOpts, ssl = SslOpts, +		emulated = EmOpts, inet_ssl = SocketOpts, connection_cb = ConnetionCb}, +	Timeout) -> +    {Transport, _, _, _} = CbInfo, +    try Transport:connect(Address, Port,  SocketOpts, Timeout) of +	{ok, Socket} -> +	    ssl_connection:connect(ConnetionCb, Address, Port, Socket,  +				   {SslOpts,  +				    emulated_socket_options(EmOpts, #socket_options{}), undefined}, +				   self(), CbInfo, Timeout); +	{error, Reason} -> +	    {error, Reason} +    catch +	exit:{function_clause, _} -> +	    {error, {options, {cb_info, CbInfo}}}; +	exit:badarg -> +	    {error, {options, {socket_options, UserOpts}}}; +	exit:{badarg, _} -> +	    {error, {options, {socket_options, UserOpts}}} +    end. +  socket(Pid, Transport, Socket, ConnectionCb, Tracker) ->      #sslsocket{pid = Pid,   	       %% "The name "fd" is keept for backwards compatibility @@ -241,3 +314,17 @@ get_emulated_opts(TrackerPid, EmOptNames) ->      lists:map(fun(Name) -> {value, Value} = lists:keysearch(Name, 1, EmOpts),  			   Value end,  	      EmOptNames). + +emulated_socket_options(InetValues, #socket_options{ +				       mode   = Mode, +				       header = Header, +				       active = Active, +				       packet = Packet, +				       packet_size = Size}) -> +    #socket_options{ +       mode   = proplists:get_value(mode, InetValues, Mode), +       header = proplists:get_value(header, InetValues, Header), +       active = proplists:get_value(active, InetValues, Active), +       packet = proplists:get_value(packet, InetValues, Packet), +       packet_size = proplists:get_value(packet_size, InetValues, Size) +      }. | 
