From f3dfe10d8ee4a65362ef75803016b7b2e4368719 Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?P=C3=A9ter=20Dimitrov?= <peterdmv@erlang.org>
Date: Tue, 23 Oct 2018 16:47:53 +0200
Subject: ssl: Implement decode of "supported_groups"

Change-Id: I42d7779bb3558aa3a2bea5be065c559d01c0a32b
---
 lib/ssl/src/ssl_handshake.erl     | 143 +++++++++++++++++++++-----------------
 lib/ssl/src/ssl_handshake.hrl     |   2 +-
 lib/ssl/src/tls_handshake_1_3.erl |   2 +-
 lib/ssl/src/tls_v1.erl            |  11 ++-
 4 files changed, 93 insertions(+), 65 deletions(-)

diff --git a/lib/ssl/src/ssl_handshake.erl b/lib/ssl/src/ssl_handshake.erl
index 360c82084c..1e812c92a8 100644
--- a/lib/ssl/src/ssl_handshake.erl
+++ b/lib/ssl/src/ssl_handshake.erl
@@ -60,7 +60,7 @@
 -export([encode_handshake/2, encode_hello_extensions/1, encode_extensions/1, encode_extensions/2,
 	 encode_client_protocol_negotiation/2, encode_protocols_advertised_on_server/1]).
 %% Decode
--export([decode_handshake/3, decode_vector/1, decode_hello_extensions/3, decode_extensions/1,
+-export([decode_handshake/3, decode_vector/1, decode_hello_extensions/3, decode_extensions/2,
 	 decode_server_key/3, decode_client_key/3,
 	 decode_suites/2
 	]).
@@ -780,15 +780,15 @@ decode_vector(<<?UINT16(Len), Vector:Len/binary>>) ->
 %% Description: Decodes TLS hello extensions
 %%--------------------------------------------------------------------
 decode_hello_extensions(Extensions, Version, Role) ->
-    decode_extensions(Extensions, empty_hello_extensions(Version, Role)).
+    decode_extensions(Extensions, Version, empty_hello_extensions(Version, Role)).
 
 %%--------------------------------------------------------------------
--spec decode_extensions(binary()) -> map().
+-spec decode_extensions(binary(),tuple()) -> map().
 %%
 %% Description: Decodes TLS hello extensions
 %%--------------------------------------------------------------------
-decode_extensions(Extensions) ->
-    decode_extensions(Extensions, empty_extensions()).
+decode_extensions(Extensions, Version) ->
+    decode_extensions(Extensions, Version, empty_extensions()).
 
 %%--------------------------------------------------------------------
 -spec decode_server_key(binary(), ssl_cipher_format:key_algo(), ssl_record:ssl_version()) ->
@@ -1014,23 +1014,20 @@ add_tls12_extensions(Version,
      }.
 
 
-add_common_extensions({3,4} = Version,
+add_common_extensions({3,4},
                       HelloExtensions,
                       _CipherSuites,
                       #ssl_options{eccs = SupportedECCs,
                                    supported_groups = Groups}) ->
-    {EcPointFormats, EllipticCurves} =
+    {EcPointFormats, _} =
         client_ecc_extensions(SupportedECCs),
     HelloExtensions#{ec_point_formats => EcPointFormats,
-                     elliptic_curves => maybe_supported_groups(Version,
-                                                               EllipticCurves,
-                                                               Groups)};
+                     elliptic_curves => Groups};
 
