From a3d68814e1cd1ef062582901e0102f60a323bae5 Mon Sep 17 00:00:00 2001
From: Ingela Anderton Andin <ingela@erlang.org>
Date: Thu, 4 Jan 2018 17:17:50 +0100
Subject: ssl: Add record version sanity check

---
 lib/ssl/src/dtls_connection.erl | 13 ++++++++++---
 lib/ssl/src/dtls_record.erl     | 30 ++++++++++++++++++------------
 lib/ssl/src/tls_connection.erl  | 26 +++++++++++++++++++++-----
 lib/ssl/src/tls_record.erl      | 37 +++++++++++++++++++++++++++++--------
 4 files changed, 78 insertions(+), 28 deletions(-)

(limited to 'lib/ssl/src')

diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl
index 9cb6934dce..03725089dd 100644
--- a/lib/ssl/src/dtls_connection.erl
+++ b/lib/ssl/src/dtls_connection.erl
@@ -758,10 +758,12 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, _}, User,
            flight_state = {retransmit, ?INITIAL_RETRANSMIT_TIMEOUT}
 	  }.
 
-next_dtls_record(Data, #state{protocol_buffers = #protocol_buffers{
+next_dtls_record(Data, StateName, #state{protocol_buffers = #protocol_buffers{
 						   dtls_record_buffer = Buf0,
 						   dtls_cipher_texts = CT0} = Buffers} = State0) ->
