aboutsummaryrefslogtreecommitdiffstats
path: root/lib/ssl/src/dtls_record.erl
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ssl/src/dtls_record.erl')
-rw-r--r--lib/ssl/src/dtls_record.erl349
1 files changed, 204 insertions, 145 deletions
diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl
index daadae0725..f667458a10 100644
--- a/lib/ssl/src/dtls_record.erl
+++ b/lib/ssl/src/dtls_record.erl
@@ -15,104 +15,134 @@
%% under the License.
%%
%% %CopyrightEnd%
+
+%%
+%%----------------------------------------------------------------------
+%% Purpose: Handle DTLS record protocol. (Parts that are not shared with SSL/TLS)
+%%----------------------------------------------------------------------
-module(dtls_record).
-include("dtls_record.hrl").
-include("ssl_internal.hrl").
-include("ssl_alert.hrl").
-
--export([init_connection_state_seq/2, current_connection_state_epoch/2,
- set_connection_state_by_epoch/3, connection_state_by_epoch/3]).
+-include("dtls_handshake.hrl").
+-include("ssl_cipher.hrl").
%% Handling of incoming data
-export([get_dtls_records/2]).
-%% Misc.
+%% Decoding
+-export([decode_cipher_text/2]).
+
+%% Encoding
+-export([encode_plain_text/4]).
+
+%% Protocol version handling
-export([protocol_version/1, lowest_protocol_version/2,
highest_protocol_version/1, supported_protocol_versions/0,
is_acceptable_version/2, cipher/4, decipher/2]).
-%%--------------------------------------------------------------------
--spec init_connection_state_seq(tls_version(), #connection_states{}) ->
- #connection_state{}.
-%%
-%% Description: Copy the read sequence number to the write sequence number
-%% This is only valid for DTLS in the first client_hello
-%%--------------------------------------------------------------------
-init_connection_state_seq({254, _},
- #connection_states{
- current_read = Read = #connection_state{epoch = 0},
- current_write = Write = #connection_state{epoch = 0}} = CS0) ->
- CS0#connection_states{current_write =
- Write#connection_state{
- sequence_number = Read#connection_state.sequence_number}};
-init_connection_state_seq(_, CS) ->
- CS.
+-export([init_connection_state_seq/2, current_connection_state_epoch/2,
+ set_connection_state_by_epoch/3, connection_state_by_epoch/3]).
-%%--------------------------------------------------------
--spec current_connection_state_epoch(#connection_states{}, read | write) ->
- integer().
-%%
-%% Description: Returns the epoch the connection_state record
-%% that is currently defined as the current conection state.
-%%--------------------------------------------------------------------
-current_connection_state_epoch(#connection_states{current_read = Current},
- read) ->
- Current#connection_state.epoch;
-current_connection_state_epoch(#connection_states{current_write = Current},
- write) ->
- Current#connection_state.epoch.
+-compile(inline).
-%%--------------------------------------------------------------------
+%%====================================================================
+%% Internal application API
+%%====================================================================
--spec connection_state_by_epoch(#connection_states{}, integer(), read | write) ->
- #connection_state{}.
-%%
-%% Description: Returns the instance of the connection_state record
-%% that is defined by the Epoch.
-%%--------------------------------------------------------------------
-connection_state_by_epoch(#connection_states{current_read = CS}, Epoch, read)
- when CS#connection_state.epoch == Epoch ->
- CS;
-connection_state_by_epoch(#connection_states{pending_read = CS}, Epoch, read)
- when CS#connection_state.epoch == Epoch ->
- CS;
-connection_state_by_epoch(#connection_states{current_write = CS}, Epoch, write)
- when CS#connection_state.epoch == Epoch ->
- CS;
-connection_state_by_epoch(#connection_states{pending_write = CS}, Epoch, write)
- when CS#connection_state.epoch == Epoch ->
- CS.
%%--------------------------------------------------------------------
--spec set_connection_state_by_epoch(#connection_states{},
- #connection_state{}, read | write) -> ok.
+-spec get_dtls_records(binary(), binary()) -> {[binary()], binary()} | #alert{}.
%%
-%% Description: Returns the instance of the connection_state record
-%% that is defined by the Epoch.
+%% Description: Given old buffer and new data from UDP/SCTP, packs up a records
+%% and returns it as a list of tls_compressed binaries also returns leftover
+%% data
%%--------------------------------------------------------------------
-set_connection_state_by_epoch(ConnectionStates0 =
- #connection_states{current_read = CS},
- NewCS = #connection_state{epoch = Epoch}, read)
- when CS#connection_state.epoch == Epoch ->
- ConnectionStates0#connection_states{current_read = NewCS};
+get_dtls_records(Data, <<>>) ->
+ get_dtls_records_aux(Data, []);
+get_dtls_records(Data, Buffer) ->
+ get_dtls_records_aux(list_to_binary([Buffer, Data]), []).
-set_connection_state_by_epoch(ConnectionStates0 =
- #connection_states{pending_read = CS},
- NewCS = #connection_state{epoch = Epoch}, read)
- when CS#connection_state.epoch == Epoch ->
- ConnectionStates0#connection_states{pending_read = NewCS};
+get_dtls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer),
+ ?UINT16(Epoch), ?UINT48(SequenceNumber),
+ ?UINT16(Length), Data:Length/binary, Rest/binary>>,
+ Acc) ->
+ get_dtls_records_aux(Rest, [#ssl_tls{type = ?APPLICATION_DATA,
+ version = {MajVer, MinVer},
+ epoch = Epoch, record_seq = SequenceNumber,
+ fragment = Data} | Acc]);
+get_dtls_records_aux(<<?BYTE(?HANDSHAKE),?BYTE(MajVer),?BYTE(MinVer),
+ ?UINT16(Epoch), ?UINT48(SequenceNumber),
+ ?UINT16(Length),
+ Data:Length/binary, Rest/binary>>, Acc) when MajVer >= 128 ->
+ get_dtls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE,
+ version = {MajVer, MinVer},
+ epoch = Epoch, record_seq = SequenceNumber,
+ fragment = Data} | Acc]);
+get_dtls_records_aux(<<?BYTE(?ALERT),?BYTE(MajVer),?BYTE(MinVer),
+ ?UINT16(Epoch), ?UINT48(SequenceNumber),
+ ?UINT16(Length), Data:Length/binary,
+ Rest/binary>>, Acc) ->
+ get_dtls_records_aux(Rest, [#ssl_tls{type = ?ALERT,
+ version = {MajVer, MinVer},
+ epoch = Epoch, record_seq = SequenceNumber,
+ fragment = Data} | Acc]);
+get_dtls_records_aux(<<?BYTE(?CHANGE_CIPHER_SPEC),?BYTE(MajVer),?BYTE(MinVer),
+ ?UINT16(Epoch), ?UINT48(SequenceNumber),
+ ?UINT16(Length), Data:Length/binary, Rest/binary>>,
+ Acc) ->
+ get_dtls_records_aux(Rest, [#ssl_tls{type = ?CHANGE_CIPHER_SPEC,
+ version = {MajVer, MinVer},
+ epoch = Epoch, record_seq = SequenceNumber,
+ fragment = Data} | Acc]);
-set_connection_state_by_epoch(ConnectionStates0 =
- #connection_states{current_write = CS},
- NewCS = #connection_state{epoch = Epoch}, write)
- when CS#connection_state.epoch == Epoch ->
- ConnectionStates0#connection_states{current_write = NewCS};
+get_dtls_records_aux(<<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer),
+ ?UINT16(Length), _/binary>>,
+ _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH ->
+ ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
-set_connection_state_by_epoch(ConnectionStates0 =
- #connection_states{pending_write = CS},
- NewCS = #connection_state{epoch = Epoch}, write)
- when CS#connection_state.epoch == Epoch ->
- ConnectionStates0#connection_states{pending_write = NewCS}.
+get_dtls_records_aux(<<1:1, Length0:15, _/binary>>,_Acc)
+ when Length0 > ?MAX_CIPHER_TEXT_LENGTH ->
+ ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
+
+get_dtls_records_aux(Data, Acc) ->
+ case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of
+ true ->
+ {lists:reverse(Acc), Data};
+ false ->
+ ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE)
+ end.
+
+encode_plain_text(Type, Version, Data,
+ #connection_state{
+ compression_state = CompS0,
+ epoch = Epoch,
+ sequence_number = Seq,
+ security_parameters=
+ #security_parameters{compression_algorithm = CompAlg}
+ }= CS0) ->
+ {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0),
+ CS1 = CS0#connection_state{compression_state = CompS1},
+ {CipherText, CS2} = cipher(Type, Version, Comp, CS1),
+ CTBin = encode_tls_cipher_text(Type, Version, Epoch, Seq, CipherText),
+ {CTBin, CS2}.
+
+decode_cipher_text(CipherText, ConnnectionStates0) ->
+ ReadState0 = ConnnectionStates0#connection_states.current_read,
+ #connection_state{compression_state = CompressionS0,
+ security_parameters = SecParams} = ReadState0,
+ CompressAlg = SecParams#security_parameters.compression_algorithm,
+ case decipher(CipherText, ReadState0) of
+ {Compressed, ReadState1} ->
+ {Plain, CompressionS1} = ssl_record:uncompress(CompressAlg,
+ Compressed, CompressionS0),
+ ConnnectionStates = ConnnectionStates0#connection_states{
+ current_read = ReadState1#connection_state{
+ compression_state = CompressionS1}},
+ {Plain, ConnnectionStates};
+ #alert{} = Alert ->
+ Alert
+ end.
%%--------------------------------------------------------------------
-spec protocol_version(tls_atom_version() | tls_version()) ->
@@ -161,67 +191,6 @@ highest_protocol_version(Version = {M,_}, [{N,_} | Rest]) when M < N ->
highest_protocol_version(_, [Version | Rest]) ->
highest_protocol_version(Version, Rest).
-%%--------------------------------------------------------------------
--spec get_dtls_records(binary(), binary()) -> {[binary()], binary()} | #alert{}.
-%%
-%% Description: Given old buffer and new data from UDP/SCTP, packs up a records
-%% and returns it as a list of tls_compressed binaries also returns leftover
-%% data
-%%--------------------------------------------------------------------
-get_dtls_records(Data, <<>>) ->
- get_dtls_records_aux(Data, []);
-get_dtls_records(Data, Buffer) ->
- get_dtls_records_aux(list_to_binary([Buffer, Data]), []).
-
-get_dtls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer),
- ?UINT16(Epoch), ?UINT48(SequenceNumber),
- ?UINT16(Length), Data:Length/binary, Rest/binary>>,
- Acc) ->
- get_dtls_records_aux(Rest, [#ssl_tls{type = ?APPLICATION_DATA,
- version = {MajVer, MinVer},
- epoch = Epoch, record_seq = SequenceNumber,
- fragment = Data} | Acc]);
-get_dtls_records_aux(<<?BYTE(?HANDSHAKE),?BYTE(MajVer),?BYTE(MinVer),
- ?UINT16(Epoch), ?UINT48(SequenceNumber),
- ?UINT16(Length),
- Data:Length/binary, Rest/binary>>, Acc) when MajVer >= 128 ->
- get_dtls_records_aux(Rest, [#ssl_tls{type = ?HANDSHAKE,
- version = {MajVer, MinVer},
- epoch = Epoch, record_seq = SequenceNumber,
- fragment = Data} | Acc]);
-get_dtls_records_aux(<<?BYTE(?ALERT),?BYTE(MajVer),?BYTE(MinVer),
- ?UINT16(Epoch), ?UINT48(SequenceNumber),
- ?UINT16(Length), Data:Length/binary,
- Rest/binary>>, Acc) ->
- get_dtls_records_aux(Rest, [#ssl_tls{type = ?ALERT,
- version = {MajVer, MinVer},
- epoch = Epoch, record_seq = SequenceNumber,
- fragment = Data} | Acc]);
-get_dtls_records_aux(<<?BYTE(?CHANGE_CIPHER_SPEC),?BYTE(MajVer),?BYTE(MinVer),
- ?UINT16(Epoch), ?UINT48(SequenceNumber),
- ?UINT16(Length), Data:Length/binary, Rest/binary>>,
- Acc) ->
- get_dtls_records_aux(Rest, [#ssl_tls{type = ?CHANGE_CIPHER_SPEC,
- version = {MajVer, MinVer},
- epoch = Epoch, record_seq = SequenceNumber,
- fragment = Data} | Acc]);
-
-get_dtls_records_aux(<<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer),
- ?UINT16(Length), _/binary>>,
- _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH ->
- ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
-
-get_dtls_records_aux(<<1:1, Length0:15, _/binary>>,_Acc)
- when Length0 > ?MAX_CIPHER_TEXT_LENGTH ->
- ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
-
-get_dtls_records_aux(Data, Acc) ->
- case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of
- true ->
- {lists:reverse(Acc), Data};
- false ->
- ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE)
- end.
%%--------------------------------------------------------------------
-spec supported_protocol_versions() -> [tls_version()].
@@ -264,13 +233,103 @@ is_acceptable_version(Version, Versions) ->
lists:member(Version, Versions).
+%%--------------------------------------------------------------------
+-spec init_connection_state_seq(tls_version(), #connection_states{}) ->
+ #connection_state{}.
+%%
+%% Description: Copy the read sequence number to the write sequence number
+%% This is only valid for DTLS in the first client_hello
+%%--------------------------------------------------------------------
+init_connection_state_seq({254, _},
+ #connection_states{
+ current_read = Read = #connection_state{epoch = 0},
+ current_write = Write = #connection_state{epoch = 0}} = CS0) ->
+ CS0#connection_states{current_write =
+ Write#connection_state{
+ sequence_number = Read#connection_state.sequence_number}};
+init_connection_state_seq(_, CS) ->
+ CS.
+
+%%--------------------------------------------------------
+-spec current_connection_state_epoch(#connection_states{}, read | write) ->
+ integer().
+%%
+%% Description: Returns the epoch the connection_state record
+%% that is currently defined as the current conection state.
+%%--------------------------------------------------------------------
+current_connection_state_epoch(#connection_states{current_read = Current},
+ read) ->
+ Current#connection_state.epoch;
+current_connection_state_epoch(#connection_states{current_write = Current},
+ write) ->
+ Current#connection_state.epoch.
+
+%%--------------------------------------------------------------------
+
+-spec connection_state_by_epoch(#connection_states{}, integer(), read | write) ->
+ #connection_state{}.
+%%
+%% Description: Returns the instance of the connection_state record
+%% that is defined by the Epoch.
+%%--------------------------------------------------------------------
+connection_state_by_epoch(#connection_states{current_read = CS}, Epoch, read)
+ when CS#connection_state.epoch == Epoch ->
+ CS;
+connection_state_by_epoch(#connection_states{pending_read = CS}, Epoch, read)
+ when CS#connection_state.epoch == Epoch ->
+ CS;
+connection_state_by_epoch(#connection_states{current_write = CS}, Epoch, write)
+ when CS#connection_state.epoch == Epoch ->
+ CS;
+connection_state_by_epoch(#connection_states{pending_write = CS}, Epoch, write)
+ when CS#connection_state.epoch == Epoch ->
+ CS.
+%%--------------------------------------------------------------------
+-spec set_connection_state_by_epoch(#connection_states{},
+ #connection_state{}, read | write) -> ok.
+%%
+%% Description: Returns the instance of the connection_state record
+%% that is defined by the Epoch.
+%%--------------------------------------------------------------------
+set_connection_state_by_epoch(ConnectionStates0 =
+ #connection_states{current_read = CS},
+ NewCS = #connection_state{epoch = Epoch}, read)
+ when CS#connection_state.epoch == Epoch ->
+ ConnectionStates0#connection_states{current_read = NewCS};
+
+set_connection_state_by_epoch(ConnectionStates0 =
+ #connection_states{pending_read = CS},
+ NewCS = #connection_state{epoch = Epoch}, read)
+ when CS#connection_state.epoch == Epoch ->
+ ConnectionStates0#connection_states{pending_read = NewCS};
+
+set_connection_state_by_epoch(ConnectionStates0 =
+ #connection_states{current_write = CS},
+ NewCS = #connection_state{epoch = Epoch}, write)
+ when CS#connection_state.epoch == Epoch ->
+ ConnectionStates0#connection_states{current_write = NewCS};
+
+set_connection_state_by_epoch(ConnectionStates0 =
+ #connection_states{pending_write = CS},
+ NewCS = #connection_state{epoch = Epoch}, write)
+ when CS#connection_state.epoch == Epoch ->
+ ConnectionStates0#connection_states{pending_write = NewCS}.
+
+%%--------------------------------------------------------------------
+%%% Internal functions
+%%--------------------------------------------------------------------
+encode_tls_cipher_text(Type, {MajVer, MinVer}, Epoch, Seq, Fragment) ->
+ Length = erlang:iolist_size(Fragment),
+ [<<?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer), ?UINT16(Epoch),
+ ?UINT48(Seq), ?UINT16(Length)>>, Fragment].
+
cipher(Type, Version, Fragment, CS0) ->
Length = erlang:iolist_size(Fragment),
{MacHash, CS1=#connection_state{cipher_state = CipherS0,
- security_parameters=
- #security_parameters{bulk_cipher_algorithm =
- BCA}
- }} =
+ security_parameters=
+ #security_parameters{bulk_cipher_algorithm =
+ BCA}
+ }} =
hash_and_bump_seqno(CS0, Type, Version, Length, Fragment),
{Ciphered, CipherS1} = ssl_cipher:cipher(BCA, CipherS0, MacHash, Fragment, Version),
CS2 = CS1#connection_state{cipher_state=CipherS1},
@@ -299,10 +358,10 @@ decipher(TLS=#ssl_tls{type=Type, version=Version={254, _},
end.
hash_with_seqno(#connection_state{mac_secret = MacSecret,
- security_parameters =
- SecPars},
- Type, Version = {254, _},
- Epoch, SeqNo, Length, Fragment) ->
+ security_parameters =
+ SecPars},
+ Type, Version = {254, _},
+ Epoch, SeqNo, Length, Fragment) ->
mac_hash(Version,
SecPars#security_parameters.mac_algorithm,
MacSecret, (Epoch bsl 48) + SeqNo, Type,