-add_common_extensions(Version,
+add_common_extensions(_Version,
                       HelloExtensions,
                       CipherSuites,
-                      #ssl_options{eccs = SupportedECCs,
-                                   supported_groups = Groups}) ->
+                      #ssl_options{eccs = SupportedECCs}) ->
 
     {EcPointFormats, EllipticCurves} =
         case advertises_ec_ciphers(
@@ -1042,9 +1039,7 @@ add_common_extensions(Version,
                 {undefined, undefined}
         end,
     HelloExtensions#{ec_point_formats => EcPointFormats,
-                     elliptic_curves => maybe_supported_groups(Version,
-                                                               EllipticCurves,
-                                                               Groups)}.
+                     elliptic_curves => EllipticCurves}.
 
 
 maybe_add_tls13_extensions({3,4},
@@ -1058,12 +1053,6 @@ maybe_add_tls13_extensions({3,4},
 maybe_add_tls13_extensions(_, HelloExtensions, _) ->
     HelloExtensions.
 
-maybe_supported_groups({3,4}, _, SupportedGroups) ->
-    SupportedGroups;
-maybe_supported_groups(_, EllipticCurves, _) ->
-    EllipticCurves.
-
-
 signature_scheme_list(undefined) ->
     undefined;
 signature_scheme_list(SignatureSchemes) ->
@@ -2081,16 +2070,19 @@ dec_server_key_signature(Params, <<?UINT16(Len), Signature:Len/binary>>, _) ->
 dec_server_key_signature(_, _, _) ->
     throw(?ALERT_REC(?FATAL, ?HANDSHAKE_FAILURE, failed_to_decrypt_server_key_sign)).
 
-decode_extensions(<<>>, Acc) ->
+decode_extensions(<<>>, _Version, Acc) ->
     Acc;
-decode_extensions(<<?UINT16(?ALPN_EXT), ?UINT16(ExtLen), ?UINT16(Len), ExtensionData:Len/binary, Rest/binary>>, Acc)
-        when Len + 2 =:= ExtLen ->
+decode_extensions(<<?UINT16(?ALPN_EXT), ?UINT16(ExtLen), ?UINT16(Len),
+                    ExtensionData:Len/binary, Rest/binary>>, Version, Acc)
+  when Len + 2 =:= ExtLen ->
     ALPN = #alpn{extension_data = ExtensionData},
-    decode_extensions(Rest, Acc#{alpn => ALPN});
-decode_extensions(<<?UINT16(?NEXTPROTONEG_EXT), ?UINT16(Len), ExtensionData:Len/binary, Rest/binary>>, Acc) ->
+    decode_extensions(Rest, Version, Acc#{alpn => ALPN});
+decode_extensions(<<?UINT16(?NEXTPROTONEG_EXT), ?UINT16(Len),
+                    ExtensionData:Len/binary, Rest/binary>>, Version, Acc) ->
     NextP = #next_protocol_negotiation{extension_data = ExtensionData},
-    decode_extensions(Rest, Acc#{next_protocol_negotiation => NextP});
-decode_extensions(<<?UINT16(?RENEGOTIATION_EXT), ?UINT16(Len), Info:Len/binary, Rest/binary>>, Acc) ->
+    decode_extensions(Rest, Version, Acc#{next_protocol_negotiation => NextP});
+decode_extensions(<<?UINT16(?RENEGOTIATION_EXT), ?UINT16(Len),
+                    Info:Len/binary, Rest/binary>>, Version, Acc) ->
     RenegotiateInfo = case Len of
 			  1 ->  % Initial handshake
 			      Info; % should be <<0>> will be matched in handle_renegotiation_info
@@ -2099,35 +2091,38 @@ decode_extensions(<<?UINT16(?RENEGOTIATION_EXT), ?UINT16(Len), Info:Len/binary,
 			      <<?BYTE(VerifyLen), VerifyInfo/binary>> = Info,
 			      VerifyInfo
 		      end,
-    decode_extensions(Rest, Acc#{renegotiation_info =>
-                                     #renegotiation_info{renegotiated_connection =
-                                                             RenegotiateInfo}});
+    decode_extensions(Rest, Version, Acc#{renegotiation_info =>
+                                              #renegotiation_info{renegotiated_connection =
+                                                                      RenegotiateInfo}});
 
-decode_extensions(<<?UINT16(?SRP_EXT), ?UINT16(Len), ?BYTE(SRPLen), SRP:SRPLen/binary, Rest/binary>>, Acc)
+decode_extensions(<<?UINT16(?SRP_EXT), ?UINT16(Len), ?BYTE(SRPLen),
+                    SRP:SRPLen/binary, Rest/binary>>, Version, Acc)
   when Len == SRPLen + 2 ->
-    decode_extensions(Rest,  Acc#{srp => #srp{username = SRP}});
+    decode_extensions(Rest, Version, Acc#{srp => #srp{username = SRP}});
 
 decode_extensions(<<?UINT16(?SIGNATURE_ALGORITHMS_EXT), ?UINT16(Len),
-		       ExtData:Len/binary, Rest/binary>>, Acc) ->
+		       ExtData:Len/binary, Rest/binary>>, Version, Acc) ->
     SignAlgoListLen = Len - 2,
     <<?UINT16(SignAlgoListLen), SignAlgoList/binary>> = ExtData,
     HashSignAlgos = [{ssl_cipher:hash_algorithm(Hash), ssl_cipher:sign_algorithm(Sign)} ||
 			<<?BYTE(Hash), ?BYTE(Sign)>> <= SignAlgoList],
-    decode_extensions(Rest, Acc#{signature_algs =>
-                                     #hash_sign_algos{hash_sign_algos = HashSignAlgos}});
+    decode_extensions(Rest, Version, Acc#{signature_algs =>
+                                              #hash_sign_algos{hash_sign_algos =
+                                                                   HashSignAlgos}});
 
 decode_extensions(<<?UINT16(?SIGNATURE_ALGORITHMS_CERT_EXT), ?UINT16(Len),
-		       ExtData:Len/binary, Rest/binary>>, Acc) ->
+		       ExtData:Len/binary, Rest/binary>>, Version, Acc) ->
     SignSchemeListLen = Len - 2,
     <<?UINT16(SignSchemeListLen), SignSchemeList/binary>> = ExtData,
     SignSchemes = [ssl_cipher:signature_scheme(SignScheme) ||
 			<<?UINT16(SignScheme)>> <= SignSchemeList],
-    decode_extensions(Rest, Acc#{signature_algs_cert =>
-                                     #signature_scheme_list{
-                                        signature_scheme_list = SignSchemes}});
+    decode_extensions(Rest, Version, Acc#{signature_algs_cert =>
+                                              #signature_scheme_list{
+                                                 signature_scheme_list = SignSchemes}});
 
 decode_extensions(<<?UINT16(?ELLIPTIC_CURVES_EXT), ?UINT16(Len),
-		       ExtData:Len/binary, Rest/binary>>, Acc) ->
+		       ExtData:Len/binary, Rest/binary>>, Version, Acc)
+  when Version < {3,4} ->
     <<?UINT16(_), EllipticCurveList/binary>> = ExtData,
     %% Ignore unknown curves
     Pick = fun(Enum) ->
