diff options
author | Raimo Niskanen <[email protected]> | 2019-02-21 14:42:07 +0100 |
---|---|---|
committer | Raimo Niskanen <[email protected]> | 2019-02-21 14:42:07 +0100 |
commit | 9e9884640b7d0ee69cd39909842ec5fc8826859f (patch) | |
tree | 08afd3c1a804077457d164b63d1c70b56a139518 /lib/ssl/src | |
parent | f06f5bf23a3cd1f040c8ab6f059097d22161abc7 (diff) | |
parent | c4fb3492f48d00214d509de5ec9336e4adf51c58 (diff) | |
download | otp-9e9884640b7d0ee69cd39909842ec5fc8826859f.tar.gz otp-9e9884640b7d0ee69cd39909842ec5fc8826859f.tar.bz2 otp-9e9884640b7d0ee69cd39909842ec5fc8826859f.zip |
Merge branch 'raimo/ssl/tls-optimization/OTP-15529' into maint
* raimo/ssl/tls-optimization/OTP-15529:
Inline local function
Optimize binary matching
Clean up module boundaries
Remove redundant return of CipherState
Use iovec() internally in send path
Small binary handling optimizations
Optimize read_application_data with Okasaki queue
Try to optimize decode_cipher_text/3
Optimize application data aggregation
Optimize TLS record parsing with Okasaki queue
Cache strong_random_bytes for IV
Optimize padding
Produce less garbage in encrypt loop
Reorganize #data{}
Tidy up state machine
Add server GC info to bench results
Diffstat (limited to 'lib/ssl/src')
-rw-r--r-- | lib/ssl/src/dtls_connection.erl | 27 | ||||
-rw-r--r-- | lib/ssl/src/dtls_record.erl | 16 | ||||
-rw-r--r-- | lib/ssl/src/ssl.erl | 4 | ||||
-rw-r--r-- | lib/ssl/src/ssl_cipher.erl | 101 | ||||
-rw-r--r-- | lib/ssl/src/ssl_connection.erl | 369 | ||||
-rw-r--r-- | lib/ssl/src/ssl_connection.hrl | 6 | ||||
-rw-r--r-- | lib/ssl/src/ssl_record.erl | 70 | ||||
-rw-r--r-- | lib/ssl/src/ssl_record.hrl | 4 | ||||
-rw-r--r-- | lib/ssl/src/tls_connection.erl | 132 | ||||
-rw-r--r-- | lib/ssl/src/tls_record.erl | 412 | ||||
-rw-r--r-- | lib/ssl/src/tls_sender.erl | 286 |
11 files changed, 862 insertions, 565 deletions
diff --git a/lib/ssl/src/dtls_connection.erl b/lib/ssl/src/dtls_connection.erl index 7a91578fe2..2c6b71c97a 100644 --- a/lib/ssl/src/dtls_connection.erl +++ b/lib/ssl/src/dtls_connection.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2013-2018. All Rights Reserved. +%% Copyright Ericsson AB 2013-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -50,8 +50,7 @@ -export([encode_alert/3, send_alert/2, send_alert_in_connection/2, close/5, protocol_name/0]). %% Data handling --export([encode_data/3, next_record/1, - send/3, socket/5, setopts/3, getopts/3]). +-export([next_record/1, socket/4, setopts/3, getopts/3]). %% gen_statem state functions -export([init/3, error/3, downgrade/3, %% Initiation and take down states @@ -392,16 +391,13 @@ protocol_name() -> %% Data handling %%==================================================================== -encode_data(Data, Version, ConnectionStates0)-> - dtls_record:encode_data(Data, Version, ConnectionStates0). +send(Transport, {Listener, Socket}, Data) when is_pid(Listener) -> % Server socket + dtls_socket:send(Transport, Socket, Data); +send(Transport, Socket, Data) -> % Client socket + dtls_socket:send(Transport, Socket, Data). -send(Transport, {_, {{_,_}, _} = Socket}, Data) -> - send(Transport, Socket, Data); -send(Transport, Socket, Data) -> - dtls_socket:send(Transport, Socket, Data). - -socket(Pid, Transport, Socket, Connection, _) -> - dtls_socket:socket(Pid, Transport, Socket, Connection). +socket(Pid, Transport, Socket, _Tracker) -> + dtls_socket:socket(Pid, Transport, Socket, ?MODULE). setopts(Transport, Socket, Other) -> dtls_socket:setopts(Transport, Socket, Other). @@ -805,7 +801,7 @@ initial_state(Role, Host, Port, Socket, {SSLOptions, SocketOptions, _}, User, session = #session{is_resumable = new}, connection_states = ConnectionStates, protocol_buffers = #protocol_buffers{}, - user_data_buffer = <<>>, + user_data_buffer = {[],0,[]}, start_or_recv_from = undefined, flight_buffer = new_flight(), protocol_specific = #{flight_state => initial_flight_state(DataTag)} @@ -1173,7 +1169,6 @@ log_ignore_alert(false, _, _,_) -> send_application_data(Data, From, _StateName, #state{static_env = #static_env{socket = Socket, - protocol_cb = Connection, transport_cb = Transport}, connection_env = #connection_env{negotiated_version = Version}, handshake_env = HsEnv, @@ -1186,9 +1181,9 @@ send_application_data(Data, From, _StateName, [{next_event, {call, From}, {application_data, Data}}]); false -> {Msgs, ConnectionStates} = - Connection:encode_data(Data, Version, ConnectionStates0), + dtls_record:encode_data(Data, Version, ConnectionStates0), State = State0#state{connection_states = ConnectionStates}, - case Connection:send(Transport, Socket, Msgs) of + case send(Transport, Socket, Msgs) of ok -> ssl_connection:hibernate_after(connection, State, [{reply, From, ok}]); Result -> diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl index dd33edfd77..2fe875da31 100644 --- a/lib/ssl/src/dtls_record.erl +++ b/lib/ssl/src/dtls_record.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2013-2018. All Rights Reserved. +%% Copyright Ericsson AB 2013-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -546,15 +546,15 @@ decode_cipher_text(#ssl_tls{type = Type, version = Version, compression_algorithm = CompAlg}} = ReadState0, ConnnectionStates0) -> AAD = start_additional_data(Type, Version, Epoch, Seq), - CipherS1 = ssl_record:nonce_seed(BulkCipherAlgo, <<?UINT16(Epoch), ?UINT48(Seq)>>, CipherS0), + CipherS = ssl_record:nonce_seed(BulkCipherAlgo, <<?UINT16(Epoch), ?UINT48(Seq)>>, CipherS0), TLSVersion = dtls_v1:corresponding_tls_version(Version), - case ssl_record:decipher_aead(BulkCipherAlgo, CipherS1, AAD, CipherFragment, TLSVersion) of - {PlainFragment, CipherState} -> - {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, + case ssl_record:decipher_aead(BulkCipherAlgo, CipherS, AAD, CipherFragment, TLSVersion) of + PlainFragment when is_binary(PlainFragment) -> + {Plain, CompressionS} = ssl_record:uncompress(CompAlg, PlainFragment, CompressionS0), - ReadState0 = ReadState0#{compression_state => CompressionS1, - cipher_state => CipherState}, - ReadState = update_replay_window(Seq, ReadState0), + ReadState1 = ReadState0#{compression_state := CompressionS, + cipher_state := CipherS}, + ReadState = update_replay_window(Seq, ReadState1), ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read), {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; #alert{} = Alert -> diff --git a/lib/ssl/src/ssl.erl b/lib/ssl/src/ssl.erl index c39a6f1603..bdb2283ad6 100644 --- a/lib/ssl/src/ssl.erl +++ b/lib/ssl/src/ssl.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 1999-2018. All Rights Reserved. +%% Copyright Ericsson AB 1999-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -626,7 +626,7 @@ close(#sslsocket{pid = {ListenSocket, #config{transport_info={Transport,_, _, _} send(#sslsocket{pid = [Pid]}, Data) when is_pid(Pid) -> ssl_connection:send(Pid, Data); send(#sslsocket{pid = [_, Pid]}, Data) when is_pid(Pid) -> - tls_sender:send_data(Pid, erlang:iolist_to_binary(Data)); + tls_sender:send_data(Pid, erlang:iolist_to_iovec(Data)); send(#sslsocket{pid = {_, #config{transport_info={_, udp, _, _}}}}, _) -> {error,enotconn}; %% Emulate connection behaviour send(#sslsocket{pid = {dtls,_}}, _) -> diff --git a/lib/ssl/src/ssl_cipher.erl b/lib/ssl/src/ssl_cipher.erl index cf1bec6332..fce48d1678 100644 --- a/lib/ssl/src/ssl_cipher.erl +++ b/lib/ssl/src/ssl_cipher.erl @@ -1,7 +1,7 @@ % %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2007-2018. All Rights Reserved. +%% Copyright Ericsson AB 2007-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -41,7 +41,7 @@ rc4_suites/1, des_suites/1, rsa_suites/1, filter/3, filter_suites/1, filter_suites/2, hash_algorithm/1, sign_algorithm/1, is_acceptable_hash/2, is_fallback/1, - random_bytes/1, calc_mac_hash/4, + random_bytes/1, calc_mac_hash/4, calc_mac_hash/6, is_stream_ciphersuite/1]). -compile(inline). @@ -97,7 +97,8 @@ cipher_init(?AES_GCM, IV, Key) -> cipher_init(?CHACHA20_POLY1305, IV, Key) -> #cipher_state{iv = IV, key = Key, tag_len = 16}; cipher_init(_BCA, IV, Key) -> - #cipher_state{iv = IV, key = Key}. + %% Initialize random IV cache, not used for aead ciphers + #cipher_state{iv = IV, key = Key, state = <<>>}. nonce_seed(Seed, CipherState) -> CipherState#cipher_state{nonce = Seed}. @@ -112,12 +113,11 @@ nonce_seed(Seed, CipherState) -> %% data is calculated and the data plus the HMAC is ecncrypted. %%------------------------------------------------------------------- cipher(?NULL, CipherState, <<>>, Fragment, _Version) -> - GenStreamCipherList = [Fragment, <<>>], - {GenStreamCipherList, CipherState}; + {iolist_to_binary(Fragment), CipherState}; cipher(?RC4, CipherState = #cipher_state{state = State0}, Mac, Fragment, _Version) -> GenStreamCipherList = [Fragment, Mac], {State1, T} = crypto:stream_encrypt(State0, GenStreamCipherList), - {T, CipherState#cipher_state{state = State1}}; + {iolist_to_binary(T), CipherState#cipher_state{state = State1}}; cipher(?DES, CipherState, Mac, Fragment, Version) -> block_cipher(fun(Key, IV, T) -> crypto:block_encrypt(des_cbc, Key, IV, T) @@ -146,8 +146,7 @@ aead_type(?CHACHA20_POLY1305) -> build_cipher_block(BlockSz, Mac, Fragment) -> TotSz = byte_size(Mac) + erlang:iolist_size(Fragment) + 1, - {PaddingLength, Padding} = get_padding(TotSz, BlockSz), - [Fragment, Mac, PaddingLength, Padding]. + [Fragment, Mac, padding_with_len(TotSz, BlockSz)]. block_cipher(Fun, BlockSz, #cipher_state{key=Key, iv=IV} = CS0, Mac, Fragment, {3, N}) @@ -157,14 +156,21 @@ block_cipher(Fun, BlockSz, #cipher_state{key=Key, iv=IV} = CS0, NextIV = next_iv(T, IV), {T, CS0#cipher_state{iv=NextIV}}; -block_cipher(Fun, BlockSz, #cipher_state{key=Key, iv=IV} = CS0, +block_cipher(Fun, BlockSz, #cipher_state{key=Key, iv=IV, state = IV_Cache0} = CS0, Mac, Fragment, {3, N}) when N == 2; N == 3 -> - NextIV = random_iv(IV), + IV_Size = byte_size(IV), + <<NextIV:IV_Size/binary, IV_Cache/binary>> = + case IV_Cache0 of + <<>> -> + random_bytes(IV_Size bsl 5); % 32 IVs + _ -> + IV_Cache0 + end, L0 = build_cipher_block(BlockSz, Mac, Fragment), L = [NextIV|L0], T = Fun(Key, IV, L), - {T, CS0#cipher_state{iv=NextIV}}. + {T, CS0#cipher_state{iv=NextIV, state = IV_Cache}}. %%-------------------------------------------------------------------- -spec decipher(cipher_enum(), integer(), #cipher_state{}, binary(), @@ -633,12 +639,13 @@ random_bytes(N) -> calc_mac_hash(Type, Version, PlainFragment, #{sequence_number := SeqNo, mac_secret := MacSecret, - security_parameters:= - SecPars}) -> + security_parameters := + #security_parameters{mac_algorithm = MacAlgorithm}}) -> + calc_mac_hash(Type, Version, PlainFragment, MacAlgorithm, MacSecret, SeqNo). +%% +calc_mac_hash(Type, Version, PlainFragment, MacAlgorithm, MacSecret, SeqNo) -> Length = erlang:iolist_size(PlainFragment), - mac_hash(Version, SecPars#security_parameters.mac_algorithm, - MacSecret, SeqNo, Type, - Length, PlainFragment). + mac_hash(Version, MacAlgorithm, MacSecret, SeqNo, Type, Length, PlainFragment). is_stream_ciphersuite(#{cipher := rc4_128}) -> true; @@ -722,7 +729,6 @@ expanded_key_material(Cipher) when Cipher == aes_128_cbc; Cipher == chacha20_poly1305 -> unknown. - effective_key_bits(null) -> 0; effective_key_bits(des_cbc) -> @@ -742,18 +748,15 @@ iv_size(Cipher) when Cipher == null; Cipher == rc4_128; Cipher == chacha20_poly1305-> 0; - iv_size(Cipher) when Cipher == aes_128_gcm; Cipher == aes_256_gcm -> 4; - iv_size(Cipher) -> block_size(Cipher). block_size(Cipher) when Cipher == des_cbc; Cipher == '3des_ede_cbc' -> 8; - block_size(Cipher) when Cipher == aes_128_cbc; Cipher == aes_256_cbc; Cipher == aes_128_gcm; @@ -888,21 +891,51 @@ is_correct_padding(GenBlockCipher, {3, 1}, false) -> %% Padding must be checked in TLS 1.1 and after is_correct_padding(#generic_block_cipher{padding_length = Len, padding = Padding}, _, _) -> - Len == byte_size(Padding) andalso - binary:copy(?byte(Len), Len) == Padding. - -get_padding(Length, BlockSize) -> - get_padding_aux(BlockSize, Length rem BlockSize). - -get_padding_aux(_, 0) -> - {0, <<>>}; -get_padding_aux(BlockSize, PadLength) -> - N = BlockSize - PadLength, - {N, binary:copy(?byte(N), N)}. + (Len == byte_size(Padding)) andalso (padding(Len) == Padding). + +padding(PadLen) -> + case PadLen of + 0 -> <<>>; + 1 -> <<1>>; + 2 -> <<2,2>>; + 3 -> <<3,3,3>>; + 4 -> <<4,4,4,4>>; + 5 -> <<5,5,5,5,5>>; + 6 -> <<6,6,6,6,6,6>>; + 7 -> <<7,7,7,7,7,7,7>>; + 8 -> <<8,8,8,8,8,8,8,8>>; + 9 -> <<9,9,9,9,9,9,9,9,9>>; + 10 -> <<10,10,10,10,10,10,10,10,10,10>>; + 11 -> <<11,11,11,11,11,11,11,11,11,11,11>>; + 12 -> <<12,12,12,12,12,12,12,12,12,12,12,12>>; + 13 -> <<13,13,13,13,13,13,13,13,13,13,13,13,13>>; + 14 -> <<14,14,14,14,14,14,14,14,14,14,14,14,14,14>>; + 15 -> <<15,15,15,15,15,15,15,15,15,15,15,15,15,15,15>>; + _ -> + binary:copy(<<PadLen>>, PadLen) + end. -random_iv(IV) -> - IVSz = byte_size(IV), - random_bytes(IVSz). +padding_with_len(TextLen, BlockSize) -> + case BlockSize - (TextLen rem BlockSize) of + 0 -> <<0>>; + 1 -> <<1,1>>; + 2 -> <<2,2,2>>; + 3 -> <<3,3,3,3>>; + 4 -> <<4,4,4,4,4>>; + 5 -> <<5,5,5,5,5,5>>; + 6 -> <<6,6,6,6,6,6,6>>; + 7 -> <<7,7,7,7,7,7,7,7>>; + 8 -> <<8,8,8,8,8,8,8,8,8>>; + 9 -> <<9,9,9,9,9,9,9,9,9,9>>; + 10 -> <<10,10,10,10,10,10,10,10,10,10,10>>; + 11 -> <<11,11,11,11,11,11,11,11,11,11,11,11>>; + 12 -> <<12,12,12,12,12,12,12,12,12,12,12,12,12>>; + 13 -> <<13,13,13,13,13,13,13,13,13,13,13,13,13,13>>; + 14 -> <<14,14,14,14,14,14,14,14,14,14,14,14,14,14,14>>; + 15 -> <<15,15,15,15,15,15,15,15,15,15,15,15,15,15,15,15>>; + PadLen -> + binary:copy(<<PadLen>>, PadLen + 1) + end. next_iv(Bin, IV) -> BinSz = byte_size(Bin), diff --git a/lib/ssl/src/ssl_connection.erl b/lib/ssl/src/ssl_connection.erl index 6206d15c13..9e037313bb 100644 --- a/lib/ssl/src/ssl_connection.erl +++ b/lib/ssl/src/ssl_connection.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2013-2018. All Rights Reserved. +%% Copyright Ericsson AB 2013-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -70,7 +70,7 @@ -export([terminate/3, format_status/2]). %% Erlang Distribution export --export([get_sslsocket/1, dist_handshake_complete/2]). +-export([dist_handshake_complete/2]). %%==================================================================== %% Setup @@ -182,19 +182,19 @@ socket_control(Connection, Socket, Pid, Transport) -> %%-------------------------------------------------------------------- socket_control(Connection, Socket, Pids, Transport, udp_listener) -> %% dtls listener process must have the socket control - {ok, Connection:socket(Pids, Transport, Socket, Connection, undefined)}; + {ok, Connection:socket(Pids, Transport, Socket, undefined)}; socket_control(tls_connection = Connection, Socket, [Pid|_] = Pids, Transport, ListenTracker) -> case Transport:controlling_process(Socket, Pid) of ok -> - {ok, Connection:socket(Pids, Transport, Socket, Connection, ListenTracker)}; + {ok, Connection:socket(Pids, Transport, Socket, ListenTracker)}; {error, Reason} -> {error, Reason} end; socket_control(dtls_connection = Connection, {_, Socket}, [Pid|_] = Pids, Transport, ListenTracker) -> case Transport:controlling_process(Socket, Pid) of ok -> - {ok, Connection:socket(Pids, Transport, Socket, Connection, ListenTracker)}; + {ok, Connection:socket(Pids, Transport, Socket, ListenTracker)}; {error, Reason} -> {error, Reason} end. @@ -211,9 +211,9 @@ socket_control(dtls_connection = Connection, {_, Socket}, [Pid|_] = Pids, Transp %%-------------------------------------------------------------------- send(Pid, Data) -> call(Pid, {application_data, - %% iolist_to_binary should really - %% be called iodata_to_binary() - erlang:iolist_to_binary(Data)}). + %% iolist_to_iovec should really + %% be called iodata_to_iovec() + erlang:iolist_to_iovec(Data)}). %%-------------------------------------------------------------------- -spec recv(pid(), integer(), timeout()) -> @@ -311,9 +311,6 @@ renegotiation(ConnectionPid) -> internal_renegotiation(ConnectionPid, #{current_write := WriteState}) -> gen_statem:cast(ConnectionPid, {internal_renegotiate, WriteState}). -get_sslsocket(ConnectionPid) -> - call(ConnectionPid, get_sslsocket). - dist_handshake_complete(ConnectionPid, DHandle) -> gen_statem:cast(ConnectionPid, {dist_handshake_complete, DHandle}). @@ -442,9 +439,9 @@ handle_alert(#alert{level = ?WARNING} = Alert, StateName, %%==================================================================== %% Data handling %%==================================================================== -passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName, Connection, StartTimerAction) -> - case Buffer of - <<>> -> +passive_receive(State0 = #state{user_data_buffer = {_,BufferSize,_}}, StateName, Connection, StartTimerAction) -> + case BufferSize of + 0 -> {Record, State} = Connection:next_record(State0), Connection:next_event(StateName, Record, State, StartTimerAction); _ -> @@ -466,101 +463,227 @@ passive_receive(State0 = #state{user_data_buffer = Buffer}, StateName, Connectio read_application_data( Data, #state{ - user_data_buffer = Buffer0, + user_data_buffer = {Front0,BufferSize0,Rear0}, connection_env = #connection_env{erl_dist_handle = DHandle}} = State) -> %% - Buffer = bincat(Buffer0, Data), + Front = Front0, + BufferSize = BufferSize0 + byte_size(Data), + Rear = [Data|Rear0], case DHandle of undefined -> - #state{ - socket_options = SocketOpts, - bytes_to_read = BytesToRead, - start_or_recv_from = RecvFrom} = State, - read_application_data( - Buffer, State, SocketOpts, RecvFrom, BytesToRead); + read_application_data(State, Front, BufferSize, Rear); _ -> - try read_application_dist_data(Buffer, State, DHandle) + try read_application_dist_data(DHandle, Front, BufferSize, Rear) of + Buffer -> + {no_record, State#state{user_data_buffer = Buffer}} catch error:_ -> {stop,disconnect, - State#state{ - user_data_buffer = Buffer, - bytes_to_read = undefined}} + State#state{user_data_buffer = {Front,BufferSize,Rear}}} end end. -read_application_dist_data(Buffer, State, DHandle) -> - case Buffer of - <<Size:32,Data:Size/binary>> -> - erlang:dist_ctrl_put_data(DHandle, Data), + +read_application_data(#state{ + socket_options = SocketOpts, + bytes_to_read = BytesToRead, + start_or_recv_from = RecvFrom} = State, Front, BufferSize, Rear) -> + read_application_data(State, Front, BufferSize, Rear, SocketOpts, RecvFrom, BytesToRead). + +%% Pick binary from queue front, if empty wait for more data +read_application_data(State, [Bin|Front], BufferSize, Rear, SocketOpts, RecvFrom, BytesToRead) -> + read_application_data_bin(State, Front, BufferSize, Rear, SocketOpts, RecvFrom, BytesToRead, Bin); +read_application_data(State, [] = Front, BufferSize, [] = Rear, SocketOpts, RecvFrom, BytesToRead) -> + 0 = BufferSize, % Assert + {no_record, State#state{socket_options = SocketOpts, + bytes_to_read = BytesToRead, + start_or_recv_from = RecvFrom, + user_data_buffer = {Front,BufferSize,Rear}}}; +read_application_data(State, [], BufferSize, Rear, SocketOpts, RecvFrom, BytesToRead) -> + [Bin|Front] = lists:reverse(Rear), + read_application_data_bin(State, Front, BufferSize, [], SocketOpts, RecvFrom, BytesToRead, Bin). + +read_application_data_bin(State, Front, BufferSize, Rear, SocketOpts, RecvFrom, BytesToRead, <<>>) -> + %% Done with this binary - get next + read_application_data(State, Front, BufferSize, Rear, SocketOpts, RecvFrom, BytesToRead); +read_application_data_bin(State, Front0, BufferSize0, Rear0, SocketOpts0, RecvFrom, BytesToRead, Bin0) -> + %% Decode one packet from a binary + case get_data(SocketOpts0, BytesToRead, Bin0) of + {ok, Data, Bin} -> % Send data + BufferSize = BufferSize0 - (byte_size(Bin0) - byte_size(Bin)), + read_application_data_deliver( + State, [Bin|Front0], BufferSize, Rear0, SocketOpts0, RecvFrom, Data); + {more, undefined} -> + %% We need more data, do not know how much + if + byte_size(Bin0) < BufferSize0 -> + %% We have more data in the buffer besides the first binary - concatenate all and retry + Bin = iolist_to_binary([Bin0,Front0|lists:reverse(Rear0)]), + read_application_data_bin( + State, [], BufferSize0, [], SocketOpts0, RecvFrom, BytesToRead, Bin); + true -> + %% All data is in the first binary, no use to retry - wait for more + {no_record, State#state{socket_options = SocketOpts0, + bytes_to_read = BytesToRead, + start_or_recv_from = RecvFrom, + user_data_buffer = {[Bin0|Front0],BufferSize0,Rear0}}} + end; + {more, Size} when Size =< BufferSize0 -> + %% We have a packet in the buffer - collect it in a binary and decode + {Data,Front,Rear} = iovec_from_front(Size - byte_size(Bin0), Front0, Rear0, [Bin0]), + Bin = iolist_to_binary(Data), + read_application_data_bin( + State, Front, BufferSize0, Rear, SocketOpts0, RecvFrom, BytesToRead, Bin); + {more, _Size} -> + %% We do not have a packet in the buffer - wait for more + {no_record, State#state{socket_options = SocketOpts0, + bytes_to_read = BytesToRead, + start_or_recv_from = RecvFrom, + user_data_buffer = {[Bin0|Front0],BufferSize0,Rear0}}}; + passive -> + {no_record, State#state{socket_options = SocketOpts0, + bytes_to_read = BytesToRead, + start_or_recv_from = RecvFrom, + user_data_buffer = {[Bin0|Front0],BufferSize0,Rear0}}}; + {error,_Reason} -> + %% Invalid packet in packet mode + #state{ + static_env = + #static_env{ + socket = Socket, + protocol_cb = Connection, + transport_cb = Transport, + tracker = Tracker}, + connection_env = + #connection_env{user_application = {_Mon, Pid}}} = State, + Buffer = iolist_to_binary([Bin0,Front0|lists:reverse(Rear0)]), + deliver_packet_error( + Connection:pids(State), Transport, Socket, SocketOpts0, + Buffer, Pid, RecvFrom, Tracker, Connection), + {stop, {shutdown, normal}, State#state{socket_options = SocketOpts0, + bytes_to_read = BytesToRead, + start_or_recv_from = RecvFrom, + user_data_buffer = {[Buffer],BufferSize0,[]}}} + end. + +read_application_data_deliver(State, Front, BufferSize, Rear, SocketOpts0, RecvFrom, Data) -> + #state{ + static_env = + #static_env{ + socket = Socket, + protocol_cb = Connection, + transport_cb = Transport, + tracker = Tracker}, + connection_env = + #connection_env{user_application = {_Mon, Pid}}} = State, + SocketOpts = + deliver_app_data( + Connection:pids(State), Transport, Socket, SocketOpts0, Data, Pid, RecvFrom, Tracker, Connection), + if + SocketOpts#socket_options.active =:= false -> + %% Passive mode, wait for active once or recv {no_record, State#state{ - user_data_buffer = <<>>, - bytes_to_read = undefined}}; - <<Size:32,Data:Size/binary,Rest/binary>> -> + user_data_buffer = {Front,BufferSize,Rear}, + start_or_recv_from = undefined, + bytes_to_read = undefined, + socket_options = SocketOpts + }}; + true -> %% Try to deliver more data + read_application_data(State, Front, BufferSize, Rear, SocketOpts, undefined, undefined) + end. + + +read_application_dist_data(DHandle, [Bin|Front], BufferSize, Rear) -> + read_application_dist_data(DHandle, Front, BufferSize, Rear, Bin); +read_application_dist_data(_DHandle, [] = Front, BufferSize, [] = Rear) -> + BufferSize = 0, + {Front,BufferSize,Rear}; +read_application_dist_data(DHandle, [], BufferSize, Rear) -> + [Bin|Front] = lists:reverse(Rear), + read_application_dist_data(DHandle, Front, BufferSize, [], Bin). +%% +read_application_dist_data(DHandle, Front0, BufferSize, Rear0, Bin0) -> + case Bin0 of + %% + %% START Optimization + %% It is cheaper to match out several packets in one match operation than to loop for each + <<SizeA:32, DataA:SizeA/binary, + SizeB:32, DataB:SizeB/binary, + SizeC:32, DataC:SizeC/binary, + SizeD:32, DataD:SizeD/binary, Rest/binary>> -> + %% We have 4 complete packets in the first binary + erlang:dist_ctrl_put_data(DHandle, DataA), + erlang:dist_ctrl_put_data(DHandle, DataB), + erlang:dist_ctrl_put_data(DHandle, DataC), + erlang:dist_ctrl_put_data(DHandle, DataD), + read_application_dist_data( + DHandle, Front0, BufferSize - (4*4+SizeA+SizeB+SizeC+SizeD), Rear0, Rest); + <<SizeA:32, DataA:SizeA/binary, + SizeB:32, DataB:SizeB/binary, + SizeC:32, DataC:SizeC/binary, Rest/binary>> -> + %% We have 3 complete packets in the first binary + erlang:dist_ctrl_put_data(DHandle, DataA), + erlang:dist_ctrl_put_data(DHandle, DataB), + erlang:dist_ctrl_put_data(DHandle, DataC), + read_application_dist_data( + DHandle, Front0, BufferSize - (3*4+SizeA+SizeB+SizeC), Rear0, Rest); + <<SizeA:32, DataA:SizeA/binary, + SizeB:32, DataB:SizeB/binary, Rest/binary>> -> + %% We have 2 complete packets in the first binary + erlang:dist_ctrl_put_data(DHandle, DataA), + erlang:dist_ctrl_put_data(DHandle, DataB), + read_application_dist_data( + DHandle, Front0, BufferSize - (2*4+SizeA+SizeB), Rear0, Rest); + %% END Optimization + %% + %% Basic one packet code path + <<Size:32, Data:Size/binary, Rest/binary>> -> + %% We have a complete packet in the first binary erlang:dist_ctrl_put_data(DHandle, Data), - read_application_dist_data(Rest, State, DHandle); - _ -> - {no_record, - State#state{ - user_data_buffer = Buffer, - bytes_to_read = undefined}} + read_application_dist_data(DHandle, Front0, BufferSize - (4+Size), Rear0, Rest); + <<Size:32, FirstData/binary>> when 4+Size =< BufferSize -> + %% We have a complete packet in the buffer + %% - fetch the missing content from the buffer front + {Data,Front,Rear} = iovec_from_front(Size - byte_size(FirstData), Front0, Rear0, [FirstData]), + erlang:dist_ctrl_put_data(DHandle, Data), + read_application_dist_data(DHandle, Front, BufferSize - (4+Size), Rear); + <<Bin/binary>> -> + %% In OTP-21 the match context reuse optimization fails if we use Bin0 in recursion, so here we + %% match out the whole binary which will trick the optimization into keeping the match context + %% for the first binary contains complete packet code above + case Bin of + <<_Size:32, _InsufficientData/binary>> -> + %% We have a length field in the first binary but there is not enough data + %% in the buffer to form a complete packet - await more data + {[Bin|Front0],BufferSize,Rear0}; + <<IncompleteLengthField/binary>> when 4 < BufferSize -> + %% We do not have a length field in the first binary but the buffer + %% contains enough data to maybe form a packet + %% - fetch a tiny binary from the buffer front to complete the length field + {LengthField,Front,Rear} = + iovec_from_front(4 - byte_size(IncompleteLengthField), Front0, Rear0, [IncompleteLengthField]), + LengthBin = iolist_to_binary(LengthField), + read_application_dist_data(DHandle, Front, BufferSize, Rear, LengthBin); + <<IncompleteLengthField/binary>> -> + %% We do not have enough data in the buffer to even form a length field - await more data + {[IncompleteLengthField|Front0],BufferSize,Rear0} + end end. -read_application_data( - Buffer0, State, SocketOpts0, RecvFrom, BytesToRead) -> - %% - case get_data(SocketOpts0, BytesToRead, Buffer0) of - {ok, ClientData, Buffer} -> % Send data - #state{static_env = - #static_env{ - socket = Socket, - protocol_cb = Connection, - transport_cb = Transport, - tracker = Tracker}, - connection_env = - #connection_env{user_application = {_Mon, Pid}}} - = State, - SocketOpts = - deliver_app_data( - Connection:pids(State), - Transport, Socket, SocketOpts0, - ClientData, Pid, RecvFrom, Tracker, Connection), - if - SocketOpts#socket_options.active =:= false -> - %% Passive mode, wait for active once or recv - %% Active and empty, get more data - {no_record, - State#state{ - user_data_buffer = Buffer, - start_or_recv_from = undefined, - bytes_to_read = undefined, - socket_options = SocketOpts - }}; - true -> %% We have more data - read_application_data( - Buffer, State, SocketOpts, - undefined, undefined) - end; - {more, Buffer} -> % no reply, we need more data - {no_record, State#state{user_data_buffer = Buffer}}; - {passive, Buffer} -> - {no_record, State#state{user_data_buffer = Buffer}}; - {error,_Reason} -> %% Invalid packet in packet mode - #state{static_env = - #static_env{ - socket = Socket, - protocol_cb = Connection, - transport_cb = Transport, - tracker = Tracker}, - connection_env = - #connection_env{user_application = {_Mon, Pid}}} - = State, - deliver_packet_error( - Connection:pids(State), Transport, Socket, SocketOpts0, - Buffer0, Pid, RecvFrom, Tracker, Connection), - {stop, {shutdown, normal}, State} +iovec_from_front(Size, [], Rear, Acc) -> + iovec_from_front(Size, lists:reverse(Rear), [], Acc); +iovec_from_front(Size, [Bin|Front], Rear, Acc) -> + case Bin of + <<Last:Size/binary>> -> % Just enough + {lists:reverse(Acc, [Last]),Front,Rear}; + <<Last:Size/binary, Rest/binary>> -> % More than enough, split here + {lists:reverse(Acc, [Last]),[Rest|Front],Rear}; + <<_/binary>> -> % Not enough + BinSize = byte_size(Bin), + iovec_from_front(Size - BinSize, Front, Rear, [Bin|Acc]) end. + %%==================================================================== %% Help functions for tls|dtls_connection.erl %%==================================================================== @@ -1268,10 +1391,6 @@ handle_call({set_opts, Opts0}, From, StateName, handle_call(renegotiate, From, StateName, _, _) when StateName =/= connection -> {keep_state_and_data, [{reply, From, {error, already_renegotiating}}]}; -handle_call(get_sslsocket, From, _StateName, State, Connection) -> - SslSocket = Connection:socket(State), - {keep_state_and_data, [{reply, From, SslSocket}]}; - handle_call({prf, Secret, Label, Seed, WantedLength}, From, _, #state{connection_states = ConnectionStates, connection_env = #connection_env{negotiated_version = Version}}, _) -> @@ -2537,7 +2656,7 @@ handle_active_option(false, connection = StateName, To, Reply, State) -> hibernate_after(StateName, State, [{reply, To, Reply}]); handle_active_option(_, connection = StateName0, To, Reply, #state{static_env = #static_env{protocol_cb = Connection}, - user_data_buffer = <<>>} = State0) -> + user_data_buffer = {_,0,_}} = State0) -> case Connection:next_event(StateName0, no_record, State0) of {next_state, StateName, State} -> hibernate_after(StateName, State, [{reply, To, Reply}]); @@ -2546,11 +2665,11 @@ handle_active_option(_, connection = StateName0, To, Reply, #state{static_env = {stop, _, _} = Stop -> Stop end; -handle_active_option(_, StateName, To, Reply, #state{user_data_buffer = <<>>} = State) -> +handle_active_option(_, StateName, To, Reply, #state{user_data_buffer = {_,0,_}} = State) -> %% Active once already set {next_state, StateName, State, [{reply, To, Reply}]}; -%% user_data_buffer =/= <<>> +%% user_data_buffer nonempty handle_active_option(_, StateName0, To, Reply, #state{static_env = #static_env{protocol_cb = Connection}} = State0) -> case read_application_data(<<>>, State0) of @@ -2570,33 +2689,25 @@ handle_active_option(_, StateName0, To, Reply, %% Picks ClientData -get_data(_, _, <<>>) -> - {more, <<>>}; -%% Recv timed out save buffer data until next recv -get_data(#socket_options{active=false}, undefined, Buffer) -> - {passive, Buffer}; -get_data(#socket_options{active=Active, packet=Raw}, BytesToRead, Buffer) +get_data(#socket_options{active=false}, undefined, _Bin) -> + %% Recv timed out save buffer data until next recv + passive; +get_data(#socket_options{active=Active, packet=Raw}, BytesToRead, Bin) when Raw =:= raw; Raw =:= 0 -> %% Raw Mode - if - Active =/= false orelse BytesToRead =:= 0 -> + case Bin of + <<_/binary>> when Active =/= false orelse BytesToRead =:= 0 -> %% Active true or once, or passive mode recv(0) - {ok, Buffer, <<>>}; - byte_size(Buffer) >= BytesToRead -> + {ok, Bin, <<>>}; + <<Data:BytesToRead/binary, Rest/binary>> -> %% Passive Mode, recv(Bytes) - <<Data:BytesToRead/binary, Rest/binary>> = Buffer, - {ok, Data, Rest}; - true -> + {ok, Data, Rest}; + <<_/binary>> -> %% Passive Mode not enough data - {more, Buffer} + {more, BytesToRead} end; -get_data(#socket_options{packet=Type, packet_size=Size}, _, Buffer) -> +get_data(#socket_options{packet=Type, packet_size=Size}, _, Bin) -> PacketOpts = [{packet_size, Size}], - case decode_packet(Type, Buffer, PacketOpts) of - {more, _} -> - {more, Buffer}; - Decoded -> - Decoded - end. + decode_packet(Type, Bin, PacketOpts). decode_packet({http, headers}, Buffer, PacketOpts) -> decode_packet(httph, Buffer, PacketOpts); @@ -2648,7 +2759,7 @@ format_reply(_, _, _,#socket_options{active = false, mode = Mode, packet = Packe {ok, do_format_reply(Mode, Packet, Header, Data)}; format_reply(CPids, Transport, Socket, #socket_options{active = _, mode = Mode, packet = Packet, header = Header}, Data, Tracker, Connection) -> - {ssl, Connection:socket(CPids, Transport, Socket, Connection, Tracker), + {ssl, Connection:socket(CPids, Transport, Socket, Tracker), do_format_reply(Mode, Packet, Header, Data)}. deliver_packet_error(CPids, Transport, Socket, @@ -2660,7 +2771,7 @@ format_packet_error(_, _, _,#socket_options{active = false, mode = Mode}, Data, {error, {invalid_packet, do_format_reply(Mode, raw, 0, Data)}}; format_packet_error(CPids, Transport, Socket, #socket_options{active = _, mode = Mode}, Data, Tracker, Connection) -> - {ssl_error, Connection:socket(CPids, Transport, Socket, Connection, Tracker), + {ssl_error, Connection:socket(CPids, Transport, Socket, Tracker), {invalid_packet, do_format_reply(Mode, raw, 0, Data)}}. do_format_reply(binary, _, N, Data) when N > 0 -> % Header mode @@ -2716,12 +2827,10 @@ alert_user(Pids, Transport, Tracker, Socket, Active, Pid, From, Alert, Role, Con case ssl_alert:reason_code(Alert, Role) of closed -> send_or_reply(Active, Pid, From, - {ssl_closed, Connection:socket(Pids, - Transport, Socket, Connection, Tracker)}); + {ssl_closed, Connection:socket(Pids, Transport, Socket, Tracker)}); ReasonCode -> send_or_reply(Active, Pid, From, - {ssl_error, Connection:socket(Pids, - Transport, Socket, Connection, Tracker), ReasonCode}) + {ssl_error, Connection:socket(Pids, Transport, Socket, Tracker), ReasonCode}) end. log_alert(true, Role, ProtocolName, StateName, #alert{role = Role} = Alert) -> @@ -2793,11 +2902,3 @@ new_emulated([], EmOpts) -> EmOpts; new_emulated(NewEmOpts, _) -> NewEmOpts. - --compile({inline, [bincat/2]}). -bincat(<<>>, B) -> - B; -bincat(A, <<>>) -> - A; -bincat(A, B) -> - <<A/binary, B/binary>>. diff --git a/lib/ssl/src/ssl_connection.hrl b/lib/ssl/src/ssl_connection.hrl index 83013e7fba..b6b23701bb 100644 --- a/lib/ssl/src/ssl_connection.hrl +++ b/lib/ssl/src/ssl_connection.hrl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2013-2018. All Rights Reserved. +%% Copyright Ericsson AB 2013-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -83,7 +83,7 @@ downgrade, terminated = false ::boolean() | closed, negotiated_version :: ssl_record:ssl_version() | 'undefined', - erl_dist_handle = undefined :: erlang:dist_handle() | undefined, + erl_dist_handle = undefined :: erlang:dist_handle() | 'undefined', private_key :: public_key:private_key() | secret_printout() | 'undefined' }). @@ -109,7 +109,7 @@ %% Data shuffling %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% connection_states :: ssl_record:connection_states() | secret_printout(), protocol_buffers :: term() | secret_printout() , %% #protocol_buffers{} from tls_record.hrl or dtls_recor.hr - user_data_buffer :: undefined | binary() | secret_printout(), + user_data_buffer :: undefined | {[binary()],non_neg_integer(),[binary()]} | secret_printout(), bytes_to_read :: undefined | integer(), %% bytes to read in passive mode %% recv and start handling diff --git a/lib/ssl/src/ssl_record.erl b/lib/ssl/src/ssl_record.erl index b9d1320ef3..1a36b2dba8 100644 --- a/lib/ssl/src/ssl_record.erl +++ b/lib/ssl/src/ssl_record.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2013-2018. All Rights Reserved. +%% Copyright Ericsson AB 2013-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -45,14 +45,16 @@ -export([compress/3, uncompress/3, compressions/0]). %% Payload encryption/decryption --export([cipher/4, decipher/4, cipher_aead/4, decipher_aead/5, is_correct_mac/2, nonce_seed/3]). +-export([cipher/4, cipher/5, decipher/4, + cipher_aead/4, cipher_aead/5, decipher_aead/5, + is_correct_mac/2, nonce_seed/3]). -export_type([ssl_version/0, ssl_atom_version/0, connection_states/0, connection_state/0]). -type ssl_version() :: {integer(), integer()}. -type ssl_atom_version() :: tls_record:tls_atom_version(). --type connection_states() :: term(). %% Map --type connection_state() :: term(). %% Map +-type connection_states() :: map(). %% Map +-type connection_state() :: map(). %% Map %%==================================================================== %% Connection state handling @@ -302,27 +304,49 @@ cipher(Version, Fragment, #security_parameters{bulk_cipher_algorithm = BulkCipherAlgo} } = WriteState0, MacHash) -> - + %% {CipherFragment, CipherS1} = ssl_cipher:cipher(BulkCipherAlgo, CipherS0, MacHash, Fragment, Version), {CipherFragment, WriteState0#{cipher_state => CipherS1}}. + +%%-------------------------------------------------------------------- +-spec cipher(ssl_version(), iodata(), #cipher_state{}, MacHash::binary(), #security_parameters{}) -> + {CipherFragment::binary(), #cipher_state{}}. +%% +%% Description: Payload encryption +%%-------------------------------------------------------------------- +cipher(Version, Fragment, CipherS0, MacHash, + #security_parameters{bulk_cipher_algorithm = BulkCipherAlgo}) -> + %% + ssl_cipher:cipher(BulkCipherAlgo, CipherS0, MacHash, Fragment, Version). + %%-------------------------------------------------------------------- -spec cipher_aead(ssl_version(), iodata(), connection_state(), AAD::binary()) -> {CipherFragment::binary(), connection_state()}. %% Description: Payload encryption %% %%-------------------------------------------------------------------- -cipher_aead(Version, Fragment, +cipher_aead(_Version, Fragment, #{cipher_state := CipherS0, security_parameters := #security_parameters{bulk_cipher_algorithm = BulkCipherAlgo} } = WriteState0, AAD) -> {CipherFragment, CipherS1} = - cipher_aead(BulkCipherAlgo, CipherS0, AAD, Fragment, Version), + do_cipher_aead(BulkCipherAlgo, Fragment, CipherS0, AAD), {CipherFragment, WriteState0#{cipher_state => CipherS1}}. %%-------------------------------------------------------------------- +-spec cipher_aead(ssl_version(), iodata(), #cipher_state{}, AAD::binary(), #security_parameters{}) -> + {CipherFragment::binary(), #cipher_state{}}. + +%% Description: Payload encryption +%% %%-------------------------------------------------------------------- +cipher_aead(_Version, Fragment, CipherS, AAD, + #security_parameters{bulk_cipher_algorithm = BulkCipherAlgo}) -> + do_cipher_aead(BulkCipherAlgo, Fragment, CipherS, AAD). + +%%-------------------------------------------------------------------- -spec decipher(ssl_version(), binary(), connection_state(), boolean()) -> {binary(), binary(), connection_state()} | #alert{}. %% @@ -343,9 +367,8 @@ decipher(Version, CipherFragment, Alert end. %%-------------------------------------------------------------------- --spec decipher_aead(ssl_cipher:cipher_enum(), #cipher_state{}, - binary(), binary(), ssl_record:ssl_version()) -> - {binary(), #cipher_state{}} | #alert{}. +-spec decipher_aead(ssl_cipher:cipher_enum(), #cipher_state{}, binary(), binary(), ssl_record:ssl_version()) -> + binary() | #alert{}. %% %% Description: Decrypts the data and checks the associated data (AAD) MAC using %% cipher described by cipher_enum() and updating the cipher state. @@ -357,7 +380,7 @@ decipher_aead(Type, #cipher_state{key = Key} = CipherState, AAD0, CipherFragment {AAD, CipherText, CipherTag} = aead_ciphertext_split(Type, CipherState, CipherFragment, AAD0), case ssl_cipher:aead_decrypt(Type, Key, Nonce, CipherText, CipherTag, AAD) of Content when is_binary(Content) -> - {Content, CipherState}; + Content; _ -> ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC, decryption_failed) end @@ -399,11 +422,13 @@ random() -> Random_28_bytes = ssl_cipher:random_bytes(28), <<?UINT32(Secs_since_1970), Random_28_bytes/binary>>. +-compile({inline, [is_correct_mac/2]}). is_correct_mac(Mac, Mac) -> true; is_correct_mac(_M,_H) -> false. +-compile({inline, [record_protocol_role/1]}). record_protocol_role(client) -> ?CLIENT; record_protocol_role(server) -> @@ -427,13 +452,15 @@ initial_security_params(ConnectionEnd) -> compression_algorithm = ?NULL}, ssl_cipher:security_parameters(?TLS_NULL_WITH_NULL_NULL, SecParams). -cipher_aead(?CHACHA20_POLY1305 = Type, #cipher_state{key=Key} = CipherState, AAD0, Fragment, _Version) -> - AAD = end_additional_data(AAD0, erlang:iolist_size(Fragment)), +-define(end_additional_data(AAD, Len), << (begin(AAD)end)/binary, ?UINT16(begin(Len)end) >>). + +do_cipher_aead(?CHACHA20_POLY1305 = Type, Fragment, #cipher_state{key=Key} = CipherState, AAD0) -> + AAD = ?end_additional_data(AAD0, erlang:iolist_size(Fragment)), Nonce = encrypt_nonce(Type, CipherState), {Content, CipherTag} = ssl_cipher:aead_encrypt(Type, Key, Nonce, Fragment, AAD), {<<Content/binary, CipherTag/binary>>, CipherState}; -cipher_aead(Type, #cipher_state{key=Key, nonce = ExplicitNonce} = CipherState, AAD0, Fragment, _Version) -> - AAD = end_additional_data(AAD0, erlang:iolist_size(Fragment)), +do_cipher_aead(Type, Fragment, #cipher_state{key=Key, nonce = ExplicitNonce} = CipherState, AAD0) -> + AAD = ?end_additional_data(AAD0, erlang:iolist_size(Fragment)), Nonce = encrypt_nonce(Type, CipherState), {Content, CipherTag} = ssl_cipher:aead_encrypt(Type, Key, Nonce, Fragment, AAD), {<<ExplicitNonce:64/integer, Content/binary, CipherTag/binary>>, CipherState#cipher_state{nonce = ExplicitNonce + 1}}. @@ -449,15 +476,12 @@ decrypt_nonce(?CHACHA20_POLY1305, #cipher_state{nonce = Nonce, iv = IV}, _) -> decrypt_nonce(?AES_GCM, #cipher_state{iv = <<Salt:4/bytes, _/binary>>}, <<ExplicitNonce:8/bytes, _/binary>>) -> <<Salt/binary, ExplicitNonce/binary>>. +-compile({inline, [aead_ciphertext_split/4]}). aead_ciphertext_split(?CHACHA20_POLY1305, #cipher_state{tag_len = Len}, CipherTextFragment, AAD) -> - CipherLen = size(CipherTextFragment) - Len, + CipherLen = byte_size(CipherTextFragment) - Len, <<CipherText:CipherLen/bytes, CipherTag:Len/bytes>> = CipherTextFragment, - {end_additional_data(AAD, CipherLen), CipherText, CipherTag}; + {?end_additional_data(AAD, CipherLen), CipherText, CipherTag}; aead_ciphertext_split(?AES_GCM, #cipher_state{tag_len = Len}, CipherTextFragment, AAD) -> - CipherLen = size(CipherTextFragment) - (Len + 8), %% 8 is length of explicit Nonce + CipherLen = byte_size(CipherTextFragment) - (Len + 8), %% 8 is length of explicit Nonce << _:8/bytes, CipherText:CipherLen/bytes, CipherTag:Len/bytes>> = CipherTextFragment, - {end_additional_data(AAD, CipherLen), CipherText, CipherTag}. - -end_additional_data(AAD, Len) -> - <<AAD/binary, ?UINT16(Len)>>. - + {?end_additional_data(AAD, CipherLen), CipherText, CipherTag}. diff --git a/lib/ssl/src/ssl_record.hrl b/lib/ssl/src/ssl_record.hrl index ed007f58d7..a927fba0de 100644 --- a/lib/ssl/src/ssl_record.hrl +++ b/lib/ssl/src/ssl_record.hrl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2007-2016. All Rights Reserved. +%% Copyright Ericsson AB 2007-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -140,6 +140,8 @@ -define(ALERT, 21). -define(HANDSHAKE, 22). -define(APPLICATION_DATA, 23). +-define(KNOWN_RECORD_TYPE(Type), + (is_integer(Type) andalso (20 =< (Type)) andalso ((Type) =< 23))). -define(MAX_PLAIN_TEXT_LENGTH, 16384). -define(MAX_COMPRESSED_LENGTH, (?MAX_PLAIN_TEXT_LENGTH+1024)). -define(MAX_CIPHER_TEXT_LENGTH, (?MAX_PLAIN_TEXT_LENGTH+2048)). diff --git a/lib/ssl/src/tls_connection.erl b/lib/ssl/src/tls_connection.erl index dfae13f6d7..3229004c9d 100644 --- a/lib/ssl/src/tls_connection.erl +++ b/lib/ssl/src/tls_connection.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2007-2018. All Rights Reserved. +%% Copyright Ericsson AB 2007-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -57,11 +57,10 @@ %% Alert and close handling -export([send_alert/2, send_alert_in_connection/2, send_sync_alert/2, - encode_alert/3, close/5, protocol_name/0]). + close/5, protocol_name/0]). %% Data handling --export([encode_data/3, next_record/1, - send/3, socket/5, setopts/3, getopts/3]). +-export([next_record/1, socket/4, setopts/3, getopts/3]). %% gen_statem state functions -export([init/3, error/3, downgrade/3, %% Initiation and take down states @@ -149,18 +148,10 @@ next_record(#state{handshake_env = {no_record, State#state{handshake_env = HsEnv#handshake_env{unprocessed_handshake_events = N-1}}}; next_record(#state{protocol_buffers = - #protocol_buffers{tls_cipher_texts = [#ssl_tls{type = Type}| _] = CipherTexts0} - = Buffers, - connection_states = ConnectionStates0, + #protocol_buffers{tls_cipher_texts = [_|_] = CipherTexts}, + connection_states = ConnectionStates, ssl_options = #ssl_options{padding_check = Check}} = State) -> - case decode_cipher_texts(Type, CipherTexts0, ConnectionStates0, Check, <<>>) of - {#ssl_tls{} = Record, ConnectionStates, CipherTexts} -> - {Record, State#state{protocol_buffers = Buffers#protocol_buffers{tls_cipher_texts = CipherTexts}, - connection_states = ConnectionStates}}; - {#alert{} = Alert, ConnectionStates, CipherTexts} -> - {Alert, State#state{protocol_buffers = Buffers#protocol_buffers{tls_cipher_texts = CipherTexts}, - connection_states = ConnectionStates}} - end; + next_record(State, CipherTexts, ConnectionStates, Check); next_record(#state{protocol_buffers = #protocol_buffers{tls_cipher_texts = []}, protocol_specific = #{active_n_toggle := true, active_n := N} = ProtocolSpec, static_env = #static_env{socket = Socket, @@ -177,16 +168,48 @@ next_record(#state{protocol_buffers = #protocol_buffers{tls_cipher_texts = []}, next_record(State) -> {no_record, State}. +%% Decipher next record and concatenate consecutive ?APPLICATION_DATA records into one +%% +next_record(State, CipherTexts, ConnectionStates, Check) -> + next_record(State, CipherTexts, ConnectionStates, Check, []). +%% +next_record(State, [#ssl_tls{type = ?APPLICATION_DATA} = CT|CipherTexts], ConnectionStates0, Check, Acc) -> + case tls_record:decode_cipher_text(CT, ConnectionStates0, Check) of + {#ssl_tls{fragment = Fragment}, ConnectionStates} -> + next_record(State, CipherTexts, ConnectionStates, Check, [Fragment|Acc]); + #alert{} = Alert -> + Alert + end; +next_record(State, [CT|CipherTexts], ConnectionStates0, Check, []) -> + case tls_record:decode_cipher_text(CT, ConnectionStates0, Check) of + {Record, ConnectionStates} -> + next_record_done(State, CipherTexts, ConnectionStates, Record); + #alert{} = Alert -> + Alert + end; +next_record(State, CipherTexts, ConnectionStates, _Check, Acc) -> + %% Not ?APPLICATION_DATA but we have a nonempty Acc + %% -> build an ?APPLICATION_DATA record with the accumulated fragments + next_record_done(State, CipherTexts, ConnectionStates, + #ssl_tls{type = ?APPLICATION_DATA, fragment = iolist_to_binary(lists:reverse(Acc))}). + +next_record_done(#state{protocol_buffers = Buffers} = State, CipherTexts, ConnectionStates, Record) -> + {Record, + State#state{protocol_buffers = Buffers#protocol_buffers{tls_cipher_texts = CipherTexts}, + connection_states = ConnectionStates}}. + + next_event(StateName, Record, State) -> next_event(StateName, Record, State, []). +%% next_event(StateName, no_record, State0, Actions) -> case next_record(State0) of {no_record, State} -> {next_state, StateName, State, Actions}; {#ssl_tls{} = Record, State} -> {next_state, StateName, State, [{next_event, internal, {protocol_record, Record}} | Actions]}; - {#alert{} = Alert, State} -> - {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} + #alert{} = Alert -> + {next_state, StateName, State0, [{next_event, internal, Alert} | Actions]} end; next_event(StateName, Record, State, Actions) -> case Record of @@ -198,21 +221,6 @@ next_event(StateName, Record, State, Actions) -> {next_state, StateName, State, [{next_event, internal, Alert} | Actions]} end. -decode_cipher_texts(Type, [] = CipherTexts, ConnectionStates, _, Acc) -> - {#ssl_tls{type = Type, fragment = Acc}, ConnectionStates, CipherTexts}; -decode_cipher_texts(Type, - [#ssl_tls{type = Type} = CT | CipherTexts], ConnectionStates0, Check, Acc) -> - case tls_record:decode_cipher_text(CT, ConnectionStates0, Check) of - {#ssl_tls{type = ?APPLICATION_DATA, fragment = Plain}, ConnectionStates} -> - decode_cipher_texts(Type, CipherTexts, - ConnectionStates, Check, <<Acc/binary, Plain/binary>>); - {#ssl_tls{type = Type, fragment = Plain}, ConnectionStates} -> - {#ssl_tls{type = Type, fragment = Plain}, ConnectionStates, CipherTexts}; - #alert{} = Alert -> - {Alert, ConnectionStates0, CipherTexts} - end; -decode_cipher_texts(Type, CipherTexts, ConnectionStates, _, Acc) -> - {#ssl_tls{type = Type, fragment = Acc}, ConnectionStates, CipherTexts}. %%% TLS record protocol level application data messages @@ -303,7 +311,7 @@ renegotiate(#state{static_env = #static_env{role = server, Hs0 = ssl_handshake:init_handshake_history(), {BinMsg, ConnectionStates} = tls_record:encode_handshake(Frag, Version, ConnectionStates0), - send(Transport, Socket, BinMsg), + tls_socket:send(Transport, Socket, BinMsg), State = State0#state{connection_states = ConnectionStates, handshake_env = HsEnv#handshake_env{tls_handshake_history = Hs0}}, @@ -325,7 +333,7 @@ queue_handshake(Handshake, #state{handshake_env = #handshake_env{tls_handshake_h send_handshake_flight(#state{static_env = #static_env{socket = Socket, transport_cb = Transport}, flight_buffer = Flight} = State0) -> - send(Transport, Socket, Flight), + tls_socket:send(Transport, Socket, Flight), {State0#state{flight_buffer = []}, []}. queue_change_cipher(Msg, #state{connection_env = #connection_env{negotiated_version = Version}, @@ -378,7 +386,7 @@ send_alert(Alert, #state{static_env = #static_env{socket = Socket, connection_states = ConnectionStates0} = StateData0) -> {BinMsg, ConnectionStates} = encode_alert(Alert, Version, ConnectionStates0), - send(Transport, Socket, BinMsg), + tls_socket:send(Transport, Socket, BinMsg), StateData0#state{connection_states = ConnectionStates}. %% If an ALERT sent in the connection state, should cause the TLS @@ -432,14 +440,9 @@ protocol_name() -> %%==================================================================== %% Data handling %%==================================================================== -encode_data(Data, Version, ConnectionStates0)-> - tls_record:encode_data(Data, Version, ConnectionStates0). -send(Transport, Socket, Data) -> - tls_socket:send(Transport, Socket, Data). - -socket(Pids, Transport, Socket, Connection, Tracker) -> - tls_socket:socket(Pids, Transport, Socket, Connection, Tracker). +socket(Pids, Transport, Socket, Tracker) -> + tls_socket:socket(Pids, Transport, Socket, ?MODULE, Tracker). setopts(Transport, Socket, Other) -> tls_socket:setopts(Transport, Socket, Other). @@ -478,7 +481,7 @@ init({call, From}, {start, Timeout}, Handshake0 = ssl_handshake:init_handshake_history(), {BinMsg, ConnectionStates, Handshake} = encode_handshake(Hello, HelloVersion, ConnectionStates0, Handshake0), - send(Transport, Socket, BinMsg), + tls_socket:send(Transport, Socket, BinMsg), State = State0#state{connection_states = ConnectionStates, connection_env = CEnv#connection_env{negotiated_version = Version}, %% Requested version session = @@ -703,12 +706,11 @@ connection(internal, #client_hello{} = Hello, }, [{next_event, internal, Hello}]); connection(internal, #client_hello{}, - #state{static_env = #static_env{role = server, - protocol_cb = Connection}, + #state{static_env = #static_env{role = server}, handshake_env = #handshake_env{allow_renegotiate = false}} = State0) -> Alert = ?ALERT_REC(?WARNING, ?NO_RENEGOTIATION), send_alert_in_connection(Alert, State0), - State = Connection:reinit_handshake_data(State0), + State = reinit_handshake_data(State0), next_event(?FUNCTION_NAME, no_record, State); connection(Type, Event, State) -> @@ -807,7 +809,7 @@ initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trac session = #session{is_resumable = new}, connection_states = ConnectionStates, protocol_buffers = #protocol_buffers{}, - user_data_buffer = <<>>, + user_data_buffer = {[],0,[]}, start_or_recv_from = undefined, flight_buffer = [], protocol_specific = #{sender => Sender, @@ -819,7 +821,6 @@ initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trac initialize_tls_sender(#state{static_env = #static_env{ role = Role, transport_cb = Transport, - protocol_cb = Connection, socket = Socket, tracker = Tracker }, @@ -833,19 +834,23 @@ initialize_tls_sender(#state{static_env = #static_env{ socket => Socket, socket_options => SockOpts, tracker => Tracker, - protocol_cb => Connection, transport_cb => Transport, negotiated_version => Version, renegotiate_at => RenegotiateAt}, tls_sender:initialize(Sender, Init). - -next_tls_record(Data, StateName, #state{protocol_buffers = - #protocol_buffers{tls_record_buffer = Buf0, - tls_cipher_texts = CT0} = Buffers} - = State0) -> - case tls_record:get_tls_records(Data, - acceptable_record_versions(StateName, State0), - Buf0) of + +next_tls_record(Data, StateName, + #state{protocol_buffers = + #protocol_buffers{tls_record_buffer = Buf0, + tls_cipher_texts = CT0} = Buffers} = State0) -> + Versions = + case StateName of + hello -> + [tls_record:protocol_version(Vsn) || Vsn <- ?ALL_AVAILABLE_VERSIONS]; + _ -> + State0#state.connection_env#connection_env.negotiated_version + end, + case tls_record:get_tls_records(Data, Versions, Buf0) of {Records, Buf1} -> CT1 = CT0 ++ Records, next_record(State0#state{protocol_buffers = @@ -856,11 +861,6 @@ next_tls_record(Data, StateName, #state{protocol_buffers = end. -acceptable_record_versions(StateName, #state{connection_env = #connection_env{negotiated_version = Version}}) when StateName =/= hello-> - Version; -acceptable_record_versions(hello, _) -> - [tls_record:protocol_version(Vsn) || Vsn <- ?ALL_AVAILABLE_VERSIONS]. - handle_record_alert(Alert, _) -> Alert. @@ -890,7 +890,7 @@ handle_info({CloseTag, Socket}, StateName, connection_env = #connection_env{negotiated_version = Version}, socket_options = #socket_options{active = Active}, protocol_buffers = #protocol_buffers{tls_cipher_texts = CTs}, - user_data_buffer = Buffer, + user_data_buffer = {_,BufferSize,_}, protocol_specific = PS} = State) -> %% Note that as of TLS 1.1, @@ -898,7 +898,7 @@ handle_info({CloseTag, Socket}, StateName, %% session not be resumed. This is a change from TLS 1.0 to conform %% with widespread implementation practice. - case (Active == false) andalso ((CTs =/= []) or (Buffer =/= <<>>)) of + case (Active == false) andalso ((CTs =/= []) or (BufferSize =/= 0)) of false -> case Version of {1, N} when N >= 1 -> @@ -933,9 +933,9 @@ handle_alerts(_, {stop, _, _} = Stop) -> handle_alerts([#alert{level = ?WARNING, description = ?CLOSE_NOTIFY} | _Alerts], {next_state, connection = StateName, #state{connection_env = CEnv, socket_options = #socket_options{active = false}, - user_data_buffer = Buffer, + user_data_buffer = {_,BufferSize,_}, protocol_buffers = #protocol_buffers{tls_cipher_texts = CTs}} = - State}) when (Buffer =/= <<>>) orelse + State}) when (BufferSize =/= 0) orelse (CTs =/= []) -> {next_state, StateName, State#state{connection_env = CEnv#connection_env{terminated = true}}}; handle_alerts([Alert | Alerts], {next_state, StateName, State}) -> diff --git a/lib/ssl/src/tls_record.erl b/lib/ssl/src/tls_record.erl index 1776ec2627..b456197398 100644 --- a/lib/ssl/src/tls_record.erl +++ b/lib/ssl/src/tls_record.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2007-2018. All Rights Reserved. +%% Copyright Ericsson AB 2007-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -75,15 +75,23 @@ init_connection_states(Role, BeastMitigation) -> pending_write => Pending}. %%-------------------------------------------------------------------- --spec get_tls_records(binary(), [tls_version()] | tls_version(), binary()) -> {[binary()], binary()} | #alert{}. +-spec get_tls_records( + binary(), [tls_version()] | tls_version(), + Buffer0 :: binary() | {'undefined' | #ssl_tls{}, {[binary()],non_neg_integer(),[binary()]}}) -> + {Records :: [#ssl_tls{}], + Buffer :: {'undefined' | #ssl_tls{}, {[binary()],non_neg_integer(),[binary()]}}} | + #alert{}. %% %% and returns it as a list of tls_compressed binaries also returns leftover %% Description: Given old buffer and new data from TCP, packs up a records %% data %%-------------------------------------------------------------------- -get_tls_records(Data, Version, Buffer) -> - get_tls_records_aux(Version, <<Buffer/binary, Data/binary>>, []). - + +get_tls_records(Data, Versions, Buffer) when is_binary(Buffer) -> + parse_tls_records(Versions, {[Data],byte_size(Data),[]}, undefined); +get_tls_records(Data, Versions, {Hdr, {Front,Size,Rear}}) -> + parse_tls_records(Versions, {Front,Size + byte_size(Data),[Data|Rear]}, Hdr). + %%==================================================================== %% Encoding %%==================================================================== @@ -102,8 +110,8 @@ encode_handshake(Frag, Version, ConnectionStates) -> case iolist_size(Frag) of N when N > ?MAX_PLAIN_TEXT_LENGTH -> - Data = split_bin(iolist_to_binary(Frag), Version, BCA, BeastMitigation), - encode_iolist(?HANDSHAKE, Data, Version, ConnectionStates); + Data = split_iovec(erlang:iolist_to_iovec(Frag), Version, BCA, BeastMitigation), + encode_fragments(?HANDSHAKE, Version, Data, ConnectionStates); _ -> encode_plain_text(?HANDSHAKE, Version, Frag, ConnectionStates) end. @@ -129,18 +137,18 @@ encode_change_cipher_spec(Version, ConnectionStates) -> encode_plain_text(?CHANGE_CIPHER_SPEC, Version, ?byte(?CHANGE_CIPHER_SPEC_PROTO), ConnectionStates). %%-------------------------------------------------------------------- --spec encode_data(binary(), tls_version(), ssl_record:connection_states()) -> - {iolist(), ssl_record:connection_states()}. +-spec encode_data([binary()], tls_version(), ssl_record:connection_states()) -> + {[[binary()]], ssl_record:connection_states()}. %% %% Description: Encodes data to send on the ssl-socket. %%-------------------------------------------------------------------- -encode_data(Frag, Version, +encode_data(Data, Version, #{current_write := #{beast_mitigation := BeastMitigation, security_parameters := #security_parameters{bulk_cipher_algorithm = BCA}}} = ConnectionStates) -> - Data = split_bin(Frag, Version, BCA, BeastMitigation), - encode_iolist(?APPLICATION_DATA, Data, Version, ConnectionStates). + Fragments = split_iovec(Data, Version, BCA, BeastMitigation), + encode_fragments(?APPLICATION_DATA, Version, Fragments, ConnectionStates). %%==================================================================== %% Decoding @@ -152,57 +160,59 @@ encode_data(Frag, Version, %% %% Description: Decode cipher text %%-------------------------------------------------------------------- -decode_cipher_text(#ssl_tls{type = Type, version = Version, - fragment = CipherFragment} = CipherText, +decode_cipher_text(CipherText, #{current_read := - #{compression_state := CompressionS0, - sequence_number := Seq, - cipher_state := CipherS0, + #{sequence_number := Seq, security_parameters := - #security_parameters{ - cipher_type = ?AEAD, - bulk_cipher_algorithm = - BulkCipherAlgo, - compression_algorithm = CompAlg} - } = ReadState0} = ConnnectionStates0, _) -> - AAD = start_additional_data(Type, Version, ReadState0), - CipherS1 = ssl_record:nonce_seed(BulkCipherAlgo, <<?UINT64(Seq)>>, CipherS0), - case ssl_record:decipher_aead(BulkCipherAlgo, CipherS1, AAD, CipherFragment, Version) of - {PlainFragment, CipherState} -> - {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, - PlainFragment, CompressionS0), - ConnnectionStates = ConnnectionStates0#{ + #security_parameters{cipher_type = ?AEAD, + bulk_cipher_algorithm = BulkCipherAlgo}, + cipher_state := CipherS0 + } + } = ConnectionStates0, _) -> + SeqBin = <<?UINT64(Seq)>>, + #ssl_tls{type = Type, version = {MajVer,MinVer} = Version, fragment = Fragment} = CipherText, + StartAdditionalData = <<SeqBin/binary, ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>, + CipherS = ssl_record:nonce_seed(BulkCipherAlgo, SeqBin, CipherS0), + case ssl_record:decipher_aead( + BulkCipherAlgo, CipherS, StartAdditionalData, Fragment, Version) + of + PlainFragment when is_binary(PlainFragment) -> + #{current_read := + #{security_parameters := SecParams, + compression_state := CompressionS0} = ReadState0} = ConnectionStates0, + {Plain, CompressionS} = ssl_record:uncompress(SecParams#security_parameters.compression_algorithm, + PlainFragment, CompressionS0), + ConnectionStates = ConnectionStates0#{ current_read => ReadState0#{ - cipher_state => CipherState, + cipher_state => CipherS, sequence_number => Seq + 1, - compression_state => CompressionS1}}, - {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; + compression_state => CompressionS}}, + {CipherText#ssl_tls{fragment = Plain}, ConnectionStates}; #alert{} = Alert -> Alert end; -decode_cipher_text(#ssl_tls{type = Type, version = Version, +decode_cipher_text(#ssl_tls{version = Version, fragment = CipherFragment} = CipherText, - #{current_read := - #{compression_state := CompressionS0, - sequence_number := Seq, - security_parameters := - #security_parameters{compression_algorithm = CompAlg} - } = ReadState0} = ConnnectionStates0, PaddingCheck) -> + #{current_read := ReadState0} = ConnnectionStates0, PaddingCheck) -> case ssl_record:decipher(Version, CipherFragment, ReadState0, PaddingCheck) of {PlainFragment, Mac, ReadState1} -> - MacHash = ssl_cipher:calc_mac_hash(Type, Version, PlainFragment, ReadState1), + MacHash = ssl_cipher:calc_mac_hash(CipherText#ssl_tls.type, Version, PlainFragment, ReadState1), case ssl_record:is_correct_mac(Mac, MacHash) of true -> + #{sequence_number := Seq, + compression_state := CompressionS0, + security_parameters := + #security_parameters{compression_algorithm = CompAlg}} = ReadState0, {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, PlainFragment, CompressionS0), - ConnnectionStates = ConnnectionStates0#{ - current_read => ReadState1#{ - sequence_number => Seq + 1, - compression_state => CompressionS1}}, + ConnnectionStates = + ConnnectionStates0#{current_read => + ReadState1#{sequence_number => Seq + 1, + compression_state => CompressionS1}}, {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; false -> - ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) + ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) end; #alert{} = Alert -> Alert @@ -384,124 +394,222 @@ initial_connection_state(ConnectionEnd, BeastMitigation) -> server_verify_data => undefined }. -get_tls_records_aux({MajVer, MinVer} = Version, <<?BYTE(Type),?BYTE(MajVer),?BYTE(MinVer), - ?UINT16(Length), Data:Length/binary, Rest/binary>>, - Acc) when Type == ?APPLICATION_DATA; - Type == ?HANDSHAKE; - Type == ?ALERT; - Type == ?CHANGE_CIPHER_SPEC -> - get_tls_records_aux(Version, Rest, [#ssl_tls{type = Type, - version = Version, - fragment = Data} | Acc]); -get_tls_records_aux(Versions, <<?BYTE(Type),?BYTE(MajVer),?BYTE(MinVer), - ?UINT16(Length), Data:Length/binary, Rest/binary>>, - Acc) when is_list(Versions) andalso - ((Type == ?APPLICATION_DATA) - orelse - (Type == ?HANDSHAKE) - orelse - (Type == ?ALERT) - orelse - (Type == ?CHANGE_CIPHER_SPEC)) -> - case is_acceptable_version({MajVer, MinVer}, Versions) of + +parse_tls_records(Versions, Q, undefined) -> + decode_tls_records(Versions, Q, [], undefined, undefined, undefined); +parse_tls_records(Versions, Q, #ssl_tls{type = Type, version = Version, fragment = Length}) -> + decode_tls_records(Versions, Q, [], Type, Version, Length). + +%% Generic code path +decode_tls_records(Versions, {_,Size,_} = Q0, Acc, undefined, _Version, _Length) -> + if + 5 =< Size -> + {<<?BYTE(Type),?BYTE(MajVer),?BYTE(MinVer), ?UINT16(Length)>>, Q} = binary_from_front(5, Q0), + validate_tls_records_type(Versions, Q, Acc, Type, {MajVer,MinVer}, Length); + 3 =< Size -> + {<<?BYTE(Type),?BYTE(MajVer),?BYTE(MinVer)>>, Q} = binary_from_front(3, Q0), + validate_tls_records_type(Versions, Q, Acc, Type, {MajVer,MinVer}, undefined); + 1 =< Size -> + {<<?BYTE(Type)>>, Q} = binary_from_front(1, Q0), + validate_tls_records_type(Versions, Q, Acc, Type, undefined, undefined); + true -> + validate_tls_records_type(Versions, Q0, Acc, undefined, undefined, undefined) + end; +decode_tls_records(Versions, {_,Size,_} = Q0, Acc, Type, undefined, _Length) -> + if + 4 =< Size -> + {<<?BYTE(MajVer),?BYTE(MinVer), ?UINT16(Length)>>, Q} = binary_from_front(4, Q0), + validate_tls_record_version(Versions, Q, Acc, Type, {MajVer,MinVer}, Length); + 2 =< Size -> + {<<?BYTE(MajVer),?BYTE(MinVer)>>, Q} = binary_from_front(2, Q0), + validate_tls_record_version(Versions, Q, Acc, Type, {MajVer,MinVer}, undefined); + true -> + validate_tls_record_version(Versions, Q0, Acc, Type, undefined, undefined) + end; +decode_tls_records(Versions, {_,Size,_} = Q0, Acc, Type, Version, undefined) -> + if + 2 =< Size -> + {<<?UINT16(Length)>>, Q} = binary_from_front(2, Q0), + validate_tls_record_length(Versions, Q, Acc, Type, Version, Length); + true -> + validate_tls_record_length(Versions, Q0, Acc, Type, Version, undefined) + end; +decode_tls_records(Versions, Q, Acc, Type, Version, Length) -> + validate_tls_record_length(Versions, Q, Acc, Type, Version, Length). + +validate_tls_records_type(_Versions, Q, Acc, undefined, _Version, _Length) -> + {lists:reverse(Acc), + {undefined, Q}}; +validate_tls_records_type(Versions, Q, Acc, Type, Version, Length) -> + if + ?KNOWN_RECORD_TYPE(Type) -> + validate_tls_record_version(Versions, Q, Acc, Type, Version, Length); + true -> + %% Not ?KNOWN_RECORD_TYPE(Type) + ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE) + end. + +validate_tls_record_version(_Versions, Q, Acc, Type, undefined, _Length) -> + {lists:reverse(Acc), + {#ssl_tls{type = Type, version = undefined, fragment = undefined}, Q}}; +validate_tls_record_version(Versions, Q, Acc, Type, Version, Length) -> + if + is_list(Versions) -> + case is_acceptable_version(Version, Versions) of + true -> + validate_tls_record_length(Versions, Q, Acc, Type, Version, Length); + false -> + ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) + end; + Version =:= Versions -> + %% Exact version match + validate_tls_record_length(Versions, Q, Acc, Type, Version, Length); true -> - get_tls_records_aux(Versions, Rest, [#ssl_tls{type = Type, - version = {MajVer, MinVer}, - fragment = Data} | Acc]); - false -> ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC) + end. + +validate_tls_record_length(_Versions, Q, Acc, Type, Version, undefined) -> + {lists:reverse(Acc), + {#ssl_tls{type = Type, version = Version, fragment = undefined}, Q}}; +validate_tls_record_length(Versions, {_,Size0,_} = Q0, Acc, Type, Version, Length) -> + if + Length =< ?MAX_CIPHER_TEXT_LENGTH -> + if + Length =< Size0 -> + %% Complete record + {Fragment, Q} = binary_from_front(Length, Q0), + Record = #ssl_tls{type = Type, version = Version, fragment = Fragment}, + decode_tls_records(Versions, Q, [Record|Acc], undefined, undefined, undefined); + true -> + {lists:reverse(Acc), + {#ssl_tls{type = Type, version = Version, fragment = Length}, Q0}} + end; + true -> + ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW) + end. + + +binary_from_front(SplitSize, {Front,Size,Rear}) -> + binary_from_front(SplitSize, Front, Size, Rear, []). +%% +binary_from_front(SplitSize, [], Size, [_] = Rear, Acc) -> + %% Optimize a simple case + binary_from_front(SplitSize, Rear, Size, [], Acc); +binary_from_front(SplitSize, [], Size, Rear, Acc) -> + binary_from_front(SplitSize, lists:reverse(Rear), Size, [], Acc); +binary_from_front(SplitSize, [Bin|Front], Size, Rear, []) -> + %% Optimize a frequent case + BinSize = byte_size(Bin), + if + SplitSize < BinSize -> + {RetBin, Rest} = erlang:split_binary(Bin, SplitSize), + {RetBin, {[Rest|Front],Size - SplitSize,Rear}}; + BinSize < SplitSize -> + binary_from_front(SplitSize - BinSize, Front, Size, Rear, [Bin]); + true -> % Perfect fit + {Bin, {Front,Size - SplitSize,Rear}} end; -get_tls_records_aux(_, <<?BYTE(Type),?BYTE(_MajVer),?BYTE(_MinVer), - ?UINT16(Length), _:Length/binary, _Rest/binary>>, - _) when Type == ?APPLICATION_DATA; - Type == ?HANDSHAKE; - Type == ?ALERT; - Type == ?CHANGE_CIPHER_SPEC -> - ?ALERT_REC(?FATAL, ?BAD_RECORD_MAC); -get_tls_records_aux(_, <<0:1, _CT:7, ?BYTE(_MajVer), ?BYTE(_MinVer), - ?UINT16(Length), _/binary>>, - _Acc) when Length > ?MAX_CIPHER_TEXT_LENGTH -> - ?ALERT_REC(?FATAL, ?RECORD_OVERFLOW); -get_tls_records_aux(_, Data, Acc) -> - case size(Data) =< ?MAX_CIPHER_TEXT_LENGTH + ?INITIAL_BYTES of - true -> - {lists:reverse(Acc), Data}; - false -> - ?ALERT_REC(?FATAL, ?UNEXPECTED_MESSAGE) +binary_from_front(SplitSize, [Bin|Front], Size, Rear, Acc) -> + BinSize = byte_size(Bin), + if + SplitSize < BinSize -> + {Last, Rest} = erlang:split_binary(Bin, SplitSize), + RetBin = iolist_to_binary(lists:reverse(Acc, [Last])), + {RetBin, {[Rest|Front],Size - byte_size(RetBin),Rear}}; + BinSize < SplitSize -> + binary_from_front(SplitSize - BinSize, Front, Size, Rear, [Bin|Acc]); + true -> % Perfect fit + RetBin = iolist_to_binary(lists:reverse(Acc, [Bin])), + {RetBin, {Front,Size - byte_size(RetBin),Rear}} end. + +%%-------------------------------------------------------------------- +encode_plain_text(Type, Version, Data, ConnectionStates0) -> + {[CipherText],ConnectionStates} = encode_fragments(Type, Version, [Data], ConnectionStates0), + {CipherText,ConnectionStates}. %%-------------------------------------------------------------------- -encode_plain_text(Type, Version, Data, #{current_write := Write0} = ConnectionStates) -> - {CipherFragment, Write1} = do_encode_plain_text(Type, Version, Data, Write0), - {CipherText, Write} = encode_tls_cipher_text(Type, Version, CipherFragment, Write1), - {CipherText, ConnectionStates#{current_write => Write}}. - -encode_tls_cipher_text(Type, {MajVer, MinVer}, Fragment, #{sequence_number := Seq} = Write) -> - Length = erlang:iolist_size(Fragment), - {[<<?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer), ?UINT16(Length)>>, Fragment], - Write#{sequence_number => Seq +1}}. - -encode_iolist(Type, Data, Version, ConnectionStates0) -> - {ConnectionStates, EncodedMsg} = - lists:foldl(fun(Text, {CS0, Encoded}) -> - {Enc, CS1} = - encode_plain_text(Type, Version, Text, CS0), - {CS1, [Enc | Encoded]} - end, {ConnectionStates0, []}, Data), - {lists:reverse(EncodedMsg), ConnectionStates}. -%%-------------------------------------------------------------------- -do_encode_plain_text(Type, Version, Data, #{compression_state := CompS0, - cipher_state := CipherS0, - sequence_number := Seq, - security_parameters := - #security_parameters{ - cipher_type = ?AEAD, - bulk_cipher_algorithm = BCAlg, - compression_algorithm = CompAlg} - } = WriteState0) -> - {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), - CipherS = ssl_record:nonce_seed(BCAlg, <<?UINT64(Seq)>>, CipherS0), - WriteState = WriteState0#{compression_state => CompS1, - cipher_state => CipherS}, - AAD = start_additional_data(Type, Version, WriteState), - ssl_record:cipher_aead(Version, Comp, WriteState, AAD); -do_encode_plain_text(Type, Version, Data, #{compression_state := CompS0, - security_parameters := - #security_parameters{compression_algorithm = CompAlg} - }= WriteState0) -> - {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), - WriteState1 = WriteState0#{compression_state => CompS1}, - MacHash = ssl_cipher:calc_mac_hash(Type, Version, Comp, WriteState1), - ssl_record:cipher(Version, Comp, WriteState1, MacHash); -do_encode_plain_text(_,_,_,CS) -> +encode_fragments(Type, Version, Data, + #{current_write := #{compression_state := CompS, + cipher_state := CipherS, + sequence_number := Seq}} = ConnectionStates) -> + encode_fragments(Type, Version, Data, ConnectionStates, CompS, CipherS, Seq, []). +%% +encode_fragments(_Type, _Version, [], #{current_write := WriteS} = CS, + CompS, CipherS, Seq, CipherFragments) -> + {lists:reverse(CipherFragments), + CS#{current_write := WriteS#{compression_state := CompS, + cipher_state := CipherS, + sequence_number := Seq}}}; +encode_fragments(Type, Version, [Text|Data], + #{current_write := #{security_parameters := + #security_parameters{cipher_type = ?AEAD, + bulk_cipher_algorithm = BCAlg, + compression_algorithm = CompAlg} = SecPars}} = CS, + CompS0, CipherS0, Seq, CipherFragments) -> + {CompText, CompS} = ssl_record:compress(CompAlg, Text, CompS0), + SeqBin = <<?UINT64(Seq)>>, + CipherS1 = ssl_record:nonce_seed(BCAlg, SeqBin, CipherS0), + {MajVer, MinVer} = Version, + VersionBin = <<?BYTE(MajVer), ?BYTE(MinVer)>>, + StartAdditionalData = <<SeqBin/binary, ?BYTE(Type), VersionBin/binary>>, + {CipherFragment,CipherS} = ssl_record:cipher_aead(Version, CompText, CipherS1, StartAdditionalData, SecPars), + Length = byte_size(CipherFragment), + CipherHeader = <<?BYTE(Type), VersionBin/binary, ?UINT16(Length)>>, + encode_fragments(Type, Version, Data, CS, CompS, CipherS, Seq + 1, + [[CipherHeader, CipherFragment] | CipherFragments]); +encode_fragments(Type, Version, [Text|Data], + #{current_write := #{security_parameters := + #security_parameters{compression_algorithm = CompAlg, + mac_algorithm = MacAlgorithm} = SecPars, + mac_secret := MacSecret}} = CS, + CompS0, CipherS0, Seq, CipherFragments) -> + {CompText, CompS} = ssl_record:compress(CompAlg, Text, CompS0), + MacHash = ssl_cipher:calc_mac_hash(Type, Version, CompText, MacAlgorithm, MacSecret, Seq), + {CipherFragment,CipherS} = ssl_record:cipher(Version, CompText, CipherS0, MacHash, SecPars), + Length = byte_size(CipherFragment), + {MajVer, MinVer} = Version, + CipherHeader = <<?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer), ?UINT16(Length)>>, + encode_fragments(Type, Version, Data, CS, CompS, CipherS, Seq + 1, + [[CipherHeader, CipherFragment] | CipherFragments]); +encode_fragments(_Type, _Version, _Data, CS, _CompS, _CipherS, _Seq, _CipherFragments) -> exit({cs, CS}). %%-------------------------------------------------------------------- -start_additional_data(Type, {MajVer, MinVer}, - #{sequence_number := SeqNo}) -> - <<?UINT64(SeqNo), ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. %% 1/n-1 splitting countermeasure Rizzo/Duong-Beast, RC4 chiphers are %% not vulnerable to this attack. -split_bin(<<FirstByte:8, Rest/binary>>, Version, BCA, one_n_minus_one) when - BCA =/= ?RC4 andalso ({3, 1} == Version orelse - {3, 0} == Version) -> - [[FirstByte]|do_split_bin(Rest)]; +split_iovec([<<FirstByte:8, Rest/binary>>|Data], Version, BCA, one_n_minus_one) + when (BCA =/= ?RC4) andalso ({3, 1} == Version orelse + {3, 0} == Version) -> + [[FirstByte]|split_iovec([Rest|Data])]; %% 0/n splitting countermeasure for clients that are incompatible with 1/n-1 %% splitting. -split_bin(Bin, Version, BCA, zero_n) when - BCA =/= ?RC4 andalso ({3, 1} == Version orelse - {3, 0} == Version) -> - [<<>>|do_split_bin(Bin)]; -split_bin(Bin, _, _, _) -> - do_split_bin(Bin). - -do_split_bin(<<>>) -> []; -do_split_bin(Bin) -> - case Bin of - <<Chunk:?MAX_PLAIN_TEXT_LENGTH/binary, Rest/binary>> -> - [Chunk|do_split_bin(Rest)]; - _ -> - [Bin] - end. +split_iovec(Data, Version, BCA, zero_n) + when (BCA =/= ?RC4) andalso ({3, 1} == Version orelse + {3, 0} == Version) -> + [<<>>|split_iovec(Data)]; +split_iovec(Data, _Version, _BCA, _BeatMitigation) -> + split_iovec(Data). + +split_iovec([]) -> + []; +split_iovec(Data) -> + {Part,Rest} = split_iovec(Data, ?MAX_PLAIN_TEXT_LENGTH, []), + [Part|split_iovec(Rest)]. +%% +split_iovec([Bin|Data], SplitSize, Acc) -> + BinSize = byte_size(Bin), + if + SplitSize < BinSize -> + {Last, Rest} = erlang:split_binary(Bin, SplitSize), + {lists:reverse(Acc, [Last]), [Rest|Data]}; + BinSize < SplitSize -> + split_iovec(Data, SplitSize - BinSize, [Bin|Acc]); + true -> % Perfect match + {lists:reverse(Acc, [Bin]), Data} + end; +split_iovec([], _SplitSize, Acc) -> + {lists:reverse(Acc),[]}. + %%-------------------------------------------------------------------- lowest_list_protocol_version(Ver, []) -> Ver; diff --git a/lib/ssl/src/tls_sender.erl b/lib/ssl/src/tls_sender.erl index 11fcc6def0..c07b7f49cd 100644 --- a/lib/ssl/src/tls_sender.erl +++ b/lib/ssl/src/tls_sender.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2018-2018. All Rights Reserved. +%% Copyright Ericsson AB 2018-2019. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -38,19 +38,23 @@ -define(SERVER, ?MODULE). --record(data, {connection_pid, - connection_states = #{}, - role, - socket, - socket_options, - tracker, - protocol_cb, - transport_cb, - negotiated_version, - renegotiate_at, - connection_monitor, - dist_handle - }). +-record(static, + {connection_pid, + role, + socket, + socket_options, + tracker, + transport_cb, + negotiated_version, + renegotiate_at, + connection_monitor, + dist_handle + }). + +-record(data, + {static = #static{}, + connection_states = #{} + }). %%%=================================================================== %%% API @@ -171,6 +175,10 @@ dist_tls_socket(Pid) -> callback_mode() -> state_functions. + +-define(HANDLE_COMMON, + ?FUNCTION_NAME(Type, Msg, StateData) -> + handle_common(Type, Msg, StateData)). %%-------------------------------------------------------------------- -spec init(Args :: term()) -> gen_statem:init_result(atom()). @@ -192,39 +200,35 @@ init({call, From}, {Pid, #{current_write := WriteState, socket := Socket, socket_options := SockOpts, tracker := Tracker, - protocol_cb := Connection, transport_cb := Transport, negotiated_version := Version, renegotiate_at := RenegotiateAt}}, - #data{connection_states = ConnectionStates} = StateData0) -> + #data{connection_states = ConnectionStates, static = Static0} = StateData0) -> Monitor = erlang:monitor(process, Pid), StateData = - StateData0#data{connection_pid = Pid, - connection_monitor = Monitor, - connection_states = - ConnectionStates#{current_write => WriteState}, - role = Role, - socket = Socket, - socket_options = SockOpts, - tracker = Tracker, - protocol_cb = Connection, - transport_cb = Transport, - negotiated_version = Version, - renegotiate_at = RenegotiateAt}, + StateData0#data{connection_states = ConnectionStates#{current_write => WriteState}, + static = Static0#static{connection_pid = Pid, + connection_monitor = Monitor, + role = Role, + socket = Socket, + socket_options = SockOpts, + tracker = Tracker, + transport_cb = Transport, + negotiated_version = Version, + renegotiate_at = RenegotiateAt}}, {next_state, handshake, StateData, [{reply, From, ok}]}; -init(info, Msg, StateData) -> - handle_info(Msg, ?FUNCTION_NAME, StateData). +init(_, _, _) -> + %% Just in case anything else sneeks through + {keep_state_and_data, [postpone]}. + %%-------------------------------------------------------------------- -spec connection(gen_statem:event_type(), Msg :: term(), StateData :: term()) -> gen_statem:event_handler_result(atom()). %%-------------------------------------------------------------------- -connection({call, From}, renegotiate, - #data{connection_states = #{current_write := Write}} = StateData) -> - {next_state, handshake, StateData, [{reply, From, {ok, Write}}]}; connection({call, From}, {application_data, AppData}, - #data{socket_options = #socket_options{packet = Packet}} = + #data{static = #static{socket_options = #socket_options{packet = Packet}}} = StateData) -> case encode_packet(Packet, AppData) of {error, _} = Error -> @@ -232,40 +236,40 @@ connection({call, From}, {application_data, AppData}, Data -> send_application_data(Data, From, ?FUNCTION_NAME, StateData) end; -connection({call, From}, {set_opts, _} = Call, StateData) -> - handle_call(From, Call, ?FUNCTION_NAME, StateData); +connection({call, From}, {ack_alert, #alert{} = Alert}, StateData0) -> + StateData = send_tls_alert(Alert, StateData0), + {next_state, ?FUNCTION_NAME, StateData, + [{reply,From,ok}]}; +connection({call, From}, renegotiate, + #data{connection_states = #{current_write := Write}} = StateData) -> + {next_state, handshake, StateData, [{reply, From, {ok, Write}}]}; +connection({call, From}, downgrade, #data{connection_states = + #{current_write := Write}} = StateData) -> + {next_state, death_row, StateData, [{reply,From, {ok, Write}}]}; +connection({call, From}, {set_opts, Opts}, StateData) -> + handle_set_opts(From, Opts, StateData); connection({call, From}, dist_get_tls_socket, - #data{protocol_cb = Connection, - transport_cb = Transport, - socket = Socket, - connection_pid = Pid, - tracker = Tracker} = StateData) -> - TLSSocket = Connection:socket([Pid, self()], Transport, Socket, Connection, Tracker), + #data{static = #static{transport_cb = Transport, + socket = Socket, + connection_pid = Pid, + tracker = Tracker}} = StateData) -> + TLSSocket = tls_connection:socket([Pid, self()], Transport, Socket, Tracker), {next_state, ?FUNCTION_NAME, StateData, [{reply, From, {ok, TLSSocket}}]}; connection({call, From}, {dist_handshake_complete, _Node, DHandle}, - #data{connection_pid = Pid, - socket_options = #socket_options{packet = Packet}} = - StateData) -> + #data{static = #static{connection_pid = Pid} = Static} = StateData) -> ok = erlang:dist_ctrl_input_handler(DHandle, Pid), ok = ssl_connection:dist_handshake_complete(Pid, DHandle), %% From now on we execute on normal priority process_flag(priority, normal), - {next_state, ?FUNCTION_NAME, StateData#data{dist_handle = DHandle}, - [{reply, From, ok} - | case dist_data(DHandle, Packet) of - [] -> - []; - Data -> - [{next_event, internal, - {application_packets,{self(),undefined},Data}}] - end]}; -connection({call, From}, {ack_alert, #alert{} = Alert}, StateData0) -> - StateData = send_tls_alert(Alert, StateData0), - {next_state, ?FUNCTION_NAME, StateData, - [{reply,From,ok}]}; -connection({call, From}, downgrade, #data{connection_states = - #{current_write := Write}} = StateData) -> - {next_state, death_row, StateData, [{reply,From, {ok, Write}}]}; + {keep_state, StateData#data{static = Static#static{dist_handle = DHandle}}, + [{reply,From,ok}| + case dist_data(DHandle) of + [] -> + []; + Data -> + [{next_event, internal, + {application_packets,{self(),undefined},erlang:iolist_to_iovec(Data)}}] + end]}; connection(internal, {application_packets, From, Data}, StateData) -> send_application_data(Data, From, ?FUNCTION_NAME, StateData); %% @@ -273,29 +277,26 @@ connection(cast, #alert{} = Alert, StateData0) -> StateData = send_tls_alert(Alert, StateData0), {next_state, ?FUNCTION_NAME, StateData}; connection(cast, {new_write, WritesState, Version}, - #data{connection_states = ConnectionStates0} = StateData) -> + #data{connection_states = ConnectionStates, static = Static} = StateData) -> {next_state, connection, StateData#data{connection_states = - ConnectionStates0#{current_write => WritesState}, - negotiated_version = Version}}; + ConnectionStates#{current_write => WritesState}, + static = Static#static{negotiated_version = Version}}}; %% -connection(info, dist_data, - #data{dist_handle = DHandle, - socket_options = #socket_options{packet = Packet}} = - StateData) -> - {next_state, ?FUNCTION_NAME, StateData, - case dist_data(DHandle, Packet) of +connection(info, dist_data, #data{static = #static{dist_handle = DHandle}}) -> + {keep_state_and_data, + case dist_data(DHandle) of [] -> []; Data -> [{next_event, internal, - {application_packets,{self(),undefined},Data}}] + {application_packets,{self(),undefined},erlang:iolist_to_iovec(Data)}}] end}; connection(info, tick, StateData) -> consume_ticks(), - {next_state, ?FUNCTION_NAME, StateData, - [{next_event, {call, {self(), undefined}}, - {application_data, <<>>}}]}; + Data = [<<0:32>>], % encode_packet(4, <<>>) + From = {self(), undefined}, + send_application_data(Data, From, ?FUNCTION_NAME, StateData); connection(info, {send, From, Ref, Data}, _StateData) -> %% This is for testing only! %% @@ -304,29 +305,37 @@ connection(info, {send, From, Ref, Data}, _StateData) -> From ! {Ref, ok}, {keep_state_and_data, [{next_event, {call, {self(), undefined}}, - {application_data, iolist_to_binary(Data)}}]}; -connection(info, Msg, StateData) -> - handle_info(Msg, ?FUNCTION_NAME, StateData). + {application_data, erlang:iolist_to_iovec(Data)}}]}; +?HANDLE_COMMON. + %%-------------------------------------------------------------------- -spec handshake(gen_statem:event_type(), Msg :: term(), StateData :: term()) -> gen_statem:event_handler_result(atom()). %%-------------------------------------------------------------------- -handshake({call, From}, {set_opts, _} = Call, StateData) -> - handle_call(From, Call, ?FUNCTION_NAME, StateData); +handshake({call, From}, {set_opts, Opts}, StateData) -> + handle_set_opts(From, Opts, StateData); handshake({call, _}, _, _) -> + %% Postpone all calls to the connection state + {keep_state_and_data, [postpone]}; +handshake(internal, {application_packets,_,_}, _) -> {keep_state_and_data, [postpone]}; handshake(cast, {new_write, WritesState, Version}, - #data{connection_states = ConnectionStates0} = StateData) -> + #data{connection_states = ConnectionStates, static = Static} = StateData) -> {next_state, connection, - StateData#data{connection_states = - ConnectionStates0#{current_write => WritesState}, - negotiated_version = Version}}; -handshake(internal, {application_packets,_,_}, _) -> + StateData#data{connection_states = ConnectionStates#{current_write => WritesState}, + static = Static#static{negotiated_version = Version}}}; +handshake(info, dist_data, _) -> {keep_state_and_data, [postpone]}; -handshake(info, Msg, StateData) -> - handle_info(Msg, ?FUNCTION_NAME, StateData). +handshake(info, tick, _) -> + %% Ignore - data is sent anyway during handshake + consume_ticks(), + keep_state_and_data; +handshake(info, {send, _, _, _}, _) -> + %% Testing only, OTP distribution test suites... + {keep_state_and_data, [postpone]}; +?HANDLE_COMMON. %%-------------------------------------------------------------------- -spec death_row(gen_statem:event_type(), @@ -361,49 +370,66 @@ code_change(_OldVsn, State, Data, _Extra) -> %%%=================================================================== %%% Internal functions %%%=================================================================== -handle_call(From, {set_opts, Opts}, StateName, #data{socket_options = SockOpts} = StateData) -> - {next_state, StateName, StateData#data{socket_options = set_opts(SockOpts, Opts)}, [{reply, From, ok}]}. - -handle_info({'DOWN', Monitor, _, _, Reason}, _, - #data{connection_monitor = Monitor, - dist_handle = Handle} = StateData) when Handle =/= undefined-> - {next_state, death_row, StateData, [{state_timeout, 5000, Reason}]}; -handle_info({'DOWN', Monitor, _, _, _}, _, - #data{connection_monitor = Monitor} = StateData) -> + +handle_set_opts( + From, Opts, #data{static = #static{socket_options = SockOpts} = Static} = StateData) -> + {keep_state, StateData#data{static = Static#static{socket_options = set_opts(SockOpts, Opts)}}, + [{reply, From, ok}]}. + +handle_common( + {call, From}, {set_opts, Opts}, + #data{static = #static{socket_options = SockOpts} = Static} = StateData) -> + {keep_state, StateData#data{static = Static#static{socket_options = set_opts(SockOpts, Opts)}}, + [{reply, From, ok}]}; +handle_common( + info, {'DOWN', Monitor, _, _, Reason}, + #data{static = #static{connection_monitor = Monitor, + dist_handle = Handle}} = StateData) when Handle =/= undefined -> + {next_state, death_row, StateData, + [{state_timeout, 5000, Reason}]}; +handle_common( + info, {'DOWN', Monitor, _, _, _}, + #data{static = #static{connection_monitor = Monitor}} = StateData) -> {stop, normal, StateData}; -handle_info(_,_,_) -> +handle_common(info, Msg, _) -> + Report = + io_lib:format("TLS sender: Got unexpected info: ~p ~n", [Msg]), + error_logger:info_report(Report), + keep_state_and_data; +handle_common(Type, Msg, _) -> + Report = + io_lib:format( + "TLS sender: Got unexpected event: ~p ~n", [{Type,Msg}]), + error_logger:error_report(Report), keep_state_and_data. -send_tls_alert(Alert, #data{negotiated_version = Version, - socket = Socket, - protocol_cb = Connection, - transport_cb = Transport, - connection_states = ConnectionStates0} = StateData0) -> +send_tls_alert(#alert{} = Alert, + #data{static = #static{negotiated_version = Version, + socket = Socket, + transport_cb = Transport}, + connection_states = ConnectionStates0} = StateData0) -> {BinMsg, ConnectionStates} = - Connection:encode_alert(Alert, Version, ConnectionStates0), - Connection:send(Transport, Socket, BinMsg), + tls_record:encode_alert_record(Alert, Version, ConnectionStates0), + tls_socket:send(Transport, Socket, BinMsg), StateData0#data{connection_states = ConnectionStates}. send_application_data(Data, From, StateName, - #data{connection_pid = Pid, - socket = Socket, - dist_handle = DistHandle, - negotiated_version = Version, - protocol_cb = Connection, - transport_cb = Transport, - connection_states = ConnectionStates0, - renegotiate_at = RenegotiateAt} = StateData0) -> + #data{static = #static{connection_pid = Pid, + socket = Socket, + dist_handle = DistHandle, + negotiated_version = Version, + transport_cb = Transport, + renegotiate_at = RenegotiateAt}, + connection_states = ConnectionStates0} = StateData0) -> case time_to_renegotiate(Data, ConnectionStates0, RenegotiateAt) of true -> ssl_connection:internal_renegotiation(Pid, ConnectionStates0), {next_state, handshake, StateData0, [{next_event, internal, {application_packets, From, Data}}]}; false -> - {Msgs, ConnectionStates} = - Connection:encode_data( - iolist_to_binary(Data), Version, ConnectionStates0), + {Msgs, ConnectionStates} = tls_record:encode_data(Data, Version, ConnectionStates0), StateData = StateData0#data{connection_states = ConnectionStates}, - case Connection:send(Transport, Socket, Msgs) of + case tls_socket:send(Transport, Socket, Msgs) of ok when DistHandle =/= undefined -> {next_state, StateName, StateData, []}; Reason when DistHandle =/= undefined -> @@ -419,9 +445,9 @@ send_application_data(Data, From, StateName, encode_packet(Packet, Data) -> Len = iolist_size(Data), case Packet of - 1 when Len < (1 bsl 8) -> [<<Len:8>>,Data]; - 2 when Len < (1 bsl 16) -> [<<Len:16>>,Data]; - 4 when Len < (1 bsl 32) -> [<<Len:32>>,Data]; + 1 when Len < (1 bsl 8) -> [<<Len:8>>|Data]; + 2 when Len < (1 bsl 16) -> [<<Len:16>>|Data]; + 4 when Len < (1 bsl 32) -> [<<Len:32>>|Data]; N when N =:= 1; N =:= 2; N =:= 4 -> {error, {badarg, {packet_to_large, Len, (1 bsl (Packet bsl 3)) - 1}}}; @@ -458,22 +484,30 @@ call(FsmPid, Event) -> {error, closed} end. -%%---------------Erlang distribution -------------------------------------- +%%-------------- Erlang distribution helpers ------------------------------ -dist_data(DHandle, Packet) -> +dist_data(DHandle) -> case erlang:dist_ctrl_get_data(DHandle) of none -> erlang:dist_ctrl_get_data_notification(DHandle), []; - Data -> - %% This is encode_packet(4, Data) without Len check - %% since the emulator will always deliver a Data - %% smaller than 4 GB, and the distribution will - %% therefore always have to use {packet,4} + %% This is encode_packet(4, Data) without Len check + %% since the emulator will always deliver a Data + %% smaller than 4 GB, and the distribution will + %% therefore always have to use {packet,4} + Data when is_binary(Data) -> + Len = byte_size(Data), + [[<<Len:32>>,Data]|dist_data(DHandle)]; + [BA,BB] = Data -> + Len = byte_size(BA) + byte_size(BB), + [[<<Len:32>>|Data]|dist_data(DHandle)]; + Data when is_list(Data) -> Len = iolist_size(Data), - [<<Len:32>>,Data|dist_data(DHandle, Packet)] + [[<<Len:32>>|Data]|dist_data(DHandle)] end. + +%% Empty the inbox from distribution ticks - do not let them accumulate consume_ticks() -> receive tick -> consume_ticks() |