diff options
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 665 |
1 files changed, 383 insertions, 282 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index e490de7eeb..b8be686b99 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -21,7 +21,7 @@ %% Internal application API --behaviour(gen_fsm). +-behaviour(gen_statem). -include("dtls_connection.hrl"). -include("dtls_handshake.hrl"). @@ -36,48 +36,49 @@ %% Internal application API %% Setup --export([start_fsm/8]). +-export([start_fsm/8, start_link/7, init/1]). %% State transition handling --export([next_record/1, next_state/4%, - %%next_state_connection/2 - ]). +-export([next_record/1, next_event/3]). %% Handshake handling --export([%%renegotiate/1, - send_handshake/2, send_change_cipher/2]). +-export([%%renegotiate/2, + send_handshake/2, queue_handshake/2, queue_change_cipher/2]). %% Alert and close handling --export([send_alert/2, handle_own_alert/4, %%handle_close_alert/3, - handle_normal_shutdown/3 - %%handle_unexpected_message/3, - %%alert_user/5, alert_user/8 +-export([%%send_alert/2, handle_own_alert/4, handle_close_alert/3, + handle_normal_shutdown/3 %%, close/5 + %%alert_user/6, alert_user/9 ]). %% Data handling + -export([%%write_application_data/3, - read_application_data/2%%, -%% passive_receive/2, next_record_if_active/1 + read_application_data/2, + passive_receive/2, next_record_if_active/1%, + %%handle_common_event/4, + %handle_packet/3 ]). -%% Called by tls_connection_sup --export([start_link/7]). +%% gen_statem state functions +-export([init/3, error/3, downgrade/3, %% Initiation and take down states + hello/3, certify/3, cipher/3, abbreviated/3, %% Handshake states + connection/3]). +%% gen_statem callbacks +-export([terminate/3, code_change/4, format_status/2]). -%% gen_fsm callbacks --export([init/1, hello/2, certify/2, cipher/2, - abbreviated/2, connection/2, handle_event/3, - handle_sync_event/4, handle_info/3, terminate/3, code_change/4]). +-define(GEN_STATEM_CB_MODE, state_functions). %%==================================================================== %% Internal application API %%==================================================================== -start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_} = Opts, +start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_, Tracker} = Opts, User, {CbModule, _,_, _} = CbInfo, Timeout) -> try {ok, Pid} = dtls_connection_sup:start_child([Role, Host, Port, Socket, Opts, User, CbInfo]), - {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule), + {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker), ok = ssl_connection:handshake(SslSocket, Timeout), {ok, SslSocket} catch @@ -85,13 +86,13 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_} = Opts, Error end; -start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_} = Opts, +start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} = Opts, User, {CbModule, _,_, _} = CbInfo, Timeout) -> try {ok, Pid} = dtls_connection_sup:start_child_dist([Role, Host, Port, Socket, Opts, User, CbInfo]), - {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule), + {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker), ok = ssl_connection:handshake(SslSocket, Timeout), {ok, SslSocket} catch @@ -99,14 +100,37 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_} = Opts, Error end. -send_handshake(Handshake, #state{negotiated_version = Version, - tls_handshake_history = Hist0, - connection_states = ConnectionStates0} = State0) -> - {BinHandshake, ConnectionStates, Hist} = +send_handshake(Handshake, State) -> + send_handshake_flight(queue_handshake(Handshake, State)). + +queue_flight_buffer(Msg, #state{negotiated_version = Version, + connection_states = #connection_states{ + current_write = + #connection_state{epoch = Epoch}}, + flight_buffer = Flight} = State) -> + State#state{flight_buffer = Flight ++ [{Version, Epoch, Msg}]}. + +queue_handshake(Handshake, #state{negotiated_version = Version, + tls_handshake_history = Hist0, + connection_states = ConnectionStates0} = State0) -> + {Frag, ConnectionStates, Hist} = encode_handshake(Handshake, Version, ConnectionStates0, Hist0), - send_flight(BinHandshake, State0#state{connection_states = ConnectionStates, - tls_handshake_history = Hist - }). + queue_flight_buffer(Frag, State0#state{connection_states = ConnectionStates, + tls_handshake_history = Hist}). + +send_handshake_flight(#state{socket = Socket, + transport_cb = Transport, + flight_buffer = Flight, + connection_states = ConnectionStates0} = State0) -> + + {Encoded, ConnectionStates} = + encode_handshake_flight(Flight, ConnectionStates0), + + Transport:send(Socket, Encoded), + State0#state{flight_buffer = [], connection_states = ConnectionStates}. + +queue_change_cipher(Msg, State) -> + queue_flight_buffer(Msg, State). send_alert(Alert, #state{negotiated_version = Version, socket = Socket, @@ -117,15 +141,6 @@ send_alert(Alert, #state{negotiated_version = Version, Transport:send(Socket, BinMsg), State0#state{connection_states = ConnectionStates}. -send_change_cipher(Msg, #state{connection_states = ConnectionStates0, - socket = Socket, - negotiated_version = Version, - transport_cb = Transport} = State0) -> - {BinChangeCipher, ConnectionStates} = - encode_change_cipher(Msg, Version, ConnectionStates0), - Transport:send(Socket, BinChangeCipher), - State0#state{connection_states = ConnectionStates}. - %%==================================================================== %% tls_connection_sup API %%==================================================================== @@ -141,83 +156,98 @@ send_change_cipher(Msg, #state{connection_states = ConnectionStates0, start_link(Role, Host, Port, Socket, Options, User, CbInfo) -> {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Host, Port, Socket, Options, User, CbInfo]])}. -init([Role, Host, Port, Socket, {SSLOpts0, _} = Options, User, CbInfo]) -> +init([Role, Host, Port, Socket, Options, User, CbInfo]) -> process_flag(trap_exit, true), State0 = initial_state(Role, Host, Port, Socket, Options, User, CbInfo), - Handshake = ssl_handshake:init_handshake_history(), - TimeStamp = erlang:monotonic_time(), - try ssl_config:init(SSLOpts0, Role) of - {ok, Ref, CertDbHandle, FileRefHandle, CacheHandle, CRLDbInfo, OwnCert, Key, DHParams} -> - Session = State0#state.session, - State = State0#state{ - tls_handshake_history = Handshake, - session = Session#session{own_certificate = OwnCert, - time_stamp = TimeStamp}, - file_ref_db = FileRefHandle, - cert_db_ref = Ref, - cert_db = CertDbHandle, - crl_db = CRLDbInfo, - session_cache = CacheHandle, - private_key = Key, - diffie_hellman_params = DHParams}, - gen_fsm:enter_loop(?MODULE, [], hello, State, get_timeout(State)) + try + State = ssl_connection:ssl_config(State0#state.ssl_options, Role, State0), + gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, init, State) catch throw:Error -> - gen_fsm:enter_loop(?MODULE, [], error, {Error,State0}, get_timeout(State0)) + gen_statem:enter_loop(?MODULE, [], ?GEN_STATEM_CB_MODE, error, {Error,State0}) end. %%-------------------------------------------------------------------- -%% Description:There should be one instance of this function for each -%% possible state name. Whenever a gen_fsm receives an event sent -%% using gen_fsm:send_event/2, the instance of this function with the -%% same name as the current state name StateName is called to handle -%% the event. It is also called if a timeout occurs. -%% -hello(start, #state{host = Host, port = Port, role = client, - ssl_options = SslOpts, - session = #session{own_certificate = Cert} = Session0, - session_cache = Cache, session_cache_cb = CacheCb, - transport_cb = Transport, socket = Socket, - connection_states = ConnectionStates0, - renegotiation = {Renegotiation, _}} = State0) -> +%% State functionsconnection/2 +%%-------------------------------------------------------------------- + +init({call, From}, {start, Timeout}, + #state{host = Host, port = Port, role = client, + ssl_options = SslOpts, + session = #session{own_certificate = Cert} = Session0, + transport_cb = Transport, socket = Socket, + connection_states = ConnectionStates0, + renegotiation = {Renegotiation, _}, + session_cache = Cache, + session_cache_cb = CacheCb + } = State0) -> + Timer = ssl_connection:start_or_recv_cancel_timer(Timeout, From), Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, - Cache, CacheCb, Renegotiation, Cert), + Cache, CacheCb, Renegotiation, Cert), Version = Hello#client_hello.client_version, + HelloVersion = dtls_record:lowest_protocol_version(SslOpts#ssl_options.versions), Handshake0 = ssl_handshake:init_handshake_history(), {BinMsg, ConnectionStates, Handshake} = - encode_handshake(Hello, Version, ConnectionStates0, Handshake0), + encode_handshake(Hello, HelloVersion, ConnectionStates0, Handshake0), Transport:send(Socket, BinMsg), State1 = State0#state{connection_states = ConnectionStates, negotiated_version = Version, %% Requested version session = Session0#session{session_id = Hello#client_hello.session_id}, - tls_handshake_history = Handshake}, + tls_handshake_history = Handshake, + start_or_recv_from = From, + timer = Timer}, {Record, State} = next_record(State1), - next_state(hello, hello, Record, State); + next_event(hello, Record, State); +init(Type, Event, State) -> + ssl_connection:init(Type, Event, State, ?MODULE). + +error({call, From}, {start, _Timeout}, {Error, State}) -> + {stop_and_reply, normal, {reply, From, {error, Error}}, State}; +error({call, From}, Msg, State) -> + handle_call(Msg, From, error, State); +error(_, _, _) -> + {keep_state_and_data, [postpone]}. -hello(Hello = #client_hello{client_version = ClientVersion}, +%%-------------------------------------------------------------------- +-spec hello(gen_statem:event_type(), + #hello_request{} | #client_hello{} | #server_hello{} | term(), + #state{}) -> + gen_statem:state_function_result(). +%%-------------------------------------------------------------------- +hello(internal, #client_hello{client_version = ClientVersion, + extensions = #hello_extensions{ec_point_formats = EcPointFormats, + elliptic_curves = EllipticCurves}} = Hello, State = #state{connection_states = ConnectionStates0, port = Port, session = #session{own_certificate = Cert} = Session0, renegotiation = {Renegotiation, _}, session_cache = Cache, session_cache_cb = CacheCb, + negotiated_protocol = CurrentProtocol, + key_algorithm = KeyExAlg, ssl_options = SslOpts}) -> + case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb, - ConnectionStates0, Cert}, Renegotiation) of - {Version, {Type, Session}, - ConnectionStates, - #hello_extensions{ec_point_formats = EcPointFormats, - elliptic_curves = EllipticCurves} = ServerHelloExt, HashSign} -> - ssl_connection:hello({common_client_hello, Type, ServerHelloExt, HashSign}, - State#state{connection_states = ConnectionStates, + ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of + #alert{} = Alert -> + handle_own_alert(Alert, ClientVersion, hello, State); + {Version, {Type, Session}, + ConnectionStates, Protocol0, ServerHelloExt, HashSign} -> + Protocol = case Protocol0 of + undefined -> CurrentProtocol; + _ -> Protocol0 + end, + + ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt}, + State#state{connection_states = ConnectionStates, negotiated_version = Version, + hashsign_algorithm = HashSign, session = Session, - client_ecc = {EllipticCurves, EcPointFormats}}, ?MODULE); - #alert{} = Alert -> - handle_own_alert(Alert, ClientVersion, hello, State) + client_ecc = {EllipticCurves, EcPointFormats}, + negotiated_protocol = Protocol}, ?MODULE) end; -hello(Hello, +hello(internal, #server_hello{} = Hello, #state{connection_states = ConnectionStates0, negotiated_version = ReqVersion, role = client, @@ -230,20 +260,30 @@ hello(Hello, ssl_connection:handle_session(Hello, Version, NewId, ConnectionStates, ProtoExt, Protocol, State) end; - -hello(Msg, State) -> - ssl_connection:hello(Msg, State, ?MODULE). - -abbreviated(Msg, State) -> - ssl_connection:abbreviated(Msg, State, ?MODULE). - -certify(Msg, State) -> - ssl_connection:certify(Msg, State, ?MODULE). - -cipher(Msg, State) -> - ssl_connection:cipher(Msg, State, ?MODULE). - -connection(#hello_request{}, #state{host = Host, port = Port, +hello(info, Event, State) -> + handle_info(Event, hello, State); + +hello(Type, Event, State) -> + ssl_connection:hello(Type, Event, State, ?MODULE). + +abbreviated(info, Event, State) -> + handle_info(Event, abbreviated, State); +abbreviated(Type, Event, State) -> + ssl_connection:abbreviated(Type, Event, State, ?MODULE). + +certify(info, Event, State) -> + handle_info(Event, certify, State); +certify(Type, Event, State) -> + ssl_connection:certify(Type, Event, State, ?MODULE). + +cipher(info, Event, State) -> + handle_info(Event, cipher, State); +cipher(Type, Event, State) -> + ssl_connection:cipher(Type, Event, State, ?MODULE). + +connection(info, Event, State) -> + handle_info(Event, connection, State); +connection(internal, #hello_request{}, #state{host = Host, port = Port, session = #session{own_certificate = Cert} = Session0, session_cache = Cache, session_cache_cb = CacheCb, ssl_options = SslOpts, @@ -251,46 +291,35 @@ connection(#hello_request{}, #state{host = Host, port = Port, renegotiation = {Renegotiation, _}} = State0) -> Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, Cache, CacheCb, Renegotiation, Cert), - %% TODO DTLS version State1 = send_handshake(Hello, State0), - State1 = State0, + State1 = send_handshake(Hello, State0), {Record, State} = next_record( State1#state{session = Session0#session{session_id = Hello#client_hello.session_id}}), - next_state(connection, hello, Record, State); + next_event(hello, Record, State); -connection(#client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) -> +connection(internal, #client_hello{} = Hello, #state{role = server, allow_renegotiate = true} = State) -> %% Mitigate Computational DoS attack %% http://www.educatedguesswork.org/2011/10/ssltls_and_computational_dos.html %% http://www.thc.org/thc-ssl-dos/ Rather than disabling client %% initiated renegotiation we will disallow many client initiated %% renegotiations immediately after each other. erlang:send_after(?WAIT_TO_ALLOW_RENEGOTIATION, self(), allow_renegotiate), - hello(Hello, State#state{allow_renegotiate = false}); + {next_state, hello, State#state{allow_renegotiate = false}, [{next_event, internal, Hello}]}; + -connection(#client_hello{}, #state{role = server, allow_renegotiate = false} = State0) -> +connection(internal, #client_hello{}, #state{role = server, allow_renegotiate = false} = State0) -> Alert = ?ALERT_REC(?WARNING, ?NO_RENEGOTIATION), - State = send_alert(Alert, State0), - next_state_connection(connection, State); + State1 = send_alert(Alert, State0), + {Record, State} = ssl_connection:prepare_connection(State1, ?MODULE), + next_event(connection, Record, State); -connection(Msg, State) -> - ssl_connection:connection(Msg, State, tls_connection). +connection(Type, Event, State) -> + ssl_connection:connection(Type, Event, State, ?MODULE). -%%-------------------------------------------------------------------- -%% Description: Whenever a gen_fsm receives an event sent using -%% gen_fsm:send_all_state_event/2, this function is called to handle -%% the event. Not currently used! -%%-------------------------------------------------------------------- -handle_event(_Event, StateName, State) -> - {next_state, StateName, State, get_timeout(State)}. +downgrade(Type, Event, State) -> + ssl_connection:downgrade(Type, Event, State, ?MODULE). -%%-------------------------------------------------------------------- -%% Description: Whenever a gen_fsm receives an event sent using -%% gen_fsm:sync_send_all_state_event/2,3, this function is called to handle -%% the event. -%%-------------------------------------------------------------------- -handle_sync_event(Event, From, StateName, State) -> - ssl_connection:handle_sync_event(Event, From, StateName, State). %%-------------------------------------------------------------------- %% Description: This function is called by a gen_fsm when it receives any @@ -301,26 +330,38 @@ handle_sync_event(Event, From, StateName, State) -> %% raw data from socket, unpack records handle_info({Protocol, _, Data}, StateName, #state{data_tag = Protocol} = State0) -> - %% Simplify for now to avoid dialzer warnings before implementation is compleate - %% case next_tls_record(Data, State0) of - %% {Record, State} -> - %% next_state(StateName, StateName, Record, State); - %% #alert{} = Alert -> - %% handle_normal_shutdown(Alert, StateName, State0), - %% {stop, {shutdown, own_alert}, State0} - %% end; - {Record, State} = next_tls_record(Data, State0), - next_state(StateName, StateName, Record, State); - + case next_tls_record(Data, State0) of + {Record, State} -> + next_event(StateName, Record, State); + #alert{} = Alert -> + handle_normal_shutdown(Alert, StateName, State0), + {stop, {shutdown, own_alert}} + end; handle_info({CloseTag, Socket}, StateName, - #state{socket = Socket, close_tag = CloseTag, - negotiated_version = _Version} = State) -> + #state{socket = Socket, close_tag = CloseTag, + negotiated_version = Version} = State) -> + %% Note that as of DTLS 1.2 (TLS 1.1), + %% failure to properly close a connection no longer requires that a + %% session not be resumed. This is a change from DTLS 1.0 to conform + %% with widespread implementation practice. + case Version of + {254, N} when N =< 253 -> + ok; + _ -> + %% As invalidate_sessions here causes performance issues, + %% we will conform to the widespread implementation + %% practice and go aginst the spec + %%invalidate_session(Role, Host, Port, Session) + ok + end, handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State), - {stop, {shutdown, transport_closed}, State}; - + {stop, {shutdown, transport_closed}}; handle_info(Msg, StateName, State) -> ssl_connection:handle_info(Msg, StateName, State). +handle_call(Event, From, StateName, State) -> + ssl_connection:handle_call(Event, From, StateName, State, ?MODULE). + %%-------------------------------------------------------------------- %% Description:This function is called by a gen_fsm when it is about %% to terminate. It should be the opposite of Module:init/1 and do any @@ -335,154 +376,96 @@ terminate(Reason, StateName, State) -> %% Description: Convert process state when code is changed %%-------------------------------------------------------------------- code_change(_OldVsn, StateName, State, _Extra) -> - {ok, StateName, State}. + {?GEN_STATEM_CB_MODE, StateName, State}. + +format_status(Type, Data) -> + ssl_connection:format_status(Type, Data). %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- encode_handshake(Handshake, Version, ConnectionStates0, Hist0) -> - Seq = sequence(ConnectionStates0), - {EncHandshake, FragmentedHandshake} = dtls_handshake:encode_handshake(Handshake, Version, - Seq), + {Seq, ConnectionStates} = sequence(ConnectionStates0), + {EncHandshake, Frag} = dtls_handshake:encode_handshake(Handshake, Version, Seq), Hist = ssl_handshake:update_handshake_history(Hist0, EncHandshake), - {Encoded, ConnectionStates} = - dtls_record:encode_handshake(FragmentedHandshake, - Version, ConnectionStates0), - {Encoded, ConnectionStates, Hist}. + {Frag, ConnectionStates, Hist}. -next_record(#state{%%flight = #flight{state = finished}, - protocol_buffers = - #protocol_buffers{dtls_packets = [], dtls_cipher_texts = [CT | Rest]} - = Buffers, - connection_states = ConnStates0} = State) -> - case dtls_record:decode_cipher_text(CT, ConnStates0) of - {Plain, ConnStates} -> - {Plain, State#state{protocol_buffers = - Buffers#protocol_buffers{dtls_cipher_texts = Rest}, - connection_states = ConnStates}}; - #alert{} = Alert -> - {Alert, State} - end; -next_record(#state{socket = Socket, - transport_cb = Transport} = State) -> %% when FlightState =/= finished - ssl_socket:setopts(Transport, Socket, [{active,once}]), - {no_record, State}; +encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> + dtls_record:encode_change_cipher_spec(Version, ConnectionStates). +encode_handshake_flight(Flight, ConnectionStates) -> + MSS = 1400, + encode_handshake_records(Flight, ConnectionStates, MSS, init_pack_records()). -next_record(State) -> - {no_record, State}. +encode_handshake_records([], CS, _MSS, Recs) -> + {finish_pack_records(Recs), CS}; -next_state(Current,_, #alert{} = Alert, #state{negotiated_version = Version} = State) -> - handle_own_alert(Alert, Version, Current, State); - -next_state(_,Next, no_record, State) -> - {next_state, Next, State, get_timeout(State)}; - -%% next_state(_,Next, #ssl_tls{type = ?ALERT, fragment = EncAlerts}, State) -> -%% Alerts = decode_alerts(EncAlerts), -%% handle_alerts(Alerts, {next_state, Next, State, get_timeout(State)}); - -next_state(Current, Next, #ssl_tls{type = ?HANDSHAKE, fragment = Data}, - State0 = #state{protocol_buffers = - #protocol_buffers{dtls_handshake_buffer = Buf0} = Buffers, - negotiated_version = Version}) -> - Handle = - fun({#hello_request{} = Packet, _}, {next_state, connection = SName, State}) -> - %% This message should not be included in handshake - %% message hashes. Starts new handshake (renegotiation) - Hs0 = ssl_handshake:init_handshake_history(), - ?MODULE:SName(Packet, State#state{tls_handshake_history=Hs0, - renegotiation = {true, peer}}); - ({#hello_request{} = Packet, _}, {next_state, SName, State}) -> - %% This message should not be included in handshake - %% message hashes. Already in negotiation so it will be ignored! - ?MODULE:SName(Packet, State); - ({#client_hello{} = Packet, Raw}, {next_state, connection = SName, State}) -> - Version = Packet#client_hello.client_version, - Hs0 = ssl_handshake:init_handshake_history(), - Hs1 = ssl_handshake:update_handshake_history(Hs0, Raw), - ?MODULE:SName(Packet, State#state{tls_handshake_history=Hs1, - renegotiation = {true, peer}}); - ({Packet, Raw}, {next_state, SName, State = #state{tls_handshake_history=Hs0}}) -> - Hs1 = ssl_handshake:update_handshake_history(Hs0, Raw), - ?MODULE:SName(Packet, State#state{tls_handshake_history=Hs1}); - (_, StopState) -> StopState - end, - try - {Packets, Buf} = tls_handshake:get_tls_handshake(Version,Data,Buf0), - State = State0#state{protocol_buffers = - Buffers#protocol_buffers{dtls_packets = Packets, - dtls_handshake_buffer = Buf}}, - handle_dtls_handshake(Handle, Next, State) - catch throw:#alert{} = Alert -> - handle_own_alert(Alert, Version, Current, State0) +encode_handshake_records([{Version, _Epoch, Frag = #change_cipher_spec{}}|Tail], ConnectionStates0, MSS, Recs0) -> + {Encoded, ConnectionStates} = + encode_change_cipher(Frag, Version, ConnectionStates0), + Recs = append_pack_records([Encoded], MSS, Recs0), + encode_handshake_records(Tail, ConnectionStates, MSS, Recs); + +encode_handshake_records([{Version, Epoch, {MsgType, MsgSeq, Bin}}|Tail], CS0, MSS, Recs0 = {Buf0, _}) -> + Space = MSS - iolist_size(Buf0), + Len = byte_size(Bin), + {Encoded, CS} = + encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, 0, MSS, [], CS0), + Recs = append_pack_records(Encoded, MSS, Recs0), + encode_handshake_records(Tail, CS, MSS, Recs). + +%% TODO: move to dtls_handshake???? +encode_handshake_record(_Version, _Epoch, _Space, _MsgType, _MsgSeq, _Len, <<>>, _Offset, _MRS, Encoded, CS) + when length(Encoded) > 0 -> + %% make sure we encode at least one segment (for empty messages like Server Hello Done + {lists:reverse(Encoded), CS}; + +encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, + Offset, MRS, Encoded0, CS0) -> + MaxFragmentLen = Space - 25, + case Bin of + <<BinFragment:MaxFragmentLen/bytes, Rest/binary>> -> + ok; + _ -> + BinFragment = Bin, + Rest = <<>> + end, + FragLength = byte_size(BinFragment), + Frag = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(Offset), ?uint24(FragLength), BinFragment], + %% TODO Real solution, now avoid dialyzer error {Encoded, CS} = ssl_record:encode_handshake({Epoch, Frag}, Version, CS0), + {Encoded, CS} = ssl_record:encode_handshake(Frag, Version, CS0), + encode_handshake_record(Version, Epoch, MRS, MsgType, MsgSeq, Len, Rest, Offset + FragLength, MRS, [Encoded|Encoded0], CS). + +init_pack_records() -> + {[], []}. + +append_pack_records([], MSS, Recs = {Buf0, Acc0}) -> + Remaining = MSS - iolist_size(Buf0), + if Remaining < 12 -> + {[], [lists:reverse(Buf0)|Acc0]}; + true -> + Recs end; - -next_state(_, StateName, #ssl_tls{type = ?APPLICATION_DATA, fragment = Data}, State0) -> - %% Simplify for now to avoid dialzer warnings before implementation is compleate - %% case read_application_data(Data, State0) of - %% Stop = {stop,_,_} -> - %% Stop; - %% {Record, State} -> - %% next_state(StateName, StateName, Record, State) - %% end; - {Record, State} = read_application_data(Data, State0), - next_state(StateName, StateName, Record, State); - -next_state(Current, Next, #ssl_tls{type = ?CHANGE_CIPHER_SPEC, fragment = <<1>>} = - _ChangeCipher, - #state{connection_states = ConnectionStates0} = State0) -> - ConnectionStates1 = - ssl_record:activate_pending_connection_state(ConnectionStates0, read), - {Record, State} = next_record(State0#state{connection_states = ConnectionStates1}), - next_state(Current, Next, Record, State); -next_state(Current, Next, #ssl_tls{type = _Unknown}, State0) -> - %% Ignore unknown type - {Record, State} = next_record(State0), - next_state(Current, Next, Record, State). - -handle_dtls_handshake(Handle, StateName, - #state{protocol_buffers = - #protocol_buffers{dtls_packets = [Packet]} = Buffers} = State) -> - FsmReturn = {next_state, StateName, State#state{protocol_buffers = - Buffers#protocol_buffers{dtls_packets = []}}}, - Handle(Packet, FsmReturn); - -handle_dtls_handshake(Handle, StateName, - #state{protocol_buffers = - #protocol_buffers{dtls_packets = [Packet | Packets]} = Buffers} = - State0) -> - FsmReturn = {next_state, StateName, State0#state{protocol_buffers = - Buffers#protocol_buffers{dtls_packets = - Packets}}}, - case Handle(Packet, FsmReturn) of - {next_state, NextStateName, State, _Timeout} -> - handle_dtls_handshake(Handle, NextStateName, State); - {stop, _,_} = Stop -> - Stop +append_pack_records([Head|Tail], MSS, {Buf0, Acc0}) -> + TotLen = iolist_size(Buf0) + iolist_size(Head), + if TotLen > MSS -> + append_pack_records(Tail, MSS, {[Head], [lists:reverse(Buf0)|Acc0]}); + true -> + append_pack_records(Tail, MSS, {[Head|Buf0], Acc0}) end. +finish_pack_records({[], Acc}) -> + lists:reverse(Acc); +finish_pack_records({Buf, Acc}) -> + lists:reverse([lists:reverse(Buf)|Acc]). -send_flight(Fragments, #state{transport_cb = Transport, socket = Socket, - protocol_buffers = _PBuffers} = State) -> - Transport:send(Socket, Fragments), - %% Start retransmission - %% State#state{protocol_buffers = - %% (PBuffers#protocol_buffers){ #flight{state = waiting}}}}. - State. - -handle_own_alert(_,_,_, State) -> %% Place holder - {stop, {shutdown, own_alert}, State}. - -handle_normal_shutdown(_, _, _State) -> %% Place holder - ok. - -encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> - dtls_record:encode_change_cipher_spec(Version, ConnectionStates). +%% decode_alerts(Bin) -> +%% ssl_alert:decode(Bin). initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User, {CbModule, DataTag, CloseTag, ErrorTag}) -> - ConnectionStates = ssl_record:init_connection_states(Role), + #ssl_options{beast_mitigation = BeastMitigation} = SSLOptions, + ConnectionStates = ssl_record:init_connection_states(Role, BeastMitigation), SessionCacheCb = case application:get_env(ssl, session_cb) of {ok, Cb} when is_atom(Cb) -> @@ -514,21 +497,139 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User, renegotiation = {false, first}, allow_renegotiate = SSLOptions#ssl_options.client_renegotiation, start_or_recv_from = undefined, - send_queue = queue:new(), protocol_cb = ?MODULE }. + +next_tls_record(Data, #state{protocol_buffers = #protocol_buffers{ + dtls_record_buffer = Buf0, + dtls_cipher_texts = CT0} = Buffers} = State0) -> + case dtls_record:get_dtls_records(Data, Buf0) of + {Records, Buf1} -> + CT1 = CT0 ++ Records, + next_record(State0#state{protocol_buffers = + Buffers#protocol_buffers{dtls_record_buffer = Buf1, + dtls_cipher_texts = CT1}}); + + #alert{} = Alert -> + Alert + end. + +next_record(#state{%%flight = #flight{state = finished}, + protocol_buffers = + #protocol_buffers{dtls_packets = [], dtls_cipher_texts = [CT | Rest]} + = Buffers, + connection_states = ConnStates0} = State) -> + case dtls_record:decode_cipher_text(CT, ConnStates0) of + {Plain, ConnStates} -> + {Plain, State#state{protocol_buffers = + Buffers#protocol_buffers{dtls_cipher_texts = Rest}, + connection_states = ConnStates}}; + #alert{} = Alert -> + {Alert, State} + end; +next_record(#state{socket = Socket, + transport_cb = Transport} = State) -> %% when FlightState =/= finished + ssl_socket:setopts(Transport, Socket, [{active,once}]), + {no_record, State}; +next_record(State) -> + {no_record, State}. + +next_record_if_active(State = + #state{socket_options = + #socket_options{active = false}}) -> + {no_record ,State}; + +next_record_if_active(State) -> + next_record(State). + +passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName) -> + case Buffer of + <<>> -> + {Record, State} = next_record(State0), + next_event(StateName, Record, State); + _ -> + {Record, State} = read_application_data(<<>>, State0), + next_event(StateName, Record, State) + end. + +next_event(StateName, Record, State) -> + next_event(StateName, Record, State, []). + +next_event(connection = StateName, no_record, State0, Actions) -> + case next_record_if_active(State0) of + {no_record, State} -> + ssl_connection:hibernate_after(StateName, State, Actions); + {#ssl_tls{} = Record, State} -> + {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; + {#alert{} = Alert, State} -> + {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} + end; +next_event(StateName, Record, State, Actions) -> + case Record of + no_record -> + {next_state, StateName, State, Actions}; + #ssl_tls{} = Record -> + {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; + #alert{} = Alert -> + {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} + end. + read_application_data(_,State) -> {#ssl_tls{fragment = <<"place holder">>}, State}. - -next_tls_record(_, State) -> - {#ssl_tls{fragment = <<"place holder">>}, State}. -get_timeout(_) -> %% Place holder - infinity. +handle_own_alert(_,_,_, State) -> %% Place holder + {stop, {shutdown, own_alert}, State}. -next_state_connection(_, State) -> %% Place holder - {next_state, connection, State, get_timeout(State)}. +handle_normal_shutdown(_, _, _State) -> %% Place holder + ok. -sequence(_) -> - %%TODO real imp - 1. +%% TODO This generates dialyzer warnings, has to be handled differently. +%% handle_packet(Address, Port, Packet) -> +%% try dtls_record:get_dtls_records(Packet, <<>>) of +%% %% expect client hello +%% {[#ssl_tls{type = ?HANDSHAKE, version = {254, _}} = Record], <<>>} -> +%% handle_dtls_client_hello(Address, Port, Record); +%% _Other -> +%% {error, not_dtls} +%% catch +%% _Class:_Error -> +%% {error, not_dtls} +%% end. + +%% handle_dtls_client_hello(Address, Port, +%% #ssl_tls{epoch = Epoch, sequence_number = Seq, +%% version = Version} = Record) -> +%% {[{Hello, _}], _} = +%% dtls_handshake:get_dtls_handshake(Record, +%% dtls_handshake:dtls_handshake_new_flight(undefined)), +%% #client_hello{client_version = {Major, Minor}, +%% random = Random, +%% session_id = SessionId, +%% cipher_suites = CipherSuites, +%% compression_methods = CompressionMethods} = Hello, +%% CookieData = [address_to_bin(Address, Port), +%% <<?BYTE(Major), ?BYTE(Minor)>>, +%% Random, SessionId, CipherSuites, CompressionMethods], +%% Cookie = crypto:hmac(sha, <<"secret">>, CookieData), + +%% case Hello of +%% #client_hello{cookie = Cookie} -> +%% accept; + +%% _ -> +%% %% generate HelloVerifyRequest +%% {RequestFragment, _} = dtls_handshake:encode_handshake( +%% dtls_handshake:hello_verify_request(Cookie), +%% Version, 0), +%% HelloVerifyRequest = +%% dtls_record:encode_tls_cipher_text(?HANDSHAKE, Version, Epoch, Seq, RequestFragment), +%% {reply, HelloVerifyRequest} +%% end. + +%% address_to_bin({A,B,C,D}, Port) -> +%% <<0:80,16#ffff:16,A,B,C,D,Port:16>>; +%% address_to_bin({A,B,C,D,E,F,G,H}, Port) -> +%% <<A:16,B:16,C:16,D:16,E:16,F:16,G:16,H:16,Port:16>>. + +sequence(#connection_states{dtls_write_msg_seq = Seq} = CS) -> + {Seq, CS#connection_states{dtls_write_msg_seq = Seq + 1}}. |