-export([client_hello/8, client_hello/9, cookie/4, hello/4, 
	 hello_verify_request/2, get_dtls_handshake/3, fragment_handshake/2,
	 handshake_bin/2, encode_handshake/3]).

-type dtls_handshake() :: #client_hello{} | #hello_verify_request{} | 

%% Internal application API
-spec client_hello(host(), inet:port_number(), ssl_record:connection_states(),
		   #ssl_options{}, integer(), atom(), boolean(), der_cert()) ->
%% Description: Creates a client hello message.
client_hello(Host, Port, ConnectionStates, SslOpts,
	     Cache, CacheCb, Renegotiation, OwnCert) ->
    %% First client hello (two sent in DTLS ) uses empty Cookie
    client_hello(Host, Port, <<>>, ConnectionStates, SslOpts,
		 Cache, CacheCb, Renegotiation, OwnCert).

-spec client_hello(host(), inet:port_number(), term(), ssl_record:connection_states(),
		   #ssl_options{}, integer(), atom(), boolean(), der_cert()) ->
%% Description: Creates a client hello message.
client_hello(Host, Port, Cookie, ConnectionStates,
	     #ssl_options{versions = Versions,
			  ciphers = UserSuites
			 } = SslOpts,
	     Cache, CacheCb, Renegotiation, OwnCert) ->
    Version =  dtls_record:highest_protocol_version(Versions),
    Pending = ssl_record:pending_connection_state(ConnectionStates, read),
    SecParams = maps:get(security_parameters, Pending),
    TLSVersion = dtls_v1:corresponding_tls_version(Version),
    CipherSuites = ssl_handshake:available_suites(UserSuites, TLSVersion),

    Extensions = ssl_handshake:client_hello_extensions(Host, TLSVersion, CipherSuites,
                                                       SslOpts, ConnectionStates, Renegotiation),

    Id = ssl_session:client_id({Host, Port, SslOpts}, Cache, CacheCb, OwnCert),

    #client_hello{session_id = Id,
		  client_version = Version,
		  cipher_suites = ssl_handshake:cipher_suites(CipherSuites, Renegotiation),
		  compression_methods = ssl_record:compressions(),
		  random = SecParams#security_parameters.client_random,
		  cookie = Cookie,
		  extensions = Extensions

hello(#server_hello{server_version = Version, random = Random,
		    cipher_suite = CipherSuite,
		    compression_method = Compression,
		    session_id = SessionId, extensions = HelloExt},
      #ssl_options{versions = SupportedVersions} = SslOpt,
      ConnectionStates0, Renegotiation) ->
    case dtls_record:is_acceptable_version(Version, SupportedVersions) of
	true ->
	    handle_server_hello_extensions(Version, SessionId, Random, CipherSuite,
					   Compression, HelloExt, SslOpt, ConnectionStates0, Renegotiation);
	false ->

hello(#client_hello{client_version = ClientVersion} = Hello,
      #ssl_options{versions = Versions} = SslOpts,
      Info, Renegotiation) ->
    Version = ssl_handshake:select_version(dtls_record, ClientVersion, Versions),
    handle_client_hello(Version, Hello, SslOpts, Info, Renegotiation).

cookie(Key, Address, Port, #client_hello{client_version = {Major, Minor},
					 random = Random,
					 session_id = SessionId,
					 cipher_suites = CipherSuites,
					 compression_methods = CompressionMethods}) ->
    CookieData = [address_to_bin(Address, Port),
		  <<?BYTE(Major), ?BYTE(Minor)>>,
		  Random, SessionId, CipherSuites, CompressionMethods],
    crypto:hmac(sha, Key, CookieData).

-spec hello_verify_request(binary(),  dtls_record:dtls_version()) -> #hello_verify_request{}.
%% Description: Creates a hello verify request message sent by server to
%% verify client
hello_verify_request(Cookie, Version) ->
    #hello_verify_request{protocol_version = Version, cookie = Cookie}.


encode_handshake(Handshake, Version, Seq) ->
    {MsgType, Bin} = enc_handshake(Handshake, Version),
    Len = byte_size(Bin),
    [MsgType, ?uint24(Len), ?uint16(Seq), ?uint24(0), ?uint24(Len), Bin].

