diff options
Diffstat (limited to 'lib/ssl/src/dtls_record.erl')
-rw-r--r-- | lib/ssl/src/dtls_record.erl | 343 |
1 files changed, 219 insertions, 124 deletions
diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl index f447897d59..9eb0d8e2d7 100644 --- a/lib/ssl/src/dtls_record.erl +++ b/lib/ssl/src/dtls_record.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2013-2016. All Rights Reserved. +%% Copyright Ericsson AB 2013-2018. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -30,35 +30,36 @@ -include("ssl_cipher.hrl"). %% Handling of incoming data --export([get_dtls_records/2, init_connection_states/2]). +-export([get_dtls_records/3, init_connection_states/2, empty_connection_state/1]). -%% Decoding --export([decode_cipher_text/2]). +-export([save_current_connection_state/2, next_epoch/2, get_connection_state_by_epoch/3, replay_detect/2, + init_connection_state_seq/2, current_connection_state_epoch/2]). %% Encoding -export([encode_handshake/4, encode_alert_record/3, - encode_change_cipher_spec/3, encode_data/3]). --export([encode_plain_text/5]). + encode_change_cipher_spec/3, encode_data/3, encode_plain_text/5]). + +%% Decoding +-export([decode_cipher_text/2]). %% Protocol version handling -export([protocol_version/1, lowest_protocol_version/1, lowest_protocol_version/2, highest_protocol_version/1, highest_protocol_version/2, is_higher/2, supported_protocol_versions/0, - is_acceptable_version/2]). - --export([save_current_connection_state/2, next_epoch/2]). + is_acceptable_version/2, hello_version/2]). --export([init_connection_state_seq/2, current_connection_state_epoch/2]). -export_type([dtls_version/0, dtls_atom_version/0]). -type dtls_version() :: ssl_record:ssl_version(). -type dtls_atom_version() :: dtlsv1 | 'dtlsv1.2'. +-define(REPLAY_WINDOW_SIZE, 64). + -compile(inline). %%==================================================================== -%% Internal application API +%% Handling of incoming data %%==================================================================== %%-------------------------------------------------------------------- -spec init_connection_states(client | server, one_n_minus_one | zero_n | disabled) -> @@ -73,7 +74,7 @@ init_connection_states(Role, BeastMitigation) -> Initial = initial_connection_state(ConnectionEnd, BeastMitigation), Current = Initial#{epoch := 0}, InitialPending = ssl_record:empty_connection_state(ConnectionEnd, BeastMitigation), - Pending = InitialPending#{epoch => undefined}, + Pending = empty_connection_state(InitialPending), #{saved_read => Current, current_read => Current, pending_read => Pending, @@ -81,6 +82,9 @@ init_connection_states(Role, BeastMitigation) -> current_write => Current, pending_write => Pending}. +empty_connection_state(Empty) -> + Empty#{epoch => undefined, replay_window => init_replay_window(?REPLAY_WINDOW_SIZE)}. + %%-------------------------------------------------------------------- -spec save_current_connection_state(ssl_record:connection_states(), read | write) -> ssl_record:connection_states(). @@ -96,11 +100,13 @@ save_current_connection_state(#{current_write := Current} = States, write) -> next_epoch(#{pending_read := Pending, current_read := #{epoch := Epoch}} = States, read) -> - States#{pending_read := Pending#{epoch := Epoch + 1}}; + States#{pending_read := Pending#{epoch := Epoch + 1, + replay_window := init_replay_window(?REPLAY_WINDOW_SIZE)}}; next_epoch(#{pending_write := Pending, current_write := #{epoch := Epoch}} = States, write) -> - States#{pending_write := Pending#{epoch := Epoch + 1}}. + States#{pending_write := Pending#{epoch := Epoch + 1, + replay_window := init_replay_window(?REPLAY_WINDOW_SIZE)}}. get_connection_state_by_epoch(Epoch, #{current_write := #{epoch := Epoch} = Current}, write) -> @@ -129,67 +135,58 @@ set_connection_state_by_epoch(ReadState, Epoch, #{saved_read := #{epoch := Epoch States#{saved_read := ReadState}. %%-------------------------------------------------------------------- --spec get_dtls_records(binary(), binary()) -> {[binary()], binary()} | #alert{}. +-spec init_connection_state_seq(dtls_version(), ssl_record:connection_states()) -> + ssl_record:connection_state(). +%% +%% Description: Copy the read sequence number to the write sequence number +%% This is only valid for DTLS in the first client_hello +%%-------------------------------------------------------------------- +init_connection_state_seq({254, _}, + #{current_read := #{epoch := 0, sequence_number := Seq}, + current_write := #{epoch := 0} = Write} = ConnnectionStates0) -> + ConnnectionStates0#{current_write => Write#{sequence_number => Seq}}; +init_connection_state_seq(_, ConnnectionStates) -> + ConnnectionStates. + +%%-------------------------------------------------------- +-spec current_connection_state_epoch(ssl_record:connection_states(), read | write) -> + integer(). +%% +%% Description: Returns the epoch the connection_state record +%% that is currently defined as the current connection state. +%%-------------------------------------------------------------------- +current_connection_state_epoch(#{current_read := #{epoch := Epoch}}, + read) -> + Epoch; +current_connection_state_epoch(#{current_write := #{epoch := Epoch}}, + write) -> + Epoch. + +%%-------------------------------------------------------------------- +-spec get_dtls_records(binary(), [dtls_version()], binary()) -> {[binary()], binary()} | #alert{}. %% %% Description: Given old buffer and new data from UDP/SCTP, packs up a records %% and returns it as a list of tls_compressed binaries also returns leftover %% data %%-------------------------------------------------------------------- -get_dtls_records(Data, <<>>) -> - get_dtls_records_aux(Data, []); -get_dtls_records(Data, Buffer) -> - get_dtls_records_aux(list_to_binary([Buffer, Data]), []). - -get_dtls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer), - ?UINT16(Epoch), ?UINT48(SequenceNumber), - ?UINT16(Length), Data:Length/binary, Rest/binary>>, - Acc) -> - get_dtls_records_aux(Rest, [#ssl_tls{type = ?APPLICATION_DATA, - version = {MajVer, MinVer}, - epoch = Epoch, sequence_number = SequenceNumber, - fragment = Data} | Acc]); -get_dtls_records_aux(<<?BYTE(?HANDSHAKE),?BYTE(MajVer),?BYTE(MinVer), - ?UINT16(Epoch), ?UINT48(SequenceNumber), - ?UINT16(Length), - Data:Length/binary, Rest/binary>>, Acc) when MajVer >= 128 -> - get_dtls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE, - version = {MajVer, MinVer}, - epoch = Epoch, sequence_number = SequenceNumber, - fragment = Data} | Acc]); -get_dtls_records_aux(<<?BYTE(?ALERT),?BYTE(MajVer),?BYTE(MinVer), - ?UINT16(Epoch), ?UINT48(SequenceNumber), - ?UINT16(Length), Data:Length/binary, - Rest/binary>>, Acc) -> - get_dtls_records_aux(Rest, [#ssl_tls{type = ?ALERT, - version = {MajVer, MinVer}, - epoch = Epoch, sequence_number = SequenceNumber, - fragment = Data} | Acc]); -get_dtls_records_aux(<<?BYTE(?CHANGE_CIPHER_SPEC),?BYTE(MajVer),?BYTE(MinVer), - ?UINT16(Epoch), ?UINT48(SequenceNumber), - ?UINT16(Length), Data:Length/binary, Rest/binary>>, - Acc) -> - get_dtls_records_aux(Rest, [#ssl_tls{type = ?CHANGE_CIPHER_SPEC, - version = {MajVer, MinVer}, - epoch = Epoch, sequence_number = SequenceNumber, - fragment = Data} | Acc]); - -get_dtls_records_aux(<<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer), - ?UINT16(Length), _/binary>>, - _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH -> - ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); - -get_dtls_records_aux(<<1:1, Length0:15, _/binary>>,_Acc) - when Length0 > ?MAX_CIPHER_TEXT_LENGTH -> - ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); - -get_dtls_records_aux(Data, Acc) -> - case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of - true -> - {lists:reverse(Acc), Data}; - false -> - ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE) +get_dtls_records(Data, Versions, Buffer) -> + BinData = list_to_binary([Buffer, Data]), + case erlang:byte_size(BinData) of + N when N >= 3 -> + case assert_version(BinData, Versions) of + true -> + get_dtls_records_aux(BinData, []); + false -> + ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) + end; + _ -> + get_dtls_records_aux(BinData, []) end. +%%==================================================================== +%% Encoding DTLS records +%%==================================================================== + %%-------------------------------------------------------------------- -spec encode_handshake(iolist(), dtls_version(), integer(), ssl_record:connection_states()) -> {iolist(), ssl_record:connection_states()}. @@ -237,11 +234,19 @@ encode_plain_text(Type, Version, Epoch, Data, ConnectionStates) -> {CipherText, Write} = encode_dtls_cipher_text(Type, Version, CipherFragment, Write1), {CipherText, set_connection_state_by_epoch(Write, Epoch, ConnectionStates, write)}. +%%==================================================================== +%% Decoding +%%==================================================================== decode_cipher_text(#ssl_tls{epoch = Epoch} = CipherText, ConnnectionStates0) -> ReadState = get_connection_state_by_epoch(Epoch, ConnnectionStates0, read), decode_cipher_text(CipherText, ReadState, ConnnectionStates0). + +%%==================================================================== +%% Protocol version handling +%%==================================================================== + %%-------------------------------------------------------------------- -spec protocol_version(dtls_atom_version() | dtls_version()) -> dtls_version() | dtls_atom_version(). @@ -373,34 +378,15 @@ supported_protocol_versions([_|_] = Vsns) -> is_acceptable_version(Version, Versions) -> lists:member(Version, Versions). +-spec hello_version(dtls_version(), [dtls_version()]) -> dtls_version(). +hello_version(Version, Versions) -> + case dtls_v1:corresponding_tls_version(Version) of + TLSVersion when TLSVersion >= {3, 3} -> + Version; + _ -> + lowest_protocol_version(Versions) + end. -%%-------------------------------------------------------------------- --spec init_connection_state_seq(dtls_version(), ssl_record:connection_states()) -> - ssl_record:connection_state(). -%% -%% Description: Copy the read sequence number to the write sequence number -%% This is only valid for DTLS in the first client_hello -%%-------------------------------------------------------------------- -init_connection_state_seq({254, _}, - #{current_read := #{epoch := 0, sequence_number := Seq}, - current_write := #{epoch := 0} = Write} = ConnnectionStates0) -> - ConnnectionStates0#{current_write => Write#{sequence_number => Seq}}; -init_connection_state_seq(_, ConnnectionStates) -> - ConnnectionStates. - -%%-------------------------------------------------------- --spec current_connection_state_epoch(ssl_record:connection_states(), read | write) -> - integer(). -%% -%% Description: Returns the epoch the connection_state record -%% that is currently defined as the current conection state. -%%-------------------------------------------------------------------- -current_connection_state_epoch(#{current_read := #{epoch := Epoch}}, - read) -> - Epoch; -current_connection_state_epoch(#{current_write := #{epoch := Epoch}}, - write) -> - Epoch. %%-------------------------------------------------------------------- %%% Internal functions @@ -410,6 +396,7 @@ initial_connection_state(ConnectionEnd, BeastMitigation) -> ssl_record:initial_security_params(ConnectionEnd), epoch => undefined, sequence_number => 0, + replay_window => init_replay_window(?REPLAY_WINDOW_SIZE), beast_mitigation => BeastMitigation, compression_state => undefined, cipher_state => undefined, @@ -418,16 +405,91 @@ initial_connection_state(ConnectionEnd, BeastMitigation) -> client_verify_data => undefined, server_verify_data => undefined }. +assert_version(<<?BYTE(_), ?BYTE(MajVer), ?BYTE(MinVer), _/binary>>, Versions) -> + is_acceptable_version({MajVer, MinVer}, Versions). -lowest_list_protocol_version(Ver, []) -> - Ver; -lowest_list_protocol_version(Ver1, [Ver2 | Rest]) -> - lowest_list_protocol_version(lowest_protocol_version(Ver1, Ver2), Rest). +get_dtls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer), + ?UINT16(Epoch), ?UINT48(SequenceNumber), + ?UINT16(Length), Data:Length/binary, Rest/binary>>, + Acc) -> + get_dtls_records_aux(Rest, [#ssl_tls{type = ?APPLICATION_DATA, + version = {MajVer, MinVer}, + epoch = Epoch, sequence_number = SequenceNumber, + fragment = Data} | Acc]); +get_dtls_records_aux(<<?BYTE(?HANDSHAKE),?BYTE(MajVer),?BYTE(MinVer), + ?UINT16(Epoch), ?UINT48(SequenceNumber), + ?UINT16(Length), + Data:Length/binary, Rest/binary>>, Acc) when MajVer >= 128 -> + get_dtls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE, + version = {MajVer, MinVer}, + epoch = Epoch, sequence_number = SequenceNumber, + fragment = Data} | Acc]); +get_dtls_records_aux(<<?BYTE(?ALERT),?BYTE(MajVer),?BYTE(MinVer), + ?UINT16(Epoch), ?UINT48(SequenceNumber), + ?UINT16(Length), Data:Length/binary, + Rest/binary>>, Acc) -> + get_dtls_records_aux(Rest, [#ssl_tls{type = ?ALERT, + version = {MajVer, MinVer}, + epoch = Epoch, sequence_number = SequenceNumber, + fragment = Data} | Acc]); +get_dtls_records_aux(<<?BYTE(?CHANGE_CIPHER_SPEC),?BYTE(MajVer),?BYTE(MinVer), + ?UINT16(Epoch), ?UINT48(SequenceNumber), + ?UINT16(Length), Data:Length/binary, Rest/binary>>, + Acc) -> + get_dtls_records_aux(Rest, [#ssl_tls{type = ?CHANGE_CIPHER_SPEC, + version = {MajVer, MinVer}, + epoch = Epoch, sequence_number = SequenceNumber, + fragment = Data} | Acc]); +get_dtls_records_aux(<<?BYTE(_), ?BYTE(_MajVer), ?BYTE(_MinVer), + ?UINT16(Length), _/binary>>, + _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH -> + ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); -highest_list_protocol_version(Ver, []) -> - Ver; -highest_list_protocol_version(Ver1, [Ver2 | Rest]) -> - highest_list_protocol_version(highest_protocol_version(Ver1, Ver2), Rest). +get_dtls_records_aux(Data, Acc) -> + case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of + true -> + {lists:reverse(Acc), Data}; + false -> + ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE) + end. +%%-------------------------------------------------------------------- + +init_replay_window(Size) -> + #{size => Size, + top => Size, + bottom => 0, + mask => 0 bsl 64 + }. + +replay_detect(#ssl_tls{sequence_number = SequenceNumber}, #{replay_window := Window}) -> + is_replay(SequenceNumber, Window). + + +is_replay(SequenceNumber, #{bottom := Bottom}) when SequenceNumber < Bottom -> + true; +is_replay(SequenceNumber, #{size := Size, + top := Top, + bottom := Bottom, + mask := Mask}) when (SequenceNumber >= Bottom) andalso (SequenceNumber =< Top) -> + Index = (SequenceNumber rem Size), + (Index band Mask) == 1; + +is_replay(_, _) -> + false. + +update_replay_window(SequenceNumber, #{replay_window := #{size := Size, + top := Top, + bottom := Bottom, + mask := Mask0} = Window0} = ConnectionStates) -> + NoNewBits = SequenceNumber - Top, + Index = SequenceNumber rem Size, + Mask = (Mask0 bsl NoNewBits) bor Index, + Window = Window0#{top => SequenceNumber, + bottom => Bottom + NoNewBits, + mask => Mask}, + ConnectionStates#{replay_window := Window}. + +%%-------------------------------------------------------------------- encode_dtls_cipher_text(Type, {MajVer, MinVer}, Fragment, #{epoch := Epoch, sequence_number := Seq} = WriteState) -> @@ -439,43 +501,61 @@ encode_dtls_cipher_text(Type, {MajVer, MinVer}, Fragment, encode_plain_text(Type, Version, Data, #{compression_state := CompS0, epoch := Epoch, sequence_number := Seq, + cipher_state := CipherS0, security_parameters := #security_parameters{ cipher_type = ?AEAD, + bulk_cipher_algorithm = + BulkCipherAlgo, compression_algorithm = CompAlg} } = WriteState0) -> {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), - WriteState1 = WriteState0#{compression_state => CompS1}, AAD = calc_aad(Type, Version, Epoch, Seq), - ssl_record:cipher_aead(dtls_v1:corresponding_tls_version(Version), Comp, WriteState1, AAD); -encode_plain_text(Type, Version, Data, #{compression_state := CompS0, + TLSVersion = dtls_v1:corresponding_tls_version(Version), + {CipherFragment, CipherS1} = + ssl_cipher:cipher_aead(BulkCipherAlgo, CipherS0, Seq, AAD, Comp, TLSVersion), + {CipherFragment, WriteState0#{compression_state => CompS1, + cipher_state => CipherS1}}; +encode_plain_text(Type, Version, Fragment, #{compression_state := CompS0, epoch := Epoch, sequence_number := Seq, + cipher_state := CipherS0, security_parameters := - #security_parameters{compression_algorithm = CompAlg} + #security_parameters{compression_algorithm = CompAlg, + bulk_cipher_algorithm = + BulkCipherAlgo} }= WriteState0) -> - {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), + {Comp, CompS1} = ssl_record:compress(CompAlg, Fragment, CompS0), WriteState1 = WriteState0#{compression_state => CompS1}, - MacHash = calc_mac_hash(Type, Version, WriteState1, Epoch, Seq, Comp), - ssl_record:cipher(dtls_v1:corresponding_tls_version(Version), Comp, WriteState1, MacHash). + MAC = calc_mac_hash(Type, Version, WriteState1, Epoch, Seq, Comp), + TLSVersion = dtls_v1:corresponding_tls_version(Version), + {CipherFragment, CipherS1} = + ssl_cipher:cipher(BulkCipherAlgo, CipherS0, MAC, Fragment, TLSVersion), + {CipherFragment, WriteState0#{cipher_state => CipherS1}}. +%%-------------------------------------------------------------------- decode_cipher_text(#ssl_tls{type = Type, version = Version, epoch = Epoch, sequence_number = Seq, fragment = CipherFragment} = CipherText, #{compression_state := CompressionS0, + cipher_state := CipherS0, security_parameters := #security_parameters{ cipher_type = ?AEAD, + bulk_cipher_algorithm = + BulkCipherAlgo, compression_algorithm = CompAlg}} = ReadState0, ConnnectionStates0) -> AAD = calc_aad(Type, Version, Epoch, Seq), - case ssl_record:decipher_aead(dtls_v1:corresponding_tls_version(Version), - CipherFragment, ReadState0, AAD) of - {PlainFragment, ReadState1} -> + TLSVersion = dtls_v1:corresponding_tls_version(Version), + case ssl_cipher:decipher_aead(BulkCipherAlgo, CipherS0, Seq, AAD, CipherFragment, TLSVersion) of + {PlainFragment, CipherState} -> {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, PlainFragment, CompressionS0), - ReadState = ReadState1#{compression_state => CompressionS1}, + ReadState0 = ReadState0#{compression_state => CompressionS1, + cipher_state => CipherState}, + ReadState = update_replay_window(Seq, ReadState0), ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read), {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; #alert{} = Alert -> @@ -498,21 +578,43 @@ decode_cipher_text(#ssl_tls{type = Type, version = Version, {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, PlainFragment, CompressionS0), - ReadState = ReadState1#{compression_state => CompressionS1}, + ReadState2 = ReadState1#{compression_state => CompressionS1}, + ReadState = update_replay_window(Seq, ReadState2), ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read), {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; false -> ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) end. +%%-------------------------------------------------------------------- calc_mac_hash(Type, Version, #{mac_secret := MacSecret, security_parameters := #security_parameters{mac_algorithm = MacAlg}}, Epoch, SeqNo, Fragment) -> Length = erlang:iolist_size(Fragment), - NewSeq = (Epoch bsl 48) + SeqNo, - mac_hash(Version, MacAlg, MacSecret, NewSeq, Type, + mac_hash(Version, MacAlg, MacSecret, Epoch, SeqNo, Type, Length, Fragment). +mac_hash({Major, Minor}, MacAlg, MacSecret, Epoch, SeqNo, Type, Length, Fragment) -> + Value = [<<?UINT16(Epoch), ?UINT48(SeqNo), ?BYTE(Type), + ?BYTE(Major), ?BYTE(Minor), ?UINT16(Length)>>, + Fragment], + dtls_v1:hmac_hash(MacAlg, MacSecret, Value). + +calc_aad(Type, {MajVer, MinVer}, Epoch, SeqNo) -> + <<?UINT16(Epoch), ?UINT48(SeqNo), ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. + +%%-------------------------------------------------------------------- + +lowest_list_protocol_version(Ver, []) -> + Ver; +lowest_list_protocol_version(Ver1, [Ver2 | Rest]) -> + lowest_list_protocol_version(lowest_protocol_version(Ver1, Ver2), Rest). + +highest_list_protocol_version(Ver, []) -> + Ver; +highest_list_protocol_version(Ver1, [Ver2 | Rest]) -> + highest_list_protocol_version(highest_protocol_version(Ver1, Ver2), Rest). + highest_protocol_version() -> highest_protocol_version(supported_protocol_versions()). @@ -523,10 +625,3 @@ sufficient_dtlsv1_2_crypto_support() -> CryptoSupport = crypto:supports(), proplists:get_bool(sha256, proplists:get_value(hashs, CryptoSupport)). -mac_hash(Version, MacAlg, MacSecret, SeqNo, Type, Length, Fragment) -> - dtls_v1:mac_hash(Version, MacAlg, MacSecret, SeqNo, Type, - Length, Fragment). - -calc_aad(Type, {MajVer, MinVer}, Epoch, SeqNo) -> - NewSeq = (Epoch bsl 48) + SeqNo, - <<NewSeq:64/integer, ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. |