From 4b568c6f180af5648218db95e331a39b90b8609a Mon Sep 17 00:00:00 2001 From: Raimo Niskanen Date: Sat, 30 Mar 2019 16:58:02 +0100 Subject: Use Erlang cookie as shared secret * Remove all configuration possibilities, so use the cookie as secret * Clean up error handling to make the module a more complete dist module * Change the init message to use length fields instead of zero termination * Remove the dependency towards modern crypto so it should run on maint --- lib/ssl/test/inet_crypto_dist.erl | 609 +++++++++++++++++++++----------------- 1 file changed, 341 insertions(+), 268 deletions(-) (limited to 'lib/ssl/test/inet_crypto_dist.erl') diff --git a/lib/ssl/test/inet_crypto_dist.erl b/lib/ssl/test/inet_crypto_dist.erl index 5aafaac983..2cf9e3569e 100644 --- a/lib/ssl/test/inet_crypto_dist.erl +++ b/lib/ssl/test/inet_crypto_dist.erl @@ -30,13 +30,15 @@ -define(FAMILY, inet). -define(PROTOCOL, inet_crypto_dist_v1). --define(DEFAULT_BLOCK_CRYPTO, aes_128_gcm). --define(DEFAULT_HASH_ALGORITHM, sha256). --define(DEFAULT_REKEY_INTERVAL, 32768). +-define(HASH_ALGORITHM, sha256). +-define(BLOCK_CRYPTO, aes_gcm). +-define(IV_LEN, 12). +-define(KEY_LEN, 16). +-define(TAG_LEN, 16). +-define(REKEY_INTERVAL, 32768). -export([listen/1, accept/1, accept_connection/5, setup/5, close/1, select/1, is_node_name/1]). --export([is_supported/0]). %% Generalized dist API, for sibling IPv6 module inet6_crypto_dist -export([gen_listen/2, gen_accept/2, gen_accept_connection/6, @@ -52,20 +54,6 @@ -include_lib("kernel/include/dist.hrl"). -include_lib("kernel/include/dist_util.hrl"). -%% Test if crypto has got enough capabilities for this module to run -%% -is_supported() -> - try {crypto:cipher_info(?DEFAULT_BLOCK_CRYPTO), - crypto:hash_info(?DEFAULT_HASH_ALGORITHM)} - of - {#{block_size := _, iv_length := _, key_length := _}, - #{size := _}} -> - true - catch - error:undef -> - false - end. - %% ------------------------------------------------------------------------- %% Erlang distribution plugin structure explained to myself %% ------- @@ -217,24 +205,23 @@ gen_accept(Listen, Driver) -> %% %% Spawn Acceptor process %% - Config = config(), monitor_dist_proc( spawn_opt( fun () -> - accept_loop(Listen, Driver, NetKernel, Config) + accept_loop(Listen, Driver, NetKernel) end, [link, {priority, max}])). -accept_loop(Listen, Driver, NetKernel, Config) -> - case Driver:accept(Listen) of +accept_loop(Listen, Driver, NetKernel) -> + case Driver:accept(trace(Listen)) of {ok, Socket} -> wait_for_code_server(), Timeout = net_kernel:connecttime(), - DistCtrl = start_dist_ctrl(Socket, Config, Timeout), + DistCtrl = start_dist_ctrl(trace(Socket), Timeout), %% DistCtrl is a "socket" NetKernel ! - {accept, - self(), DistCtrl, Driver:family(), ?DIST_PROTO}, + trace({accept, + self(), DistCtrl, Driver:family(), ?DIST_PROTO}), receive {NetKernel, controller, Controller} -> call_dist_ctrl(DistCtrl, {controller, Controller, self()}), @@ -242,7 +229,7 @@ accept_loop(Listen, Driver, NetKernel, Config) -> {NetKernel, unsupported_protocol} -> exit(unsupported_protocol) end, - accept_loop(Listen, Driver, NetKernel, Config); + accept_loop(Listen, Driver, NetKernel); AcceptError -> exit({accept, AcceptError}) end. @@ -292,12 +279,13 @@ gen_accept_connection( fun() -> do_accept( Acceptor, DistCtrl, - MyNode, Allowed, SetupTime, Driver, NetKernel) + trace(MyNode), Allowed, SetupTime, Driver, NetKernel) end, [link, {priority, max}])). do_accept( Acceptor, DistCtrl, MyNode, Allowed, SetupTime, Driver, NetKernel) -> + %% receive {Acceptor, controller, Socket} -> Timer = dist_util:start_timer(SetupTime), @@ -337,40 +325,42 @@ gen_setup(Node, Type, MyNode, LongOrShortNames, SetupTime, Driver) -> -spec setup_fun(_,_,_,_,_,_,_) -> fun(() -> no_return()). setup_fun( Node, Type, MyNode, LongOrShortNames, SetupTime, Driver, NetKernel) -> + %% fun() -> do_setup( - Node, Type, MyNode, LongOrShortNames, SetupTime, + trace(Node), Type, MyNode, LongOrShortNames, SetupTime, Driver, NetKernel) end. -spec do_setup(_,_,_,_,_,_,_) -> no_return(). do_setup( Node, Type, MyNode, LongOrShortNames, SetupTime, Driver, NetKernel) -> + %% {Name, Address} = split_node(Driver, Node, LongOrShortNames), ErlEpmd = net_kernel:epmd_module(), {ARMod, ARFun} = get_address_resolver(ErlEpmd, Driver), Timer = trace(dist_util:start_timer(SetupTime)), case ARMod:ARFun(Name, Address, Driver:family()) of - {ok, Ip, TcpPort, Version} -> - do_setup_connect( - Node, Type, MyNode, Timer, Driver, NetKernel, - Ip, TcpPort, Version); + {ok, Ip, TcpPort, Version} -> + do_setup_connect( + Node, Type, MyNode, Timer, Driver, NetKernel, + Ip, TcpPort, Version); {ok, Ip} -> case ErlEpmd:port_please(Name, Ip) of {port, TcpPort, Version} -> do_setup_connect( Node, Type, MyNode, Timer, Driver, NetKernel, - Ip, TcpPort, Version); + Ip, TcpPort, trace(Version)); Other -> - ?shutdown2( - Node, - trace( - {port_please_failed, ErlEpmd, Name, Ip, Other})) + _ = trace( + {ErlEpmd, port_please, [Name, Ip], Other}), + ?shutdown(Node) end; Other -> - ?shutdown2( - Node, - trace({getaddr_failed, Driver, Address, Other})) + _ = trace( + {ARMod, ARFun, [Name, Address, Driver:family()], + Other}), + ?shutdown(Node) end. -spec do_setup_connect(_,_,_,_,_,_,_,_,_) -> no_return(). @@ -385,9 +375,12 @@ do_setup_connect( [binary, {active, false}, {packet, 2}, {nodelay, true}])), case Driver:connect(Ip, TcpPort, ConnectOpts) of {ok, Socket} -> - Config = config(), DistCtrl = - start_dist_ctrl(Socket, Config, net_kernel:connecttime()), + try start_dist_ctrl(Socket, net_kernel:connecttime()) + catch error : {dist_ctrl, _} = DistCtrlError -> + _ = trace(DistCtrlError), + ?shutdown(Node) + end, %% DistCtrl is a "socket" HSData = hs_data_common( @@ -401,8 +394,10 @@ do_setup_connect( request_type = Type}, dist_util:handshake_we_started(trace(HSData_1)); ConnectError -> - ?shutdown2(Node, - trace({connect_failed, Ip, TcpPort, ConnectError})) + _ = trace( + {Driver, connect, [Ip, TcpPort, ConnectOpts], + ConnectError}), + ?shutdown(Node) end. %% ------------------------------------------------------------------------- @@ -412,7 +407,7 @@ close(Socket) -> gen_close(Socket, ?DRIVER). gen_close(Socket, Driver) -> - trace(Driver:close(Socket)). + Driver:close(trace(Socket)). %% ------------------------------------------------------------------------- @@ -428,17 +423,23 @@ hs_data_common(NetKernel, MyNode, DistCtrl, Timer, Socket, Family) -> socket = DistCtrl, timer = Timer, %% - f_send = + f_send = % -> ok | {error, closed}=>?shutdown() fun (S, Packet) when S =:= DistCtrl -> - call_dist_ctrl(S, {send, Packet}) + try call_dist_ctrl(S, {send, Packet}) + catch error : {dist_ctrl, Reason} -> + _ = trace(Reason), + {error, closed} + end end, - f_recv = + f_recv = % -> {ok, List} | Other=>?shutdown() fun (S, 0, infinity) when S =:= DistCtrl -> - case call_dist_ctrl(S, recv) of + try call_dist_ctrl(S, recv) of {ok, Bin} when is_binary(Bin) -> {ok, binary_to_list(Bin)}; Error -> Error + catch error : {dist_ctrl, Reason} -> + {error, trace(Reason)} end end, f_setopts_pre_nodeup = @@ -453,9 +454,9 @@ hs_data_common(NetKernel, MyNode, DistCtrl, Timer, Socket, Family) -> fun (S) when S =:= DistCtrl -> {ok, S} %% DistCtrl is the distribution port end, - f_address = + f_address = % -> #net_address{} | ?shutdown() fun (S, Node) when S =:= DistCtrl -> - case call_dist_ctrl(S, peername) of + try call_dist_ctrl(S, peername) of {ok, Address} -> case dist_util:split_node(Node) of {node, _, Host} -> @@ -465,13 +466,23 @@ hs_data_common(NetKernel, MyNode, DistCtrl, Timer, Socket, Family) -> protocol = ?DIST_PROTO, family = Family}; _ -> - {error, no_node} - end + ?shutdown(Node) + end; + Error -> + _ = trace(Error), + ?shutdown(Node) + catch error : {dist_ctrl, Reason} -> + _ = trace(Reason), + ?shutdown(Node) end end, - f_handshake_complete = - fun (S, _Node, DistHandle) when S =:= DistCtrl -> - call_dist_ctrl(S, {handshake_complete, DistHandle}) + f_handshake_complete = % -> ok | ?shutdown() + fun (S, Node, DistHandle) when S =:= DistCtrl -> + try call_dist_ctrl(S, {handshake_complete, DistHandle}) + catch error : {dist_ctrl, Reason} -> + _ = trace(Reason), + ?shutdown(Node) + end end, %% %% mf_tick/1, mf_getstat/1, mf_setopts/2 and mf_getopts/2 @@ -481,7 +492,7 @@ hs_data_common(NetKernel, MyNode, DistCtrl, Timer, Socket, Family) -> fun (S) when S =:= DistCtrl -> S ! dist_tick end, - mf_getstat = + mf_getstat = % -> {ok, RecvCnt, SendCnt, SendPend} | Other=>ignore_it fun (S) when S =:= DistCtrl -> case inet:getstat(Socket, [recv_cnt, send_cnt, send_pend]) @@ -489,7 +500,7 @@ hs_data_common(NetKernel, MyNode, DistCtrl, Timer, Socket, Family) -> {ok, Stat} -> split_stat(Stat, 0, 0, 0); Error -> - Error + trace(Error) end end, mf_setopts = @@ -596,25 +607,6 @@ nodelay() -> {nodelay, true} end. -config() -> - case init:get_argument(?DIST_NAME) of - error -> - error({missing_argument, ?DIST_NAME}); - {ok, [[String]]} -> - {ok, Tokens, _} = erl_scan:string(String ++ "."), - case erl_parse:parse_term(Tokens) of - {ok, #{secret := Secret} = Config} - when is_binary(Secret); is_list(Secret) -> - Config; - {ok, #{} = Config} -> - error({missing_secret, [{?DIST_NAME,Config}]}); - _ -> - error({bad_argument_value, [{?DIST_NAME,String}]}) - end; - {ok, Value} -> - error({malformed_argument, [{?DIST_NAME,Value}]}) - end. - %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% %% The DistCtrl process(es). @@ -625,19 +617,13 @@ config() -> %% %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%% XXX Missing to "productified": -%%% * Cryptoanalysis by experts -%%% * Proof of usefulness -%%% * Unifying exit reasons using a death_row() function -%%% * Verification (and rejection) of other end's crypto parameters -%%% * OTP:ification (proc_lib?) -%%% * An application to belong to (crypto|kernel?) -%%% * Secret on file (cookie as default?), parameter handling -%%% * Restart and/or code reload policy +%%% * Cryptoanalysis by experts, e.g Forward Secrecy is missing, right? +%%% * Is it useful? +%%% * An application to belong to (kernel) +%%% * Restart and/or code reload policy (not needed in kernel) -%% Debug client and server -test_config() -> - #{secret => <<"Secret Cluster Password 123456">>}. +%% Debug client and server test_server() -> {ok, Listen} = gen_tcp:listen(0, [{packet, 2}, {active, false}, binary]), @@ -653,12 +639,12 @@ test_client(Port) -> test(Socket). test(Socket) -> - start_dist_ctrl(Socket, test_config(), 10000). + start_dist_ctrl(Socket, 10000). %% ------------------------------------------------------------------------- -start_dist_ctrl(Socket, Config, Timeout) -> - Protocol = ?PROTOCOL, +start_dist_ctrl(Socket, Timeout) -> + Secret = atom_to_binary(auth:get_cookie(), latin1), Controller = self(), Server = monitor_dist_proc( @@ -667,7 +653,7 @@ start_dist_ctrl(Socket, Config, Timeout) -> receive {?MODULE, From, start} -> {SendParams, RecvParams} = - init(Socket, Config, Protocol), + init(Socket, Secret), reply(From, self()), handshake(SendParams, 0, RecvParams, 0, Controller) end @@ -691,12 +677,13 @@ call_dist_ctrl(Server, Msg, Timeout) -> erlang:demonitor(Ref, [flush]), Res; {'DOWN', Ref, process, Server, Reason} -> - exit({?PROTOCOL, Reason}) - after Timeout -> - exit(Server, timeout), + error({dist_ctrl, Reason}) + after Timeout -> % Timeout < infinity is only used by start_dist_ctrl/2 receive {'DOWN', Ref, process, Server, _} -> - exit({?PROTOCOL, timeout}) + receive {Ref, _} -> ok after 0 -> ok end, + error({dist_ctrl, timeout}) + %% Server will be killed by link end end. @@ -717,10 +704,11 @@ reply({Ref, Pid}, Msg) -> key, tag_len}). --define(TCP_ACTIVE, 64). +-define(TCP_ACTIVE, 16). -define(CHUNK_SIZE, (65536 - 512)). -%% The start chunk starts with zeros, so it seems logical to not have -%% a chunk type with value 0 + +%% The start chunk starts with zeros, so it seems logical to +%% not have a chunk type with value 0 -define(HANDSHAKE_CHUNK, 1). -define(DATA_CHUNK, 2). -define(TICK_CHUNK, 3). @@ -759,85 +747,49 @@ reply({Ref, Pid}, Msg) -> %% in the local log, but all the other end sees is a closed connection. %% ------------------------------------------------------------------------- -init(Socket, Config, Protocol) -> - Secret = maps:get(secret, Config), - HashAlgorithm = - maps:get(hash_algorithm, Config, ?DEFAULT_HASH_ALGORITHM), - BlockCrypto = - maps:get(block_crypto, Config, ?DEFAULT_BLOCK_CRYPTO), - RekeyInterval = - maps:get(rekey_interval, Config, ?DEFAULT_REKEY_INTERVAL), - %% - SendParams = - init_params( - Socket, Protocol, HashAlgorithm, BlockCrypto, RekeyInterval), - send_init(SendParams, Secret). - -send_init( - #params{ - protocol = Protocol, - socket = Socket, - block_crypto = BlockCrypto, - iv = IVLen, - key = KeyLen, - hash_algorithm = HashAlgorithm} = SendParams, - Secret) -> - %% - ProtocolString = atom_to_binary(Protocol, utf8), - BlockCryptoString = atom_to_binary(BlockCrypto, utf8), - HashAlgorithmString = atom_to_binary(HashAlgorithm, utf8), - SendHeader = - <>, - <> = IV_KeySalt = - crypto:strong_rand_bytes(IVLen + KeyLen), - InitPacket = [SendHeader, IV_KeySalt], - ok = gen_tcp:send(Socket, InitPacket), - recv_init(SendParams#params{iv = IV, key = KeySalt}, Secret, SendHeader). +init(Socket, Secret) -> + {SendParams, Header, Data} = init_params(Socket), + ok = gen_tcp:send(Socket, [Header, Data]), + init( + SendParams, Secret, + [<<(byte_size(Header) + byte_size(Data)):32>>, Header]). -recv_init( +init( #params{ socket = Socket, hash_algorithm = SendHashAlgorithm, - key = SendKeySalt} = SendParams, Secret, SendHeader) -> + key = {SendKeySalt, SendOtherSalt}} = SendParams, Secret, SendAAD) -> %% - {ok, InitPacket} = gen_tcp:recv(Socket, 0), - [ProtocolString, Rest_1] = binary:split(InitPacket, <<0>>), - Protocol = binary_to_existing_atom(ProtocolString, utf8), - case Protocol of - ?PROTOCOL -> - [HashAlgorithmString, Rest_2] = binary:split(Rest_1, <<0>>), - HashAlgorithm = binary_to_existing_atom(HashAlgorithmString, utf8), - [BlockCryptoString, Rest_3] = binary:split(Rest_2, <<0>>), - BlockCrypto = binary_to_existing_atom(BlockCryptoString, utf8), - #params{ - hash_algorithm = RecvHashAlgorithm, - iv = RecvIVLen, - key = RecvKeyLen} = RecvParams = - init_params( - Socket, Protocol, HashAlgorithm, BlockCrypto, undefined), - <> = Rest_3, + {ok, InitMsg} = gen_tcp:recv(Socket, 0), + try + init_params(Socket, InitMsg) + of + {#params{ + hash_algorithm = RecvHashAlgorithm, + key = {RecvKeySalt, RecvOtherSalt}} = RecvParams, + RecvHeader} -> + RecvAAD = [<<(byte_size(InitMsg)):32>>, RecvHeader], SendKey = - hash_key(SendHashAlgorithm, SendKeySalt, [RecvKeySalt, Secret]), + hash_key( + SendHashAlgorithm, SendKeySalt, [RecvOtherSalt, Secret]), RecvKey = - hash_key(RecvHashAlgorithm, RecvKeySalt, [SendKeySalt, Secret]), + hash_key( + RecvHashAlgorithm, RecvKeySalt, [SendOtherSalt, Secret]), SendParams_1 = SendParams#params{key = SendKey}, - RecvParams_1 = RecvParams#params{iv = RecvIV, key = RecvKey}, - RecvHeaderLen = byte_size(InitPacket) - RecvIVLen - RecvKeyLen, - <> = InitPacket, - send_start(SendParams_1, SendHeader), - RecvRekeyInterval = recv_start(RecvParams_1, RecvHeader), - {SendParams_1, - RecvParams_1#params{rekey_interval = RecvRekeyInterval}} + RecvParams_1 = RecvParams#params{key = RecvKey}, + send_start(SendParams_1, SendAAD), + recv_start(RecvParams_1, RecvAAD), + {SendParams_1, RecvParams_1} + catch + error : Reason -> + _ = trace(Reason), + exit(connection_closed) end. send_start( #params{ socket = Socket, block_crypto = BlockCrypto, - rekey_interval= RekeyInterval, iv = IV, key = Key, tag_len = TagLen}, AAD) -> @@ -845,9 +797,7 @@ send_start( KeyLen = byte_size(Key), Zeros = binary:copy(<<0>>, KeyLen), {Ciphertext, CipherTag} = - crypto:block_encrypt( - crypto_cipher_name(BlockCrypto), - Key, IV, {AAD, [Zeros, <>], TagLen}), + crypto:block_encrypt(BlockCrypto, Key, IV, {AAD, Zeros, TagLen}), ok = gen_tcp:send(Socket, [Ciphertext, CipherTag]). recv_start( @@ -857,51 +807,100 @@ recv_start( iv = IV, key = Key, tag_len = TagLen}, AAD) -> + %% {ok, Packet} = gen_tcp:recv(Socket, 0), KeyLen = byte_size(Key), - PacketLen = KeyLen + 4, + PacketLen = KeyLen, <> = Packet, Zeros = binary:copy(<<0>>, KeyLen), case crypto:block_decrypt( - crypto_cipher_name(BlockCrypto), - Key, IV, {AAD, Ciphertext, CipherTag}) + BlockCrypto, Key, IV, {AAD, Ciphertext, CipherTag}) of - <> - when 1 =< RekeyInterval -> - RekeyInterval; + Zeros -> + ok; _ -> - error(decrypt_error) + _ = trace(decrypt_error), + exit(connection_closed) end. -init_params(Socket, Protocol, HashAlgorithm, BlockCrypto, RekeyInterval) -> - #{block_size := 1, - iv_length := IVLen, - key_length := KeyLen} = crypto:cipher_info(BlockCrypto), - case crypto:hash_info(HashAlgorithm) of - #{size := HashSize} when HashSize >= KeyLen -> - #params{ - socket = Socket, - protocol = Protocol, - hash_algorithm = HashAlgorithm, - block_crypto = BlockCrypto, - rekey_interval = RekeyInterval, - iv = IVLen, - key = KeyLen, - tag_len = 16} - end. -crypto_cipher_name(BlockCrypto) -> - case BlockCrypto of - aes_128_gcm -> aes_gcm; - aes_192_gcm -> aes_gcm; - aes_256_gcm -> aes_gcm + +init_params(Socket) -> + Protocol = ?PROTOCOL, + HashAlgorithm = ?HASH_ALGORITHM, + BlockCrypto = ?BLOCK_CRYPTO, + IVLen = ?IV_LEN, + KeyLen = ?KEY_LEN, + TagLen = ?TAG_LEN, + RekeyInterval = ?REKEY_INTERVAL, + %% + ProtocolBin = atom_to_binary(Protocol, utf8), + HashAlgorithmBin = atom_to_binary(HashAlgorithm, utf8), + BlockCryptoBin = atom_to_binary(BlockCrypto, utf8), + Header = + <<(byte_size(ProtocolBin)), ProtocolBin/binary, + (byte_size(HashAlgorithmBin)), HashAlgorithmBin/binary, + (byte_size(BlockCryptoBin)), BlockCryptoBin/binary, + KeyLen, TagLen>>, + <> = + Data = crypto:strong_rand_bytes(IVLen + KeyLen + KeyLen), + {#params{ + socket = Socket, + protocol = Protocol, + hash_algorithm = HashAlgorithm, + block_crypto = BlockCrypto, + iv = IV, + key = {KeySalt, OtherSalt}, + tag_len = TagLen, + rekey_interval = RekeyInterval}, Header, Data}. + +init_params(Socket, Msg) -> + Protocol = ?PROTOCOL, + ProtocolBin = atom_to_binary(Protocol, utf8), + ProtocolBinLen = byte_size(ProtocolBin), + case Msg of + <> -> + HashAlgorithm = ?HASH_ALGORITHM, + BlockCrypto = ?BLOCK_CRYPTO, + HashAlgorithmBin = atom_to_binary(HashAlgorithm, utf8), + HashAlgorithmBinLen = byte_size(HashAlgorithmBin), + BlockCryptoBin = atom_to_binary(BlockCrypto, utf8), + BlockCryptoBinLen = byte_size(BlockCryptoBin), + case Rest_1 of + <> -> + IVLen = ?IV_LEN, KeyLen = ?KEY_LEN, + TagLen = ?TAG_LEN, RekeyInterval = ?REKEY_INTERVAL, + <> = Rest_2, + HeaderLen = byte_size(Msg) - (byte_size(Rest_2) - 2), + <> = Msg, + {#params{ + socket = Socket, + protocol = Protocol, + hash_algorithm = HashAlgorithm, + block_crypto = BlockCrypto, + iv = IV, + key = {KeySalt, OtherSalt}, + tag_len = TagLen, + rekey_interval = RekeyInterval}, + Header} + end end. -hash_key(HashAlgorithm, KeySalt, OtherSalt) -> + + +hash_key(HashAlgorithm, KeySalt, KeyData) -> KeyLen = byte_size(KeySalt), <> = - crypto:hash(HashAlgorithm, [KeySalt, OtherSalt]), + crypto:hash(HashAlgorithm, [KeySalt, KeyData]), Key. %% ------------------------------------------------------------------------- @@ -918,7 +917,6 @@ handshake( reply(From, Result), handshake(SendParams, SendSeq, RecvParams, RecvSeq, Controller_1); {?MODULE, From, {handshake_complete, DistHandle}} -> - reply(From, ok), InputHandler = monitor_dist_proc( spawn_opt( @@ -945,25 +943,38 @@ handshake( ok = gen_tcp:controlling_process(Socket, InputHandler), ok = erlang:dist_ctrl_input_handler(DistHandle, InputHandler), InputHandler ! DistHandle, + crypto:rand_seed_alg(crypto_cache), + reply(From, ok), process_flag(priority, normal), erlang:dist_ctrl_get_data_notification(DistHandle), - crypto:rand_seed_alg(crypto_cache), output_handler( SendParams#params{dist_handle = DistHandle}, SendSeq); %% {?MODULE, From, {send, Data}} -> - {SendParams_1, SendSeq_1} = + case encrypt_and_send_chunk( - SendParams, SendSeq, [?HANDSHAKE_CHUNK, Data]), - reply(From, ok), - handshake( - SendParams_1, SendSeq_1, RecvParams, RecvSeq, Controller); + SendParams, SendSeq, [?HANDSHAKE_CHUNK, Data]) + of + {SendParams_1, SendSeq_1, ok} -> + reply(From, ok), + handshake( + SendParams_1, SendSeq_1, RecvParams, RecvSeq, + Controller); + {_, _, Error} -> + reply(From, {error, closed}), + death_row({send, trace(Error)}) + end; {?MODULE, From, recv} -> - {RecvParams_1, RecvSeq_1, Reply} = - recv_and_decrypt_chunk(RecvParams, RecvSeq), - reply(From, Reply), - handshake( - SendParams, SendSeq, RecvParams_1, RecvSeq_1, Controller); + case recv_and_decrypt_chunk(RecvParams, RecvSeq) of + {RecvParams_1, RecvSeq_1, {ok, _} = Reply} -> + reply(From, Reply), + handshake( + SendParams, SendSeq, RecvParams_1, RecvSeq_1, + Controller); + {_, _, Error} -> + reply(From, Error), + death_row({recv, trace(Error)}) + end; {?MODULE, From, peername} -> reply(From, inet:peername(Socket)), handshake(SendParams, SendSeq, RecvParams, RecvSeq, Controller); @@ -978,10 +989,12 @@ recv_and_decrypt_chunk(#params{socket = Socket} = RecvParams, RecvSeq) -> case decrypt_chunk(RecvParams, RecvSeq, Chunk) of <> -> {RecvParams, RecvSeq + 1, {ok, Cleartext}}; + OtherChunk when is_binary(OtherChunk) -> + {RecvParams, RecvSeq + 1, {error, decrypt_error}}; #params{} = RecvParams_1 -> recv_and_decrypt_chunk(RecvParams_1, 0); - _ -> - error(decrypt_error) + error -> + {RecvParams, RecvSeq, {error, decrypt_error}} end; Error -> {RecvParams, RecvSeq, Error} @@ -1001,8 +1014,9 @@ output_handler(Params, Seq) -> output_handler_data(Params, Seq); dist_tick -> output_handler_tick(Params, Seq); - _Other -> + Other -> %% Ignore + _ = trace(Other), output_handler(Params, Seq) end end. @@ -1015,14 +1029,15 @@ output_handler_data(Params, Seq) -> output_handler_data(Params, Seq); dist_tick -> output_handler_data(Params, Seq); - _Other -> + Other -> %% Ignore + _ = trace(Other), output_handler_data(Params, Seq) end after 0 -> DistHandle = Params#params.dist_handle, Q = get_data(DistHandle, empty_q()), - {Params_1, Seq_1} = output_handler_send(Params, Seq, Q, true), + {Params_1, Seq_1} = output_handler_send(Params, Seq, Q), erlang:dist_ctrl_get_data_notification(DistHandle), output_handler(Params_1, Seq_1) end. @@ -1035,39 +1050,56 @@ output_handler_tick(Params, Seq) -> output_handler_data(Params, Seq); dist_tick -> output_handler_tick(Params, Seq); - _Other -> + Other -> %% Ignore + _ = trace(Other), output_handler_tick(Params, Seq) end after 0 -> - TickSize = 8 + rand:uniform(56), + TickSize = 7 + rand:uniform(56), TickData = binary:copy(<<0>>, TickSize), - {Params_1, Seq_1} = - encrypt_and_send_chunk(Params, Seq, [?TICK_CHUNK, TickData]), - output_handler(Params_1, Seq_1) + case + encrypt_and_send_chunk(Params, Seq, [?TICK_CHUNK, TickData]) + of + {Params_1, Seq_1, ok} -> + output_handler(Params_1, Seq_1); + {_, _, Error} -> + _ = trace(Error), + death_row() + end end. -output_handler_send( - #params{dist_handle = DistHandle} = Params, Seq, {_, Size, _} = Q, Retry) -> - %% +output_handler_send(Params, Seq, {_, Size, _} = Q) -> if ?CHUNK_SIZE < Size -> - {Cleartext, Q_1} = deq_iovec(?CHUNK_SIZE, Q), - {Params_1, Seq_1} = - encrypt_and_send_chunk(Params, Seq, [?DATA_CHUNK, Cleartext]), - output_handler_send(Params_1, Seq_1, Q_1, Retry); - Retry -> - Q_1 = get_data(DistHandle, Q), - output_handler_send(Params, Seq, Q_1, false); + output_handler_send(Params, Seq, Q, ?CHUNK_SIZE); true -> - {Cleartext, _} = deq_iovec(Size, Q), - encrypt_and_send_chunk(Params, Seq, [?DATA_CHUNK, Cleartext]) + case get_data(Params#params.dist_handle, Q) of + {_, 0, _} -> + {Params, Seq}; + {_, Size, _} = Q_1 -> % Got no more + output_handler_send(Params, Seq, Q_1, Size); + Q_1 -> + output_handler_send(Params, Seq, Q_1) + end + end. + +output_handler_send(Params, Seq, Q, Size) -> + {Cleartext, Q_1} = deq_iovec(Size, Q), + case + encrypt_and_send_chunk(Params, Seq, [?DATA_CHUNK, Cleartext]) + of + {Params_1, Seq_1, ok} -> + output_handler_send(Params_1, Seq_1, Q_1); + {_, _, Error} -> + _ = trace(Error), + death_row() end. %% ------------------------------------------------------------------------- %% Input handler process %% -%% Here is T 0 or infinity to steer if we should try to receive +%% Here is T = 0|infinity to steer if we should try to receive %% more data or not; start with infinity, and when we get some %% data try with 0 to see if more is waiting @@ -1086,11 +1118,12 @@ input_handler(#params{socket = Socket} = Params, Seq, Q, T) -> end, input_handler(Params, Seq, Q_1, infinity); {tcp, Socket, Chunk} -> - input_chunk(Params, Seq, Q, Chunk); + input_chunk(Params, Seq, Q, T, Chunk); {tcp_closed, Socket} -> - error(connection_closed); - _Other -> + exit(connection_closed); + Other -> %% Ignore... + _ = trace(Other), input_handler(Params, Seq, Q, T) end after T -> @@ -1098,16 +1131,20 @@ input_handler(#params{socket = Socket} = Params, Seq, Q, T) -> input_handler(Params, Seq, Q_1, infinity) end. -input_chunk(Params, Seq, Q, Chunk) -> +input_chunk(Params, Seq, Q, T, Chunk) -> case decrypt_chunk(Params, Seq, Chunk) of <> -> input_handler(Params, Seq + 1, enq_binary(Cleartext, Q), 0); <> -> - input_handler(Params, Seq + 1, Q, 0); + input_handler(Params, Seq + 1, Q, T); + OtherChunk when is_binary(OtherChunk) -> + _ = trace(invalid_chunk), + exit(connection_closed); #params{} = Params_1 -> - input_handler(Params_1, 0, Q, 0); - _ -> - error(decrypt_error) + input_handler(Params_1, 0, Q, T); + error -> + _ = trace(decrypt_error), + exit(connection_closed) end. %% ------------------------------------------------------------------------- @@ -1206,14 +1243,22 @@ encrypt_and_send_chunk( IVLen = byte_size(IV), Chunk = <> = crypto:strong_rand_bytes(IVLen + KeyLen), - ok = gen_tcp:send(Socket, encrypt_chunk(Params, Seq, [?REKEY_CHUNK, Chunk])), - Key_1 = hash_key(HashAlgorithm, Key, KeySalt), - Params_1 = Params#params{key = Key_1, iv = IV_1}, - ok = gen_tcp:send(Socket, encrypt_chunk(Params_1, 0, Cleartext)), - {Params_1, 1}; + case + gen_tcp:send( + Socket, encrypt_chunk(Params, Seq, [?REKEY_CHUNK, Chunk])) + of + ok -> + Key_1 = hash_key(HashAlgorithm, Key, KeySalt), + Params_1 = Params#params{key = Key_1, iv = IV_1}, + Result = + gen_tcp:send(Socket, encrypt_chunk(Params_1, 0, Cleartext)), + {Params_1, 1, Result}; + SendError -> + {Params, Seq + 1, SendError} + end; encrypt_and_send_chunk(#params{socket = Socket} = Params, Seq, Cleartext) -> - ok = gen_tcp:send(Socket, encrypt_chunk(Params, Seq, Cleartext)), - {Params, Seq + 1}. + Result = gen_tcp:send(Socket, encrypt_chunk(Params, Seq, Cleartext)), + {Params, Seq + 1, Result}. encrypt_chunk( #params{ @@ -1223,8 +1268,7 @@ encrypt_chunk( ChunkLen = iolist_size(Cleartext) + TagLen, AAD = <>, {Ciphertext, CipherTag} = - crypto:block_encrypt( - crypto_cipher_name(BlockCrypto), Key, IV, {AAD, Cleartext, TagLen}), + crypto:block_encrypt(BlockCrypto, Key, IV, {AAD, Cleartext, TagLen}), Chunk = [Ciphertext,CipherTag], Chunk. @@ -1234,28 +1278,50 @@ decrypt_chunk( iv = IV, key = Key, tag_len = TagLen} = Params, Seq, Chunk) -> %% ChunkLen = byte_size(Chunk), - true = TagLen =< ChunkLen, % Assert - AAD = <>, - CiphertextLen = ChunkLen - TagLen, - <> = Chunk, - block_decrypt( - Params, Seq, crypto_cipher_name(BlockCrypto), - Key, IV, {AAD, Ciphertext, CipherTag}). + if + ChunkLen < TagLen -> + error; + true -> + AAD = <>, + CiphertextLen = ChunkLen - TagLen, + case Chunk of + <> -> + block_decrypt( + Params, Seq, BlockCrypto, Key, IV, + {AAD, Ciphertext, CipherTag}); + _ -> + error + end + end. block_decrypt( - #params{rekey_interval = Seq} = Params, Seq, CipherName, Key, IV, Data) -> + #params{rekey_interval = RekeyInterval} = Params, + Seq, CipherName, Key, IV, Data) -> %% - KeyLen = byte_size(Key), - IVLen = byte_size(IV), case crypto:block_decrypt(CipherName, Key, IV, Data) of - <> -> - Key_1 = hash_key(Params#params.hash_algorithm, Key, KeySalt), - Params#params{iv = IV_1, key = Key_1}; - _ -> - error(decrypt_error) - end; -block_decrypt(_Params, _Seq, CipherName, Key, IV, Data) -> - crypto:block_decrypt(CipherName, Key, IV, Data). + <> -> + KeyLen = byte_size(Key), + IVLen = byte_size(IV), + case Rest of + <> -> + Key_1 = + hash_key( + Params#params.hash_algorithm, Key, KeySalt), + Params#params{iv = IV_1, key = Key_1}; + _ -> + error + end; + Chunk when is_binary(Chunk) -> + case Seq of + RekeyInterval -> + error; + _ -> + Chunk + end; + error -> + error + end. %% ------------------------------------------------------------------------- %% Queue of binaries i.e an iovec queue @@ -1289,6 +1355,13 @@ deq_iovec(GetSize, [Bin|Front], Size, Rear, Acc) -> %% ------------------------------------------------------------------------- +death_row() -> death_row(connection_closed). +%% +death_row(normal) -> death_row(connection_closed); +death_row(Reason) -> receive after 5000 -> exit(Reason) end. + +%% ------------------------------------------------------------------------- + %% Trace point trace(Term) -> Term. @@ -1316,7 +1389,7 @@ dbg() -> dbg:stop(), dbg:tracer(), dbg:p(all, c), - dbg:tpl(?MODULE, cx), + dbg:tpl(?MODULE, trace, cx), dbg:tpl(erlang, dist_ctrl_get_data_notification, cx), dbg:tpl(erlang, dist_ctrl_get_data, cx), dbg:tpl(erlang, dist_ctrl_put_data, cx), -- cgit v1.2.3