fragment_handshake(Bin, _) when is_binary(Bin)-> 
    %% This is the change_cipher_spec not a "real handshake" but part of the flight
fragment_handshake([MsgType, Len, Seq, _, Len, Bin], Size) ->
    Bins = bin_fragments(Bin, Size),
    handshake_fragments(MsgType, Seq, Len, Bins, []).

handshake_bin([Type, Length, Data], Seq) ->	
    handshake_bin(Type, Length, Seq, Data).
-spec get_dtls_handshake(dtls_record:dtls_version(), binary(), #protocol_buffers{}) ->
                                {[dtls_handshake()], #protocol_buffers{}}.                
%% Description:  Given buffered and new data from dtls_record, collects
%% and returns it as a list of handshake messages, also returns 
%% possible leftover data in the new "protocol_buffers".
get_dtls_handshake(Version, Fragment, ProtocolBuffers) ->
    handle_fragments(Version, Fragment, ProtocolBuffers, []).

%%% Internal functions
handle_client_hello(Version, #client_hello{session_id = SugesstedId,
					   cipher_suites = CipherSuites,
					   compression_methods = Compressions,
					   random = Random,
					   extensions =
					       #hello_extensions{elliptic_curves = Curves,
								 signature_algs = ClientHashSigns} = HelloExt},
		    #ssl_options{versions = Versions,
				 signature_algs = SupportedHashSigns} = SslOpts,
		    {Port, Session0, Cache, CacheCb, ConnectionStates0, Cert, _}, Renegotiation) ->
    case dtls_record:is_acceptable_version(Version, Versions) of
	true ->
	    TLSVersion = dtls_v1:corresponding_tls_version(Version),
	    AvailableHashSigns = ssl_handshake:available_signature_algs(
				   ClientHashSigns, SupportedHashSigns, Cert,TLSVersion),
	    ECCCurve = ssl_handshake:select_curve(Curves, ssl_handshake:supported_ecc(TLSVersion)),
	    {Type, #session{cipher_suite = CipherSuite} = Session1}
		= ssl_handshake:select_session(SugesstedId, CipherSuites, AvailableHashSigns, Compressions,
					       Port, Session0#session{ecc = ECCCurve}, TLSVersion,
					       SslOpts, Cache, CacheCb, Cert),
	    case CipherSuite of
		no_suite ->
		_ ->
		    {KeyExAlg,_,_,_} = ssl_cipher:suite_definition(CipherSuite),
		    case ssl_handshake:select_hashsign(ClientHashSigns, Cert, KeyExAlg, 
						       SupportedHashSigns, TLSVersion) of
			#alert{} = Alert ->
			HashSign ->
			    handle_client_hello_extensions(Version, Type, Random, CipherSuites, HelloExt,
							   SslOpts, Session1, ConnectionStates0,
							   Renegotiation, HashSign)
	false ->

handle_client_hello_extensions(Version, Type, Random, CipherSuites,
			HelloExt, SslOpts, Session0, ConnectionStates0, Renegotiation, HashSign) ->
    try ssl_handshake:handle_client_hello_extensions(dtls_record, Random, CipherSuites,
						     HelloExt, dtls_v1:corresponding_tls_version(Version),
						     SslOpts, Session0, ConnectionStates0, Renegotiation) of
	#alert{} = Alert ->
	{Session, ConnectionStates, Protocol, ServerHelloExt} ->
	    {Version, {Type, Session}, ConnectionStates, Protocol, ServerHelloExt, HashSign}
    catch throw:Alert ->

handle_server_hello_extensions(Version, SessionId, Random, CipherSuite,
			       Compression, HelloExt, SslOpt, ConnectionStates0, Renegotiation) ->
    case ssl_handshake:handle_server_hello_extensions(dtls_record, Random, CipherSuite,
						      Compression, HelloExt,
						      SslOpt, ConnectionStates0, Renegotiation) of
	#alert{} = Alert ->
	{ConnectionStates, ProtoExt, Protocol} ->
	    {Version, SessionId, ConnectionStates, ProtoExt, Protocol}

%%%%%%%  Encodeing   %%%%%%%%%%%%%

enc_handshake(#hello_verify_request{protocol_version = {Major, Minor},
 				       cookie = Cookie}, _Version) ->
    CookieLength = byte_size(Cookie),
    {?HELLO_VERIFY_REQUEST, <<?BYTE(Major), ?BYTE(Minor),

enc_handshake(#hello_request{}, _Version) ->
    {?HELLO_REQUEST, <<>>};
enc_handshake(#client_hello{client_version = {Major, Minor},
			       random = Random,
			       session_id = SessionID,
			       cookie = Cookie,
			       cipher_suites = CipherSuites,
			       compression_methods = CompMethods,
			       extensions = HelloExtensions}, _Version) ->
    SIDLength = byte_size(SessionID),
    CookieLength = byte_size(Cookie),
    BinCompMethods = list_to_binary(CompMethods),
    CmLength = byte_size(BinCompMethods),
    BinCipherSuites = list_to_binary(CipherSuites),
    CsLength = byte_size(BinCipherSuites),
    ExtensionsBin = ssl_handshake:encode_hello_extensions(HelloExtensions),

    {?CLIENT_HELLO, <<?BYTE(Major), ?BYTE(Minor), Random:32/binary,
 		      ?BYTE(SIDLength), SessionID/binary,
		      ?BYTE(CookieLength), Cookie/binary,
		      ?UINT16(CsLength), BinCipherSuites/binary,
 		      ?BYTE(CmLength), BinCompMethods/binary, ExtensionsBin/binary>>};

enc_handshake(#server_hello{} = HandshakeMsg, Version) ->
    {Type, <<?BYTE(Major), ?BYTE(Minor), Rest/binary>>} = 
	ssl_handshake:encode_handshake(HandshakeMsg, Version),
    {DTLSMajor, DTLSMinor} = dtls_v1:corresponding_dtls_version({Major, Minor}),
    {Type,  <<?BYTE(DTLSMajor), ?BYTE(DTLSMinor), Rest/binary>>};

enc_handshake(HandshakeMsg, Version) ->
    ssl_handshake:encode_handshake(HandshakeMsg, dtls_v1:corresponding_tls_version(Version)).

bin_fragments(Bin, Size) ->
     bin_fragments(Bin, size(Bin), Size, 0, []).

bin_fragments(Bin, BinSize,  FragSize, Offset, Fragments) ->
    case (BinSize - Offset - FragSize)  > 0 of
	true ->
	    Frag = binary:part(Bin, {Offset, FragSize}),
	    bin_fragments(Bin, BinSize, FragSize, Offset + FragSize, [{Frag, Offset} | Fragments]);
	false ->
	    Frag = binary:part(Bin, {Offset, BinSize-Offset}),
	    lists:reverse([{Frag, Offset} | Fragments])

handshake_fragments(_, _, _, [], Acc) ->
handshake_fragments(MsgType, Seq, Len, [{Bin, Offset} | Bins], Acc) ->
    FragLen = size(Bin),
    handshake_fragments(MsgType, Seq, Len, Bins, 
      [<<?BYTE(MsgType), Len/binary, Seq/binary, ?UINT24(Offset),
	 ?UINT24(FragLen), Bin/binary>> | Acc]).

address_to_bin({A,B,C,D}, Port) ->
address_to_bin({A,B,C,D,E,F,G,H}, Port) ->

%%%%%%%  Decodeing   %%%%%%%%%%%%%

handle_fragments(Version, FragmentData, Buffers0, Acc) ->
    Fragments = decode_handshake_fragments(FragmentData),
    do_handle_fragments(Version, Fragments, Buffers0, Acc).

do_handle_fragments(_, [], Buffers, Acc) ->
    {lists:reverse(Acc), Buffers};
do_handle_fragments(Version, [Fragment | Fragments], Buffers0, Acc) ->
    case reassemble(Version, Fragment, Buffers0) of
	{more_data, Buffers} when Fragments == [] ->
	    {lists:reverse(Acc), Buffers};
	{more_data, Buffers} ->
	    do_handle_fragments(Version, Fragments, Buffers, Acc);
	{HsPacket, Buffers} ->
	    do_handle_fragments(Version, Fragments, Buffers, [HsPacket | Acc])

decode_handshake(Version, <<?BYTE(Type), Bin/binary>>) ->
    decode_handshake(Version, Type, Bin).

decode_handshake(_, ?HELLO_REQUEST, <<>>) ->
decode_handshake(_Version, ?CLIENT_HELLO, <<?UINT24(_), ?UINT16(_),
					    ?UINT24(_),  ?UINT24(_), 
					    ?BYTE(Major), ?BYTE(Minor), Random:32/binary,
					    ?BYTE(SID_length), Session_ID:SID_length/binary,
					    ?BYTE(CookieLength), Cookie:CookieLength/binary,
					    ?UINT16(Cs_length), CipherSuites:Cs_length/binary,
					    ?BYTE(Cm_length), Comp_methods:Cm_length/binary,
					    Extensions/binary>>) ->
    DecodedExtensions = ssl_handshake:decode_hello_extensions({client, Extensions}),

       client_version = {Major,Minor},
       random = Random,
       cookie = Cookie,
       session_id = Session_ID,
       cipher_suites = ssl_handshake:decode_suites('2_bytes', CipherSuites),
       compression_methods = Comp_methods,
       extensions = DecodedExtensions

decode_handshake(_Version, ?HELLO_VERIFY_REQUEST, <<?UINT24(_), ?UINT16(_),
						    ?UINT24(_),  ?UINT24(_),
						    ?BYTE(Major), ?BYTE(Minor),
						    Cookie:CookieLength/binary>>) ->
    #hello_verify_request{protocol_version = {Major, Minor},
			  cookie = Cookie};

decode_handshake(Version, Tag,  <<?UINT24(_), ?UINT16(_),
				  ?UINT24(_),  ?UINT24(_), Msg/binary>>) -> 
    %% DTLS specifics stripped
    decode_tls_thandshake(Version, Tag, Msg).

decode_tls_thandshake(Version, Tag, Msg) ->
    TLSVersion = dtls_v1:corresponding_tls_version(Version),
    ssl_handshake:decode_handshake(TLSVersion, Tag, Msg).

decode_handshake_fragments(<<>>) ->
decode_handshake_fragments(<<?BYTE(Type), ?UINT24(Length),
			     ?UINT24(FragmentOffset), ?UINT24(FragmentLength),
			    Fragment:FragmentLength/binary, Rest/binary>>) ->
    [#handshake_fragment{type = Type, 
			length = Length,
			message_seq = MessageSeq,
			fragment_offset = FragmentOffset,
			fragment_length = FragmentLength,
			fragment = Fragment} | decode_handshake_fragments(Rest)].