@@ -2139,42 +2134,66 @@ decode_extensions(<<?UINT16(?ELLIPTIC_CURVES_EXT), ?UINT16(Len),
 		   end
 	   end,
     EllipticCurves = lists:filtermap(Pick, [ECC || <<ECC:16>> <= EllipticCurveList]),
-    decode_extensions(Rest, Acc#{elliptic_curves =>
-                                     #elliptic_curves{elliptic_curve_list =
-                                                          EllipticCurves}});
+    decode_extensions(Rest, Version, Acc#{elliptic_curves =>
+                                              #elliptic_curves{elliptic_curve_list =
+                                                                   EllipticCurves}});
+
+
+decode_extensions(<<?UINT16(?ELLIPTIC_CURVES_EXT), ?UINT16(Len),
+		       ExtData:Len/binary, Rest/binary>>, Version, Acc)
+  when Version =:= {3,4} ->
+    <<?UINT16(_), GroupList/binary>> = ExtData,
+    %% Ignore unknown curves
+    Pick = fun(Enum) ->
+		   case tls_v1:enum_to_group(Enum) of
+		       undefined ->
+			   false;
+		       Group ->
+			   {true, Group}
+		   end
+	   end,
+    SupportedGroups = lists:filtermap(Pick, [Group || <<Group:16>> <= GroupList]),
+    decode_extensions(Rest, Version, Acc#{elliptic_curves =>
+                                              #supported_groups{supported_groups =
+                                                                   SupportedGroups}});
+
 decode_extensions(<<?UINT16(?EC_POINT_FORMATS_EXT), ?UINT16(Len),
-		       ExtData:Len/binary, Rest/binary>>, Acc) ->
+                    ExtData:Len/binary, Rest/binary>>, Version, Acc) ->
     <<?BYTE(_), ECPointFormatList/binary>> = ExtData,
     ECPointFormats = binary_to_list(ECPointFormatList),
