diff options
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 81 | ||||
-rw-r--r-- | lib/ssl/src/dtls_handshake.erl | 17 | ||||
-rw-r--r-- | lib/ssl/src/dtls_record.erl | 10 | ||||
-rw-r--r-- | lib/ssl/src/ssl_record.erl | 15 |
4 files changed, 88 insertions, 35 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index 014e915e12..3ebc340ff8 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -54,7 +54,7 @@ %% Data handling -export([%%write_application_data/3, read_application_data/2, - %%passive_receive/2, + passive_receive/2, next_record_if_active/1 %%, %%handle_common_event/4 ]). @@ -356,20 +356,81 @@ format_status(Type, Data) -> %%% Internal functions %%-------------------------------------------------------------------- encode_handshake(Handshake, Version, ConnectionStates0, Hist0) -> - Seq = sequence(ConnectionStates0), - {EncHandshake, FragmentedHandshake} = dtls_handshake:encode_handshake(Handshake, Version, - Seq), + {Seq, ConnectionStates} = sequence(ConnectionStates0), + {EncHandshake, Frag} = dtls_handshake:encode_handshake(Handshake, Version, Seq), Hist = ssl_handshake:update_handshake_history(Hist0, EncHandshake), - {Encoded, ConnectionStates} = - dtls_record:encode_handshake(FragmentedHandshake, - Version, ConnectionStates0), - {Encoded, ConnectionStates, Hist}. + {Frag, ConnectionStates, Hist}. encode_change_cipher(#change_cipher_spec{}, Version, ConnectionStates) -> dtls_record:encode_change_cipher_spec(Version, ConnectionStates). -decode_alerts(Bin) -> - ssl_alert:decode(Bin). +encode_handshake_flight(Flight, ConnectionStates) -> + MSS = 1400, + encode_handshake_records(Flight, ConnectionStates, MSS, init_pack_records()). + +encode_handshake_records([], CS, _MSS, Recs) -> + {finish_pack_records(Recs), CS}; + +encode_handshake_records([{Version, _Epoch, Frag = #change_cipher_spec{}}|Tail], ConnectionStates0, MSS, Recs0) -> + {Encoded, ConnectionStates} = + encode_change_cipher(Frag, Version, ConnectionStates0), + Recs = append_pack_records([Encoded], MSS, Recs0), + encode_handshake_records(Tail, ConnectionStates, MSS, Recs); + +encode_handshake_records([{Version, Epoch, {MsgType, MsgSeq, Bin}}|Tail], CS0, MSS, Recs0 = {Buf0, _}) -> + Space = MSS - iolist_size(Buf0), + Len = byte_size(Bin), + {Encoded, CS} = + encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, 0, MSS, [], CS0), + Recs = append_pack_records(Encoded, MSS, Recs0), + encode_handshake_records(Tail, CS, MSS, Recs). + +%% TODO: move to dtls_handshake???? +encode_handshake_record(_Version, _Epoch, _Space, _MsgType, _MsgSeq, _Len, <<>>, _Offset, _MRS, Encoded, CS) + when length(Encoded) > 0 -> + %% make sure we encode at least one segment (for empty messages like Server Hello Done + {lists:reverse(Encoded), CS}; + +encode_handshake_record(Version, Epoch, Space, MsgType, MsgSeq, Len, Bin, + Offset, MRS, Encoded0, CS0) -> + MaxFragmentLen = Space - 25, + case Bin of + <<BinFragment:MaxFragmentLen/bytes, Rest/binary>> -> + ok; + _ -> + BinFragment = Bin, + Rest = <<>> + end, + FragLength = byte_size(BinFragment), + Frag = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(Offset), ?uint24(FragLength), BinFragment], + {Encoded, CS} = ssl_record:encode_handshake({Epoch, Frag}, Version, CS0), + encode_handshake_record(Version, Epoch, MRS, MsgType, MsgSeq, Len, Rest, Offset + FragLength, MRS, [Encoded|Encoded0], CS). + +init_pack_records() -> + {[], []}. + +append_pack_records([], MSS, Recs = {Buf0, Acc0}) -> + Remaining = MSS - iolist_size(Buf0), + if Remaining < 12 -> + {[], [lists:reverse(Buf0)|Acc0]}; + true -> + Recs + end; +append_pack_records([Head|Tail], MSS, {Buf0, Acc0}) -> + TotLen = iolist_size(Buf0) + iolist_size(Head), + if TotLen > MSS -> + append_pack_records(Tail, MSS, {[Head], [lists:reverse(Buf0)|Acc0]}); + true -> + append_pack_records(Tail, MSS, {[Head|Buf0], Acc0}) + end. + +finish_pack_records({[], Acc}) -> + lists:reverse(Acc); +finish_pack_records({Buf, Acc}) -> + lists:reverse([lists:reverse(Buf)|Acc]). + +%% decode_alerts(Bin) -> +%% ssl_alert:decode(Bin). initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions}, User, {CbModule, DataTag, CloseTag, ErrorTag}) -> diff --git a/lib/ssl/src/dtls_handshake.erl b/lib/ssl/src/dtls_handshake.erl index 4f48704cac..dbb03096ab 100644 --- a/lib/ssl/src/dtls_handshake.erl +++ b/lib/ssl/src/dtls_handshake.erl @@ -136,9 +136,9 @@ hello(#client_hello{client_version = ClientVersion}, _Options, {_,_,_,_,Connecti encode_handshake(Handshake, Version, MsgSeq) -> {MsgType, Bin} = enc_handshake(Handshake, Version), Len = byte_size(Bin), - EncHandshake = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(0), ?uint24(Len), Bin], - FragmentedHandshake = dtls_fragment(erlang:iolist_size(EncHandshake), MsgType, Len, MsgSeq, Bin, 0, []), - {EncHandshake, FragmentedHandshake}. + Enc = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(0), ?uint24(Len), Bin], + Frag = {MsgType, MsgSeq, Bin}, + {Enc, Frag}. %%-------------------------------------------------------------------- -spec get_dtls_handshake(#ssl_tls{}, #dtls_hs_state{} | binary()) -> @@ -189,17 +189,6 @@ handle_server_hello_extensions(Version, SessionId, Random, CipherSuite, {Version, SessionId, ConnectionStates, ProtoExt, Protocol} end. -dtls_fragment(Mss, MsgType, Len, MsgSeq, Bin, Offset, Acc) - when byte_size(Bin) + 12 < Mss -> - FragmentLen = byte_size(Bin), - BinMsg = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(Offset), ?uint24(FragmentLen), Bin], - lists:reverse([BinMsg|Acc]); -dtls_fragment(Mss, MsgType, Len, MsgSeq, Bin, Offset, Acc) -> - FragmentLen = Mss - 12, - <<Fragment:FragmentLen/bytes, Rest/binary>> = Bin, - BinMsg = [MsgType, ?uint24(Len), ?uint16(MsgSeq), ?uint24(Offset), ?uint24(FragmentLen), Fragment], - dtls_fragment(Mss, MsgType, Len, MsgSeq, Rest, Offset + FragmentLen, [BinMsg|Acc]). - get_dtls_handshake_aux(#ssl_tls{version = Version, sequence_number = SeqNo, fragment = Data}, HsState) -> diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl index e79e1cede0..ed8024d892 100644 --- a/lib/ssl/src/dtls_record.erl +++ b/lib/ssl/src/dtls_record.erl @@ -36,7 +36,7 @@ -export([decode_cipher_text/2]). %% Encoding --export([encode_plain_text/4, encode_handshake/3, encode_change_cipher_spec/2]). +-export([encode_plain_text/4, encode_change_cipher_spec/2]). %% Protocol version handling -export([protocol_version/1, lowest_protocol_version/2, lowest_protocol_version/1, @@ -208,14 +208,6 @@ decode_cipher_text(#ssl_tls{type = Type, version = Version, false -> ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) end. -%%-------------------------------------------------------------------- --spec encode_handshake(iolist(), dtls_version(), #connection_states{}) -> - {iolist(), #connection_states{}}. -%% -%% Description: Encodes a handshake message to send on the ssl-socket. -%%-------------------------------------------------------------------- -encode_handshake(Frag, Version, ConnectionStates) -> - encode_plain_text(?HANDSHAKE, Version, Frag, ConnectionStates). %%-------------------------------------------------------------------- -spec encode_change_cipher_spec(dtls_version(), #connection_states{}) -> diff --git a/lib/ssl/src/ssl_record.erl b/lib/ssl/src/ssl_record.erl index 0a086f5eeb..2bd282c664 100644 --- a/lib/ssl/src/ssl_record.erl +++ b/lib/ssl/src/ssl_record.erl @@ -320,14 +320,25 @@ encode_handshake(Frag, Version, beast_mitigation = BeastMitigation, security_parameters = #security_parameters{bulk_cipher_algorithm = BCA}}} = - ConnectionStates) -> + ConnectionStates) +when is_list(Frag) -> case iolist_size(Frag) of N when N > ?MAX_PLAIN_TEXT_LENGTH -> Data = split_bin(iolist_to_binary(Frag), ?MAX_PLAIN_TEXT_LENGTH, Version, BCA, BeastMitigation), encode_iolist(?HANDSHAKE, Data, Version, ConnectionStates); _ -> encode_plain_text(?HANDSHAKE, Version, Frag, ConnectionStates) - end. + end; +%% TODO: this is a workarround for DTLS +%% +%% DTLS need to select the connection write state based on Epoch it wants to +%% send this fragment in. That Epoch does not nessarily has to be the same +%% as the current_write epoch. +%% The right solution might be to pass the WriteState instead of the ConnectionStates, +%% however, this will require substantion API changes. +encode_handshake(Frag, Version, ConnectionStates) -> + encode_plain_text(?HANDSHAKE, Version, Frag, ConnectionStates). + %%-------------------------------------------------------------------- -spec encode_alert_record(#alert{}, ssl_version(), #connection_states{}) -> {iolist(), #connection_states{}}. |