-    case dtls_record:get_dtls_records(Data, Buf0) of
+    case dtls_record:get_dtls_records(Data,
+                                      acceptable_record_versions(StateName, State0), 
+                                      Buf0) of
 	{Records, Buf1} ->
 	    CT1 = CT0 ++ Records,
 	    next_record(State0#state{protocol_buffers =
@@ -771,6 +773,11 @@ next_dtls_record(Data, #state{protocol_buffers = #protocol_buffers{
 	    Alert
     end.
 
+acceptable_record_versions(hello, _) ->
+    [dtls_record:protocol_version(Vsn) || Vsn <- ?ALL_DATAGRAM_SUPPORTED_VERSIONS];
+acceptable_record_versions(_, #state{negotiated_version = Version}) ->
+    [Version].
+
 dtls_handshake_events(Packets) ->
     lists:map(fun(Packet) ->
 		      {next_event, internal, {handshake, Packet}}
@@ -828,7 +835,7 @@ handle_client_hello(#client_hello{client_version = ClientVersion} = Hello,
 %% raw data from socket, unpack records
 handle_info({Protocol, _, _, _, Data}, StateName,
             #state{data_tag = Protocol} = State0) ->
-    case next_dtls_record(Data, State0) of
+    case next_dtls_record(Data, StateName, State0) of
 	{Record, State} ->
 	    next_event(StateName, Record, State);
 	#alert{} = Alert ->
diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl
index 2dcc6efc91..316de05532 100644
--- a/lib/ssl/src/dtls_record.erl
+++ b/lib/ssl/src/dtls_record.erl
@@ -30,7 +30,7 @@
 -include("ssl_cipher.hrl").
 
 %% Handling of incoming data
--export([get_dtls_records/2,  init_connection_states/2, empty_connection_state/1]).
+-export([get_dtls_records/3,  init_connection_states/2, empty_connection_state/1]).
 
 -export([save_current_connection_state/2, next_epoch/2, get_connection_state_by_epoch/3, replay_detect/2,
          init_connection_state_seq/2, current_connection_state_epoch/2]).
@@ -163,17 +163,25 @@ current_connection_state_epoch(#{current_write := #{epoch := Epoch}},
     Epoch.
 
 %%--------------------------------------------------------------------
--spec get_dtls_records(binary(), binary()) -> {[binary()], binary()} | #alert{}.
+-spec get_dtls_records(binary(), [dtls_version()], binary()) -> {[binary()], binary()} | #alert{}.
 %%
 %% Description: Given old buffer and new data from UDP/SCTP, packs up a records
 %% and returns it as a list of tls_compressed binaries also returns leftover
 %% data
 %%--------------------------------------------------------------------
-get_dtls_records(Data, <<>>) ->
-    get_dtls_records_aux(Data, []);
-get_dtls_records(Data, Buffer) ->
-    get_dtls_records_aux(list_to_binary([Buffer, Data]), []).
-
+get_dtls_records(Data, Versions, Buffer) ->
+    BinData = list_to_binary([Buffer, Data]),
+    case erlang:byte_size(BinData) of
+        N when N >= 3 ->
+            case assert_version(BinData, Versions) of
+                true ->
+                    get_dtls_records_aux(BinData, []);
+                false ->
+                    ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC)
+            end;
+        _ ->
+            get_dtls_records_aux(BinData, [])
+    end.
 
 %%====================================================================
 %% Encoding DTLS records
@@ -397,6 +405,8 @@ initial_connection_state(ConnectionEnd, BeastMitigation) ->
       client_verify_data => undefined,
       server_verify_data => undefined
      }.
+assert_version(<<?BYTE(_), ?BYTE(MajVer), ?BYTE(MinVer), _/binary>>, Versions) ->
+    is_acceptable_version({MajVer, MinVer}, Versions).
 
 get_dtls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer),
 		       ?UINT16(Epoch), ?UINT48(SequenceNumber),
@@ -431,15 +441,11 @@ get_dtls_records_aux(<<?BYTE(?CHANGE_CIPHER_SPEC),?BYTE(MajVer),?BYTE(MinVer),
 					 epoch = Epoch, sequence_number = SequenceNumber,
 					 fragment = Data} | Acc]);
 
-get_dtls_records_aux(<<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer),
+get_dtls_records_aux(<<?BYTE(_), ?BYTE(_MajVer), ?BYTE(_MinVer),
 		       ?UINT16(Length), _/binary>>,
 		     _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH ->
     ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
 
-get_dtls_records_aux(<<1:1, Length0:15, _/binary>>,_Acc)
-  when Length0 > ?MAX_CIPHER_TEXT_LENGTH ->
-    ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
-
 get_dtls_records_aux(Data, Acc) ->
     case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of
 	true ->
diff --git a/lib/ssl/src/tls_connection.erl b/lib/ssl/src/tls_connection.erl
index 39f3ed996e..914ee9f22f 100644
--- a/lib/ssl/src/tls_connection.erl
+++ b/lib/ssl/src/tls_connection.erl
@@ -629,18 +629,34 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, Tracker}, Us
 	   flight_buffer = []
 	  }.
 
-next_tls_record(Data, #state{protocol_buffers = #protocol_buffers{tls_record_buffer = Buf0,
-						tls_cipher_texts = CT0} = Buffers} = State0) ->
-    case tls_record:get_tls_records(Data, Buf0) of
+next_tls_record(Data, StateName, #state{protocol_buffers = 
+                                            #protocol_buffers{tls_record_buffer = Buf0,
+                                                              tls_cipher_texts = CT0} = Buffers}
+                                        = State0) ->
+    case tls_record:get_tls_records(Data, 
+                                    acceptable_record_versions(StateName, State0),
+                                    Buf0) of
 	{Records, Buf1} ->
 	    CT1 = CT0 ++ Records,
 	    next_record(State0#state{protocol_buffers =
 					 Buffers#protocol_buffers{tls_record_buffer = Buf1,
 								  tls_cipher_texts = CT1}});
 	#alert{} = Alert ->
-	    Alert
+	    handle_record_alert(Alert, State0)
     end.
 