-    decode_extensions(Rest, Acc#{ec_point_formats =>
-                                     #ec_point_formats{ec_point_format_list =
-                                                           ECPointFormats}});
+    decode_extensions(Rest, Version, Acc#{ec_point_formats =>
+                                               #ec_point_formats{ec_point_format_list =
+                                                                     ECPointFormats}});
 
-decode_extensions(<<?UINT16(?SNI_EXT), ?UINT16(Len), Rest/binary>>, Acc) when Len == 0 ->
-    decode_extensions(Rest, Acc#{sni => #sni{hostname = ""}}); %% Server may send an empy SNI
+decode_extensions(<<?UINT16(?SNI_EXT), ?UINT16(Len),
+                    Rest/binary>>, Version, Acc) when Len == 0 ->
+    decode_extensions(Rest, Version, Acc#{sni => #sni{hostname = ""}}); %% Server may send an empy SNI
 
 decode_extensions(<<?UINT16(?SNI_EXT), ?UINT16(Len),
-                ExtData:Len/binary, Rest/binary>>, Acc) ->
+                ExtData:Len/binary, Rest/binary>>, Version, Acc) ->
     <<?UINT16(_), NameList/binary>> = ExtData,
-    decode_extensions(Rest, Acc#{sni => dec_sni(NameList)});
+    decode_extensions(Rest, Version, Acc#{sni => dec_sni(NameList)});
 
 decode_extensions(<<?UINT16(?SUPPORTED_VERSIONS_EXT), ?UINT16(Len),
-                       ExtData:Len/binary, Rest/binary>>, Acc) when Len > 2 ->
+                       ExtData:Len/binary, Rest/binary>>, Version, Acc) when Len > 2 ->
     <<?UINT16(_),Versions/binary>> = ExtData,
-    decode_extensions(Rest, Acc#{client_hello_versions =>
-                                     #client_hello_versions{versions = decode_versions(Versions)}});
+    decode_extensions(Rest, Version, Acc#{client_hello_versions =>
+                                              #client_hello_versions{
+                                                 versions = decode_versions(Versions)}});
 
 decode_extensions(<<?UINT16(?SUPPORTED_VERSIONS_EXT), ?UINT16(Len),
-                       ?UINT16(Version), Rest/binary>>, Acc) when Len =:= 2, Version =:= 16#0304 ->
-    decode_extensions(Rest, Acc#{server_hello_selected_version =>
-                                     #server_hello_selected_version{selected_version = {3,4}}});
+                       ?UINT16(SelectedVersion), Rest/binary>>, Version, Acc)
+  when Len =:= 2, SelectedVersion =:= 16#0304 ->
+    decode_extensions(Rest, Version, Acc#{server_hello_selected_version =>
+                                              #server_hello_selected_version{selected_version =
+                                                                                 {3,4}}});
 
 %% Ignore data following the ClientHello (i.e.,
 %% extensions) if not understood.
-decode_extensions(<<?UINT16(_), ?UINT16(Len), _Unknown:Len/binary, Rest/binary>>, Acc) ->
-    decode_extensions(Rest, Acc);
+decode_extensions(<<?UINT16(_), ?UINT16(Len), _Unknown:Len/binary, Rest/binary>>, Version, Acc) ->
+    decode_extensions(Rest, Version, Acc);
 %% This theoretically should not happen if the protocol is followed, but if it does it is ignored.