reassemble(Version,  #handshake_fragment{message_seq = Seq} = Fragment, 
	   #protocol_buffers{dtls_handshake_next_seq = Seq,
			     dtls_handshake_next_fragments = Fragments0,
			     dtls_handshake_later_fragments = LaterFragments0} = 
    case reassemble_fragments(Fragment, Fragments0) of
	{more_data, Fragments} ->
	    {more_data,  Buffers0#protocol_buffers{dtls_handshake_next_fragments = Fragments}};
	{raw, RawHandshake} ->
	    Handshake = decode_handshake(Version, RawHandshake),
	    {NextFragments, LaterFragments} = next_fragments(LaterFragments0),
	    {{Handshake, RawHandshake}, Buffers0#protocol_buffers{dtls_handshake_next_seq = Seq + 1,
						  dtls_handshake_next_fragments = NextFragments,
						  dtls_handshake_later_fragments = LaterFragments}}
reassemble(_,  #handshake_fragment{message_seq = FragSeq} = Fragment, 
	   #protocol_buffers{dtls_handshake_next_seq = Seq,
			     dtls_handshake_later_fragments = LaterFragments} = Buffers0) when FragSeq > Seq-> 
      Buffers0#protocol_buffers{dtls_handshake_later_fragments = [Fragment | LaterFragments]}};