+acceptable_record_versions(hello, #state{ssl_options = #ssl_options{v2_hello_compatible = true}}) ->
+    [tls_record:protocol_version(Vsn) || Vsn <- ?ALL_AVAILABLE_VERSIONS ++ ['sslv2']];
+acceptable_record_versions(hello, _) ->
+    [tls_record:protocol_version(Vsn) || Vsn <- ?ALL_AVAILABLE_VERSIONS];
+acceptable_record_versions(_, #state{negotiated_version = Version}) ->
+    [Version].
+handle_record_alert(#alert{description = ?BAD_RECORD_MAC}, 
+                    #state{ssl_options = #ssl_options{v2_hello_compatible = true}}) ->
+    ?ALERT_REC(?FATAL, ?PROTOCOL_VERSION);
+handle_record_alert(Alert, _) ->
+    Alert.
+
 tls_handshake_events(Packets) ->
     lists:map(fun(Packet) ->
 		      {next_event, internal, {handshake, Packet}}
@@ -649,7 +665,7 @@ tls_handshake_events(Packets) ->
 %% raw data from socket, upack records
 handle_info({Protocol, _, Data}, StateName,
             #state{data_tag = Protocol} = State0) ->
-    case next_tls_record(Data, State0) of
+    case next_tls_record(Data, StateName, State0) of
 	{Record, State} ->
 	    next_event(StateName, Record, State);
 	#alert{} = Alert ->
diff --git a/lib/ssl/src/tls_record.erl b/lib/ssl/src/tls_record.erl
index ab179c1bf0..188ec6809d 100644
--- a/lib/ssl/src/tls_record.erl
+++ b/lib/ssl/src/tls_record.erl
@@ -32,7 +32,7 @@
 -include("ssl_cipher.hrl").
 
 %% Handling of incoming data
--export([get_tls_records/2, init_connection_states/2]).
+-export([get_tls_records/3, init_connection_states/2]).
 
 %% Encoding TLS records
 -export([encode_handshake/3, encode_alert_record/3,
@@ -75,16 +75,25 @@ init_connection_states(Role, BeastMitigation) ->
       pending_write => Pending}.
 
 %%--------------------------------------------------------------------
--spec get_tls_records(binary(), binary()) -> {[binary()], binary()} | #alert{}.
+-spec get_tls_records(binary(), [tls_version()], binary()) -> {[binary()], binary()} | #alert{}.
 %%			     
 %% and returns it as a list of tls_compressed binaries also returns leftover
 %% Description: Given old buffer and new data from TCP, packs up a records
 %% data
 %%--------------------------------------------------------------------
-get_tls_records(Data, <<>>) ->
-    get_tls_records_aux(Data, []);
-get_tls_records(Data, Buffer) ->
-    get_tls_records_aux(list_to_binary([Buffer, Data]), []).
+get_tls_records(Data, Versions, Buffer) ->
+    BinData = list_to_binary([Buffer, Data]),
+    case erlang:byte_size(BinData) of
+        N when N >= 3 ->
+            case assert_version(BinData, Versions) of
+                true ->
+                    get_tls_records_aux(BinData, []);
+                false ->
+                    ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC)
+            end;
+        _ ->
+            get_tls_records_aux(BinData, [])
+    end.
 
 %%====================================================================
 %% Encoding
@@ -385,6 +394,19 @@ initial_connection_state(ConnectionEnd, BeastMitigation) ->
       server_verify_data => undefined
      }.
 
+assert_version(<<1:1, Length0:15, Data0:Length0/binary, _/binary>>, Versions) ->
+    case Data0 of
+        <<?BYTE(?CLIENT_HELLO), ?BYTE(Major), ?BYTE(Minor), _/binary>> ->
+            %% First check v2_hello_compatible mode is active 
+            lists:member({2,0}, Versions) andalso
+            %% andalso we want to negotiate higher version
+                lists:member({Major, Minor}, Versions -- [{2,0}]); 
+        _ ->
+            false
+    end;
+assert_version(<<?BYTE(_), ?BYTE(MajVer), ?BYTE(MinVer), _/binary>>, Versions) ->
+    is_acceptable_version({MajVer, MinVer}, Versions).
+                   
 get_tls_records_aux(<<?BYTE(?APPLICATION_DATA),?BYTE(MajVer),?BYTE(MinVer),
 		     ?UINT16(Length), Data:Length/binary, Rest/binary>>, 
 		    Acc) ->
@@ -428,10 +450,9 @@ get_tls_records_aux(<<1:1, Length0:15, Data0:Length0/binary, Rest/binary>>,
     end;
 
 get_tls_records_aux(<<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer),
-                     ?UINT16(Length), _/binary>>,
+                      ?UINT16(Length), _/binary>>,
                     _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH ->
     ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
-
 get_tls_records_aux(<<1:1, Length0:15, _/binary>>,_Acc) 
   when Length0 > ?MAX_CIPHER_TEXT_LENGTH ->
     ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW);
-- 
cgit v1.2.3