-decode_extensions(_, Acc) ->
+decode_extensions(_, _, Acc) ->
     Acc.
 
 dec_hashsign(<<?BYTE(HashAlgo), ?BYTE(SignAlgo)>>) ->
diff --git a/lib/ssl/src/ssl_handshake.hrl b/lib/ssl/src/ssl_handshake.hrl
index 81ff73baf9..4336742e76 100644
--- a/lib/ssl/src/ssl_handshake.hrl
+++ b/lib/ssl/src/ssl_handshake.hrl
@@ -53,7 +53,7 @@
 -define(NUM_OF_SESSION_ID_BYTES, 32).  % TSL 1.1 & SSL 3
 -define(NUM_OF_PREMASTERSECRET_BYTES, 48).
 -define(DEFAULT_DIFFIE_HELLMAN_GENERATOR, ssl_dh_groups:modp2048_generator()).
--define(DEFAULT_DIFFIE_HELLMAN_PRIME, ssl_sh_groups:modp2048_prime()).
+-define(DEFAULT_DIFFIE_HELLMAN_PRIME, ssl_dh_groups:modp2048_prime()).
 
 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
 %%% Handsake protocol - RFC 4346 section 7.4
diff --git a/lib/ssl/src/tls_handshake_1_3.erl b/lib/ssl/src/tls_handshake_1_3.erl
index b9ebf2e502..104017b67c 100644
--- a/lib/ssl/src/tls_handshake_1_3.erl
+++ b/lib/ssl/src/tls_handshake_1_3.erl
@@ -150,7 +150,7 @@ decode_cert_entries(<<?UINT24(DSize), Data:DSize/binary, ?UINT16(Esize), BinExts
 encode_extensions(Exts)->
     ssl_handshake:encode_extensions(extensions_list(Exts)).
 decode_extensions(Exts) ->
-    ssl_handshake:decode_extensions(Exts).
+    ssl_handshake:decode_extensions(Exts, {3,4}).
 
 extensions_list(HelloExtensions) ->
     [Ext || {_, Ext} <- maps:to_list(HelloExtensions)].
diff --git a/lib/ssl/src/tls_v1.erl b/lib/ssl/src/tls_v1.erl
index a535df3dc3..e7218c8c8a 100644
--- a/lib/ssl/src/tls_v1.erl
+++ b/lib/ssl/src/tls_v1.erl
@@ -34,7 +34,7 @@
 	 ecc_curves/1, ecc_curves/2, oid_to_enum/1, enum_to_oid/1, 
 	 default_signature_algs/1, signature_algs/2,
          default_signature_schemes/1, signature_schemes/2,
-         groups/1, group_to_enum/1]).
+         groups/1, groups/2, group_to_enum/1, enum_to_group/1]).
 
 -type named_curve() :: sect571r1 | sect571k1 | secp521r1 | brainpoolP512r1 |
                        sect409k1 | sect409r1 | brainpoolP384r1 | secp384r1 |
@@ -516,6 +516,15 @@ group_to_enum(ffdhe4096) -> 258;
 group_to_enum(ffdhe6144) -> 259;
 group_to_enum(ffdhe8192) -> 260.
 
+enum_to_group(23) -> secp256r1;
+enum_to_group(24) -> secp384r1;
+enum_to_group(25) -> secp521r1;
+enum_to_group(256) -> ffdhe2048;
+enum_to_group(257) -> ffdhe3072;
+enum_to_group(258) -> ffdhe4096;
+enum_to_group(259) -> ffdhe6144;
+enum_to_group(260) -> ffdhe8192;
+enum_to_group(_) -> undefined.
 
 %% ECC curves from draft-ietf-tls-ecc-12.txt (Oct. 17, 2005)
 oid_to_enum(?sect163k1) -> 1;
-- 
cgit v1.2.3