reassemble(_, _, Buffers) -> 
    %% Disregard fragments FragSeq < Seq
    {more_data, Buffers}.

reassemble_fragments(Current, Fragments0) ->
    [Frag1 | Frags] = lists:keysort(#handshake_fragment.fragment_offset, [Current | Fragments0]),
    [Fragment | _] = Fragments = merge_fragment(Frag1, Frags),
    case is_complete_handshake(Fragment) of
	true ->
	    {raw, handshake_bin(Fragment)};
	false ->
	    {more_data, Fragments}

merge_fragment(Frag0, []) ->
merge_fragment(Frag0, [Frag1 | Rest]) ->
    case merge_fragments(Frag0, Frag1) of
	[_|_] = Frags ->
	    Frags ++ Rest;
	Frag ->
	    merge_fragment(Frag, Rest)

is_complete_handshake(#handshake_fragment{length = Length, fragment_length = Length}) ->
is_complete_handshake(_) ->

next_fragments(LaterFragments) ->
    case lists:keysort(#handshake_fragment.message_seq, LaterFragments) of
	[] ->
	    {[], []}; 
	[#handshake_fragment{message_seq = Seq} | _] = Fragments ->
	    split_frags(Fragments, Seq, [])

split_frags([#handshake_fragment{message_seq = Seq} = Frag | Rest], Seq, Acc) ->
    split_frags(Rest, Seq, [Frag | Acc]);
split_frags(Frags, _, Acc) ->
    {lists:reverse(Acc), Frags}.

%% Duplicate
		   fragment_offset = PreviousOffSet, 
		   fragment_length = PreviousLen,
		   fragment = PreviousData
		  } = Previous, 
		   fragment_offset = PreviousOffSet,
		   fragment_length = PreviousLen,
		   fragment = PreviousData}) ->

%% Lager fragment save new data
		   fragment_offset = PreviousOffSet, 
		   fragment_length = PreviousLen,
		   fragment = PreviousData
		  } = Previous, 
		   fragment_offset = PreviousOffSet,
		   fragment_length = CurrentLen,
		   fragment = CurrentData}) when CurrentLen > PreviousLen ->
    NewLength = CurrentLen - PreviousLen,
    <<_:PreviousLen/binary, NewData/binary>> = CurrentData, 
      fragment_length = PreviousLen + NewLength,
      fragment = <<PreviousData/binary, NewData/binary>>

