diff options
Diffstat (limited to 'lib/ssl/src/dtls_connection.erl')
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 408 |
1 files changed, 286 insertions, 122 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index 60a61bc901..b8be686b99 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -42,9 +42,8 @@ -export([next_record/1, next_event/3]). %% Handshake handling --export([%%renegotiate/2, - send_handshake/2, send_change_cipher/2]). - +-export([%%renegotiate/2, + send_handshake/2, queue_handshake/2, queue_change_cipher/2]). %% Alert and close handling -export([%%send_alert/2, handle_own_alert/4, handle_close_alert/3, @@ -53,11 +52,12 @@ ]). %% Data handling + -export([%%write_application_data/3, read_application_data/2, - %%passive_receive/2, - next_record_if_active/1 %%, - %%handle_common_event/4 + passive_receive/2, next_record_if_active/1%, + %%handle_common_event/4, + %handle_packet/3 ]). %% gen_statem state functions @@ -72,13 +72,13 @@ %%==================================================================== %% Internal application API %%==================================================================== -start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_} = Opts, +start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_, Tracker} = Opts, User, {CbModule, _,_, _} = CbInfo, Timeout) -> try {ok, Pid} = dtls_connection_sup:start_child([Role, Host, Port, Socket, Opts, User, CbInfo]), - {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule), + {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker), ok = ssl_connection:handshake(SslSocket, Timeout), {ok, SslSocket} catch @@ -86,13 +86,13 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = false},_} = Opts, Error end; -start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_} = Opts, +start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} = Opts, User, {CbModule, _,_, _} = CbInfo, Timeout) -> try {ok, Pid} = dtls_connection_sup:start_child_dist([Role, Host, Port, Socket, Opts, User, CbInfo]), - {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule), + {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, Pid, CbModule, Tracker), ok = ssl_connection:handshake(SslSocket, Timeout), {ok, SslSocket} catch @@ -100,14 +100,37 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_} = Opts, Error end. -send_handshake(Handshake, #state{negotiated_version = Version, - tls_handshake_history = Hist0, - connection_states = ConnectionStates0} = State0) -> - {BinHandshake, ConnectionStates, Hist} = +send_handshake(Handshake, State) -> + send_handshake_flight(queue_handshake(Handshake, State)). + +queue_flight_buffer(Msg, #state{negotiated_version = Version, + connection_states = #connection_states{ + current_write = + #connection_state{epoch = Epoch}}, + flight_buffer = Flight} = State) -> + State#state{flight_buffer = Flight ++ [{Version, Epoch, Msg}]}. + +queue_handshake(Handshake, #state{negotiated_version = Version, + tls_handshake_history = Hist0, + connection_states = ConnectionStates0} = State0) -> + {Frag, ConnectionStates, Hist} = encode_handshake(Handshake, Version, ConnectionStates0, Hist0), - send_flight(BinHandshake, State0#state{connection_states = ConnectionStates, - tls_handshake_history = Hist - }). + queue_flight_buffer(Frag, State0#state{connection_states = ConnectionStates, + tls_handshake_history = Hist}). + +send_handshake_flight(#state{socket = Socket, + transport_cb = Transport, + flight_buffer = Flight, + connection_states = ConnectionStates0} = State0) -> + + {Encoded, ConnectionStates} = + encode_handshake_flight(Flight, ConnectionStates0), + + Transport:send(Socket, Encoded), + State0#state{flight_buffer = [], connection_states = ConnectionStates}. + +queue_change_cipher(Msg, State) -> + queue_flight_buffer(Msg, State). send_alert(Alert, #state{negotiated_version = Version, socket = Socket, @@ -118,15 +141,6 @@ send_alert(Alert, #state{negotiated_version = Version, Transport:send(Socket, BinMsg), State0#state{connection_states = ConnectionStates}. -send_change_cipher(Msg, #state{connection_states = ConnectionStates0, - socket = Socket, - negotiated_version = Version, - transport_cb = Transport} = State0) -> - {BinChangeCipher, ConnectionStates} = - encode_change_cipher(Msg, Version, ConnectionStates0), - Transport:send(Socket, BinChangeCipher), - State0#state{connection_states = ConnectionStates}. - %%==================================================================== %% tls_connection_sup API %%==================================================================== @@ -196,28 +210,44 @@ error({call, From}, Msg, State) -> error(_, _, _) -> {keep_state_and_data, [postpone]}. -hello(internal, #client_hello{client_version = ClientVersion} = Hello, - #state{connection_states = ConnectionStates0, - port = Port, session = #session{own_certificate = Cert} = Session0, - renegotiation = {Renegotiation, _}, - session_cache = Cache, - session_cache_cb = CacheCb, - ssl_options = SslOpts} = State) -> +%%-------------------------------------------------------------------- +-spec hello(gen_statem:event_type(), + #hello_request{} | #client_hello{} | #server_hello{} | term(), + #state{}) -> + gen_statem:state_function_result(). +%%-------------------------------------------------------------------- +hello(internal, #client_hello{client_version = ClientVersion, + extensions = #hello_extensions{ec_point_formats = EcPointFormats, + elliptic_curves = EllipticCurves}} = Hello, + State = #state{connection_states = ConnectionStates0, + port = Port, session = #session{own_certificate = Cert} = Session0, + renegotiation = {Renegotiation, _}, + session_cache = Cache, + session_cache_cb = CacheCb, + negotiated_protocol = CurrentProtocol, + key_algorithm = KeyExAlg, + ssl_options = SslOpts}) -> + case dtls_handshake:hello(Hello, SslOpts, {Port, Session0, Cache, CacheCb, - ConnectionStates0, Cert}, Renegotiation) of - {Version, {Type, Session}, - ConnectionStates, - #hello_extensions{ec_point_formats = EcPointFormats, - elliptic_curves = EllipticCurves} = ServerHelloExt, HashSign} -> - ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt, HashSign}, - State#state{connection_states = ConnectionStates, + ConnectionStates0, Cert, KeyExAlg}, Renegotiation) of + #alert{} = Alert -> + handle_own_alert(Alert, ClientVersion, hello, State); + {Version, {Type, Session}, + ConnectionStates, Protocol0, ServerHelloExt, HashSign} -> + Protocol = case Protocol0 of + undefined -> CurrentProtocol; + _ -> Protocol0 + end, + + ssl_connection:hello(internal, {common_client_hello, Type, ServerHelloExt}, + State#state{connection_states = ConnectionStates, negotiated_version = Version, + hashsign_algorithm = HashSign, session = Session, - client_ecc = {EllipticCurves, EcPointFormats}}, ?MODULE); - #alert{} = Alert -> - handle_own_alert(Alert, ClientVersion, hello, State) + client_ecc = {EllipticCurves, EcPointFormats}, + negotiated_protocol = Protocol}, ?MODULE) end; -hello(internal, Hello, +hello(internal, #server_hello{} = Hello, #state{connection_states = ConnectionStates0, negotiated_version = ReqVersion, role = client, @@ -261,8 +291,7 @@ connection(internal, #hello_request{}, #state{host = Host, port = Port, renegotiation = {Renegotiation, _}} = State0) -> Hello = dtls_handshake:client_hello(Host, Port, ConnectionStates0, SslOpts, Cache, CacheCb, Renegotiation, Cert), - %% TODO DTLS version State1 = send_handshake(Hello, State0), - State1 = State0, + State1 = send_handshake(Hello, State0), {Record, State} = next_record( State1#state{session = Session0#session{session_id @@ -309,11 +338,24 @@ handle_info({Protocol, _, Data}, StateName, {stop, {shutdown, own_alert}} end; handle_info({CloseTag, Socket}, StateName, - #state{socket = Socket, close_tag = CloseTag, - negotiated_version = _Version} = State) -> + #state{socket = Socket, close_tag = CloseTag, + negotiated_version = Version} = State) -> + %% Note that as of DTLS 1.2 (TLS 1.1), + %% failure to properly close a connection no longer requires that a + %% session not be resumed. This is a change from DTLS 1.0 to conform + %% with widespread implementation practice. + case Version of + {254, N} when N =< 253 -> + ok; + _ -> + %% As invalidate_sessions here causes performance issues, + %% we will conform to the widespread implementation + %% practice and go aginst the spec + %%invalidate_session(Role, Host, Port, Session) + ok + end, handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State), {stop, {shutdown, transport_closed}}; - handle_info(Msg, StateName, State) -> ssl_connection:handle_info(Msg, StateName, State). @@ -343,76 +385,82 @@ format_status(Type, Data) -> %%% Internal functions %%-------------------------------------------------------------------- encode_handshake(Handshake, Version, ConnectionStates0, Hist0) -> - Seq = sequence(ConnectionStates0), - {EncHandshake, FragmentedHandshake} = dtls_handshake:encode_handshake(Handshake, Version, - Seq), + {Seq, ConnectionStates} = sequence(ConnectionStates0), + {EncHandshake, Frag} = dtls_handshake:encode_handshake(Handshake, Version, Seq), Hist = ssl_handshake:update_handshake_history(Hist0, EncHandshake), - {Encoded, ConnectionStates} = - dtls_record:encode_handshake(FragmentedHandshake, - Version, ConnectionStates0), - {Encoded, ConnectionStates, Hist}. - -next_record(#state{%%flight = #flight{state = finished}, - protocol_buffers = - #protocol_buffers{dtls_packets = [], dtls_cipher_texts = [CT | Rest]} - = Buffers, - connection_states = ConnStates0} = State) -> - case dtls_record:decode_cipher_text(CT, ConnStates0) of - {Plain, ConnStates} -> - {Plain, State#state{protocol_buffers = - Buffers#protocol_buffers{dtls_cipher_texts = Rest}, - connection_states = ConnStates}}; - #alert{} = Alert -> - {Alert, State} - end; -next_record(#state{socket = Socket, - transport_cb = Transport} = State) -> %% when FlightState =/= finished - ssl_socket:setopts(Transport, Socket, [{active,once}]), - {no_record, State}; - + {Frag, ConnectionStates, Hist}. -next_record(State) -> - {no_record, State}. +encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> + dtls_record:encode_change_cipher_spec(Version, ConnectionStates). +encode_handshake_flight(Flight, ConnectionStates) -> + MSS = 1400, + encode_handshake_records(Flight, ConnectionStates, MSS, init_pack_records()). -next_event(StateName, Record, State) -> - next_event(StateName, Record, State, []). +encode_handshake_records([], CS, _MSS, Recs) -> + {finish_pack_records(Recs), CS}; -next_event(connection = StateName, no_record, State0, Actions) -> - case next_record_if_active(State0) of - {no_record, State} -> - ssl_connection:hibernate_after(StateName, State, Actions); - {#ssl_tls{} = Record, State} -> - {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; - {#alert{} = Alert, State} -> - {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} +encode_handshake_records([{Version, _Epoch, Frag = #change_cipher_spec{}}|Tail], ConnectionStates0, MSS, Recs0) -> + {Encoded, ConnectionStates} = + encode_change_cipher(Frag, Version, ConnectionStates0), + Recs = append_pack_records([Encoded], MSS, Recs0), + encode_handshake_records(Tail, ConnectionStates, MSS, Recs); + +encode_handshake_records([{Version, Epoch, {MsgType, MsgSeq, Bin}}|Tail], CS0, MSS, Recs0 = {Buf0, _}) -> + Space = MSS - iolist_size(Buf0), + Len = byte_size(Bin), + {Encoded, CS} = + encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, 0, MSS, [], CS0), + Recs = append_pack_records(Encoded, MSS, Recs0), + encode_handshake_records(Tail, CS, MSS, Recs). + +%% TODO: move to dtls_handshake???? +encode_handshake_record(_Version, _Epoch, _Space, _MsgType, _MsgSeq, _Len, <<>>, _Offset, _MRS, Encoded, CS) + when length(Encoded) > 0 -> + %% make sure we encode at least one segment (for empty messages like Server Hello Done + {lists:reverse(Encoded), CS}; + +encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, + Offset, MRS, Encoded0, CS0) -> + MaxFragmentLen = Space - 25, + case Bin of + <<BinFragment:MaxFragmentLen/bytes, Rest/binary>> -> + ok; + _ -> + BinFragment = Bin, + Rest = <<>> + end, + FragLength = byte_size(BinFragment), + Frag = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(Offset), ?uint24(FragLength), BinFragment], + %% TODO Real solution, now avoid dialyzer error {Encoded, CS} = ssl_record:encode_handshake({Epoch, Frag}, Version, CS0), + {Encoded, CS} = ssl_record:encode_handshake(Frag, Version, CS0), + encode_handshake_record(Version, Epoch, MRS, MsgType, MsgSeq, Len, Rest, Offset + FragLength, MRS, [Encoded|Encoded0], CS). + +init_pack_records() -> + {[], []}. + +append_pack_records([], MSS, Recs = {Buf0, Acc0}) -> + Remaining = MSS - iolist_size(Buf0), + if Remaining < 12 -> + {[], [lists:reverse(Buf0)|Acc0]}; + true -> + Recs end; -next_event(StateName, Record, State, Actions) -> - case Record of - no_record -> - {next_state, StateName, State, Actions}; - #ssl_tls{} = Record -> - {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; - #alert{} = Alert -> - {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} +append_pack_records([Head|Tail], MSS, {Buf0, Acc0}) -> + TotLen = iolist_size(Buf0) + iolist_size(Head), + if TotLen > MSS -> + append_pack_records(Tail, MSS, {[Head], [lists:reverse(Buf0)|Acc0]}); + true -> + append_pack_records(Tail, MSS, {[Head|Buf0], Acc0}) end. -send_flight(Fragments, #state{transport_cb = Transport, socket = Socket, - protocol_buffers = _PBuffers} = State) -> - Transport:send(Socket, Fragments), - %% Start retransmission - %% State#state{protocol_buffers = - %% (PBuffers#protocol_buffers){ #flight{state = waiting}}}}. - State. +finish_pack_records({[], Acc}) -> + lists:reverse(Acc); +finish_pack_records({Buf, Acc}) -> + lists:reverse([lists:reverse(Buf)|Acc]). -handle_own_alert(_,_,_, State) -> %% Place holder - {stop, {shutdown, own_alert}, State}. - -handle_normal_shutdown(_, _, _State) -> %% Place holder - ok. - -encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> - dtls_record:encode_change_cipher_spec(Version, ConnectionStates). +%% decode_alerts(Bin) -> +%% ssl_alert:decode(Bin). initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User, {CbModule, DataTag, CloseTag, ErrorTag}) -> @@ -451,21 +499,137 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User, start_or_recv_from = undefined, protocol_cb = ?MODULE }. -read_application_data(_,State) -> - {#ssl_tls{fragment = <<"place holder">>}, State}. -next_tls_record(<<>>, _State) -> - #alert{}; %% Place holder -next_tls_record(_, State) -> - {#ssl_tls{fragment = <<"place holder">>}, State}. +next_tls_record(Data, #state{protocol_buffers = #protocol_buffers{ + dtls_record_buffer = Buf0, + dtls_cipher_texts = CT0} = Buffers} = State0) -> + case dtls_record:get_dtls_records(Data, Buf0) of + {Records, Buf1} -> + CT1 = CT0 ++ Records, + next_record(State0#state{protocol_buffers = + Buffers#protocol_buffers{dtls_record_buffer = Buf1, + dtls_cipher_texts = CT1}}); + + #alert{} = Alert -> + Alert + end. + +next_record(#state{%%flight = #flight{state = finished}, + protocol_buffers = + #protocol_buffers{dtls_packets = [], dtls_cipher_texts = [CT | Rest]} + = Buffers, + connection_states = ConnStates0} = State) -> + case dtls_record:decode_cipher_text(CT, ConnStates0) of + {Plain, ConnStates} -> + {Plain, State#state{protocol_buffers = + Buffers#protocol_buffers{dtls_cipher_texts = Rest}, + connection_states = ConnStates}}; + #alert{} = Alert -> + {Alert, State} + end; +next_record(#state{socket = Socket, + transport_cb = Transport} = State) -> %% when FlightState =/= finished + ssl_socket:setopts(Transport, Socket, [{active,once}]), + {no_record, State}; +next_record(State) -> + {no_record, State}. -sequence(_) -> - %%TODO real imp - 1. -next_record_if_active(State = - #state{socket_options = - #socket_options{active = false}}) -> +next_record_if_active(State = + #state{socket_options = + #socket_options{active = false}}) -> {no_record ,State}; next_record_if_active(State) -> next_record(State). + +passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName) -> + case Buffer of + <<>> -> + {Record, State} = next_record(State0), + next_event(StateName, Record, State); + _ -> + {Record, State} = read_application_data(<<>>, State0), + next_event(StateName, Record, State) + end. + +next_event(StateName, Record, State) -> + next_event(StateName, Record, State, []). + +next_event(connection = StateName, no_record, State0, Actions) -> + case next_record_if_active(State0) of + {no_record, State} -> + ssl_connection:hibernate_after(StateName, State, Actions); + {#ssl_tls{} = Record, State} -> + {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; + {#alert{} = Alert, State} -> + {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} + end; +next_event(StateName, Record, State, Actions) -> + case Record of + no_record -> + {next_state, StateName, State, Actions}; + #ssl_tls{} = Record -> + {next_state, StateName, State, [{next_event, internal, {dtls_record, Record}} | Actions]}; + #alert{} = Alert -> + {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} + end. + +read_application_data(_,State) -> + {#ssl_tls{fragment = <<"place holder">>}, State}. + +handle_own_alert(_,_,_, State) -> %% Place holder + {stop, {shutdown, own_alert}, State}. + +handle_normal_shutdown(_, _, _State) -> %% Place holder + ok. + +%% TODO This generates dialyzer warnings, has to be handled differently. +%% handle_packet(Address, Port, Packet) -> +%% try dtls_record:get_dtls_records(Packet, <<>>) of +%% %% expect client hello +%% {[#ssl_tls{type = ?HANDSHAKE, version = {254, _}} = Record], <<>>} -> +%% handle_dtls_client_hello(Address, Port, Record); +%% _Other -> +%% {error, not_dtls} +%% catch +%% _Class:_Error -> +%% {error, not_dtls} +%% end. + +%% handle_dtls_client_hello(Address, Port, +%% #ssl_tls{epoch = Epoch, sequence_number = Seq, +%% version = Version} = Record) -> +%% {[{Hello, _}], _} = +%% dtls_handshake:get_dtls_handshake(Record, +%% dtls_handshake:dtls_handshake_new_flight(undefined)), +%% #client_hello{client_version = {Major, Minor}, +%% random = Random, +%% session_id = SessionId, +%% cipher_suites = CipherSuites, +%% compression_methods = CompressionMethods} = Hello, +%% CookieData = [address_to_bin(Address, Port), +%% <<?BYTE(Major), ?BYTE(Minor)>>, +%% Random, SessionId, CipherSuites, CompressionMethods], +%% Cookie = crypto:hmac(sha, <<"secret">>, CookieData), + +%% case Hello of +%% #client_hello{cookie = Cookie} -> +%% accept; + +%% _ -> +%% %% generate HelloVerifyRequest +%% {RequestFragment, _} = dtls_handshake:encode_handshake( +%% dtls_handshake:hello_verify_request(Cookie), +%% Version, 0), +%% HelloVerifyRequest = +%% dtls_record:encode_tls_cipher_text(?HANDSHAKE, Version, Epoch, Seq, RequestFragment), +%% {reply, HelloVerifyRequest} +%% end. + +%% address_to_bin({A,B,C,D}, Port) -> +%% <<0:80,16#ffff:16,A,B,C,D,Port:16>>; +%% address_to_bin({A,B,C,D,E,F,G,H}, Port) -> +%% <<A:16,B:16,C:16,D:16,E:16,F:16,G:16,H:16,Port:16>>. + +sequence(#connection_states{dtls_write_msg_seq = Seq} = CS) -> + {Seq, CS#connection_states{dtls_write_msg_seq = Seq + 1}}. |