%% Smaller fragment
		   fragment_offset = PreviousOffSet, 
		   fragment_length = PreviousLen
		  } = Previous, 
		   fragment_offset = PreviousOffSet,
		   fragment_length = CurrentLen}) when CurrentLen < PreviousLen ->
%% Next fragment, might be overlapping
		   fragment_offset = PreviousOffSet, 
		   fragment_length = PreviousLen,
		   fragment = PreviousData
		  } = Previous, 
		   fragment_offset = CurrentOffSet,
		   fragment_length = CurrentLen,
                  fragment = CurrentData})
  when PreviousOffSet + PreviousLen >= CurrentOffSet andalso
       PreviousOffSet + PreviousLen < CurrentOffSet + CurrentLen ->
    CurrentStart = PreviousOffSet + PreviousLen - CurrentOffSet,
    <<_:CurrentStart/bytes, Data/binary>> = CurrentData,
      fragment_length =  PreviousLen + CurrentLen - CurrentStart,
      fragment = <<PreviousData/binary, Data/binary>>};
%% already fully contained fragment
                   fragment_offset = PreviousOffSet, 
                   fragment_length = PreviousLen
                  } = Previous, 
                   fragment_offset = CurrentOffSet,
                   fragment_length = CurrentLen})
  when PreviousOffSet + PreviousLen >= CurrentOffSet andalso
       PreviousOffSet + PreviousLen >= CurrentOffSet + CurrentLen ->

%% No merge there is a gap
merge_fragments(Previous, Current) ->
    [Previous, Current].
		 type = Type,
		 length = Len, 
		 message_seq = Seq,
		 fragment_length = Len,
		 fragment_offset = 0,
		 fragment = Fragment}) ->	    
    handshake_bin(Type, Len, Seq, Fragment).

handshake_bin(Type, Length, Seq, FragmentData) -> 
    <<?BYTE(Type), ?UINT24(Length),
      ?UINT16(Seq), ?UINT24(0), ?UINT24(Length),