diff options
Diffstat (limited to 'lib/ssh/test/ssh_trpt_test_lib.erl')
-rw-r--r-- | lib/ssh/test/ssh_trpt_test_lib.erl | 193 |
1 files changed, 98 insertions, 95 deletions
diff --git a/lib/ssh/test/ssh_trpt_test_lib.erl b/lib/ssh/test/ssh_trpt_test_lib.erl index bc86000d81..3f4df2c986 100644 --- a/lib/ssh/test/ssh_trpt_test_lib.erl +++ b/lib/ssh/test/ssh_trpt_test_lib.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2004-2016. All Rights Reserved. +%% Copyright Ericsson AB 2004-2017. All Rights Reserved. %% %% The contents of this file are subject to the Erlang Public License, %% Version 1.1, (the "License"); you may not use this file except in @@ -41,15 +41,20 @@ opts = [], timeout = 5000, % ms seen_hello = false, - enc = <<>>, ssh = #ssh{}, % #ssh{} alg_neg = {undefined,undefined}, % {own_kexinit, peer_kexinit} alg, % #alg{} vars = dict:new(), reply = [], % Some repy msgs are generated hidden in ssh_transport :[ prints = [], - return_value - }). + return_value, + + %% Packet retrival and decryption + decrypted_data_buffer = <<>>, + encrypted_data_buffer = <<>>, + aead_data = <<>>, + undecrypted_packet_length + }). -define(role(S), ((S#s.ssh)#ssh.role) ). @@ -85,15 +90,18 @@ exec(Op, S0=#s{}) -> throw:Term -> report_trace(throw, Term, S1), - throw(Term); + throw({Term,Op}); error:Error -> report_trace(error, Error, S1), - error(Error); + error({Error,Op}); exit:Exit -> report_trace(exit, Exit, S1), - exit(Exit) + exit({Exit,Op}); + Cls:Err -> + ct:pal("Class=~p, Error=~p", [Cls,Err]), + error({"fooooooO",Op}) end; exec(Op, {ok,S=#s{}}) -> exec(Op, S); exec(_, Error) -> Error. @@ -111,20 +119,20 @@ op({accept,Opts}, S) when ?role(S) == server -> {ok,Socket} = gen_tcp:accept(S#s.listen_socket, S#s.timeout), {Host,_Port} = ok(inet:sockname(Socket)), S#s{socket = Socket, - ssh = init_ssh(server,Socket,[{host,host(Host)}|Opts]), + ssh = init_ssh(server, Socket, host(Host), Opts), return_value = ok}; %%%---- Client ops op({connect,Host,Port,Opts}, S) when ?role(S) == undefined -> Socket = ok(gen_tcp:connect(host(Host), Port, mangle_opts([]))), S#s{socket = Socket, - ssh = init_ssh(client, Socket, [{host,host(Host)}|Opts]), + ssh = init_ssh(client, Socket, host(Host), Opts), return_value = ok}; %%%---- ops for both client and server op(close_socket, S) -> - catch tcp_gen:close(S#s.socket), - catch tcp_gen:close(S#s.listen_socket), + catch gen_tcp:close(S#s.socket), + catch gen_tcp:close(S#s.listen_socket), S#s{socket = undefined, listen_socket = undefined, return_value = ok}; @@ -293,12 +301,14 @@ instantiate(X, _S) -> %%%================================================================ %%% -init_ssh(Role, Socket, Options0) -> - Options = [{user_interaction, false}, - {vsn, {2,0}}, - {id_string, "ErlangTestLib"} - | Options0], - ssh_connection_handler:init_ssh_record(Role, Socket, Options). +init_ssh(Role, Socket, Host, UserOptions0) -> + UserOptions = [{user_interaction, false}, + {vsn, {2,0}}, + {id_string, "ErlangTestLib"} + | UserOptions0], + Opts = ?PUT_INTERNAL_OPT({host,Host}, + ssh_options:handle_options(Role, UserOptions)), + ssh_connection_handler:init_ssh_record(Role, Socket, Opts). mangle_opts(Options) -> SysOpts = [{reuseaddr, true}, @@ -309,8 +319,7 @@ mangle_opts(Options) -> lists:keydelete(K,1,Opts) end, Options, SysOpts). -host({0,0,0,0}) -> "localhost"; -host(H) -> H. +host(H) -> ssh_test_lib:ntoa(ssh_test_lib:mangle_connect_address(H)). %%%---------------------------------------------------------------- send(S=#s{ssh=C}, hello) -> @@ -393,6 +402,12 @@ send(S0, {special,Msg,PacketFun}) when is_tuple(Msg), send_bytes(Packet, S#s{ssh = C, %%inc_send_seq_num(C), return_value = Msg}); +send(S0, #ssh_msg_newkeys{} = Msg) -> + S = opt(print_messages, S0, + fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end), + {ok, Packet, C} = ssh_transport:new_keys_message(S#s.ssh), + send_bytes(Packet, S#s{ssh = C}); + send(S0, Msg) when is_tuple(Msg) -> S = opt(print_messages, S0, fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end), @@ -451,7 +466,10 @@ recv(S0 = #s{}) -> }; #ssh_msg_kexdh_reply{} -> {ok, _NewKeys, C} = ssh_transport:handle_kexdh_reply(PeerMsg, S#s.ssh), - S#s{ssh=C#ssh{send_sequence=S#s.ssh#ssh.send_sequence}}; % Back the number + S#s{ssh = (S#s.ssh)#ssh{shared_secret = C#ssh.shared_secret, + exchanged_hash = C#ssh.exchanged_hash, + session_id = C#ssh.session_id}}; + %%%S#s{ssh=C#ssh{send_sequence=S#s.ssh#ssh.send_sequence}}; % Back the number #ssh_msg_newkeys{} -> {ok, C} = ssh_transport:handle_new_keys(PeerMsg, S#s.ssh), S#s{ssh=C}; @@ -462,11 +480,11 @@ recv(S0 = #s{}) -> %%%================================================================ try_find_crlf(Seen, S0) -> - case erlang:decode_packet(line,S0#s.enc,[]) of + case erlang:decode_packet(line,S0#s.encrypted_data_buffer,[]) of {more,_} -> - Line = <<Seen/binary,(S0#s.enc)/binary>>, + Line = <<Seen/binary,(S0#s.encrypted_data_buffer)/binary>>, S0#s{seen_hello = {more,Line}, - enc = <<>>, % didn't find a complete line + encrypted_data_buffer = <<>>, % didn't find a complete line % -> no more characters to test return_value = {more,Line} }; @@ -477,13 +495,13 @@ try_find_crlf(Seen, S0) -> S = opt(print_messages, S0, fun(X) when X==true;X==detail -> {"Recv info~n~p~n",[Line]} end), S#s{seen_hello = false, - enc = Rest, + encrypted_data_buffer = Rest, return_value = {info,Line}}; S1=#s{} -> S = opt(print_messages, S1, fun(X) when X==true;X==detail -> {"Recv hello~n~p~n",[Line]} end), S#s{seen_hello = true, - enc = Rest, + encrypted_data_buffer = Rest, return_value = {hello,Line}} end end. @@ -498,73 +516,58 @@ handle_hello(Bin, S=#s{ssh=C}) -> {{Vp,Vs}, server} -> S#s{ssh = C#ssh{c_vsn=Vp, c_version=Vs}} end. -receive_binary_msg(S0=#s{ssh=C0=#ssh{decrypt_block_size = BlockSize, - recv_mac_size = MacSize - } - }) -> - case size(S0#s.enc) >= max(8,BlockSize) of - false -> - %% Need more bytes to decode the packet_length field - Remaining = max(8,BlockSize) - size(S0#s.enc), - receive_binary_msg( receive_wait(Remaining, S0) ); - true -> - %% Has enough bytes to decode the packet_length field - {_, <<?UINT32(PacketLen), _/binary>>, _} = - ssh_transport:decrypt_blocks(S0#s.enc, BlockSize, C0), % FIXME: BlockSize should be at least 4 - - %% FIXME: Check that ((4+PacketLen) rem BlockSize) == 0 ? - - S1 = if - PacketLen > ?SSH_MAX_PACKET_SIZE -> - fail({too_large_message,PacketLen},S0); % FIXME: disconnect - - ((4+PacketLen) rem BlockSize) =/= 0 -> - fail(bad_packet_length_modulo, S0); % FIXME: disconnect - - size(S0#s.enc) >= (4 + PacketLen + MacSize) -> - %% has the whole packet - S0; - - true -> - %% need more bytes to get have the whole packet - Remaining = (4 + PacketLen + MacSize) - size(S0#s.enc), - receive_wait(Remaining, S0) - end, - - %% Decrypt all, including the packet_length part (re-use the initial #ssh{}) - {C1, SshPacket = <<?UINT32(_),?BYTE(PadLen),Tail/binary>>, EncRest} = - ssh_transport:decrypt_blocks(S1#s.enc, PacketLen+4, C0), - - PayloadLen = PacketLen - 1 - PadLen, - <<CompressedPayload:PayloadLen/binary, _Padding:PadLen/binary>> = Tail, - - {C2, Payload} = ssh_transport:decompress(C1, CompressedPayload), - - <<Mac:MacSize/binary, Rest/binary>> = EncRest, - - case {ssh_transport:is_valid_mac(Mac, SshPacket, C2), - catch ssh_message:decode(set_prefix_if_trouble(Payload,S1))} - of - {false, _} -> fail(bad_mac,S1); - {_, {'EXIT',_}} -> fail(decode_failed,S1); - - {true, Msg} -> - C3 = case Msg of - #ssh_msg_kexinit{} -> - ssh_transport:key_init(opposite_role(C2), C2, Payload); - _ -> - C2 +receive_binary_msg(S0=#s{}) -> + case ssh_transport:handle_packet_part( + S0#s.decrypted_data_buffer, + S0#s.encrypted_data_buffer, + S0#s.aead_data, + S0#s.undecrypted_packet_length, + S0#s.ssh) + of + {packet_decrypted, DecryptedBytes, EncryptedDataRest, Ssh1} -> + S1 = S0#s{ssh = Ssh1#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh1#ssh.recv_sequence)}, + decrypted_data_buffer = <<>>, + undecrypted_packet_length = undefined, + aead_data = <<>>, + encrypted_data_buffer = EncryptedDataRest}, + case + catch ssh_message:decode(set_prefix_if_trouble(DecryptedBytes,S1)) + of + {'EXIT',_} -> fail(decode_failed,S1); + + Msg -> + Ssh2 = case Msg of + #ssh_msg_kexinit{} -> + ssh_transport:key_init(opposite_role(Ssh1), Ssh1, DecryptedBytes); + _ -> + Ssh1 end, - S2 = opt(print_messages, S1, - fun(X) when X==true;X==detail -> {"Recv~n~s~n",[format_msg(Msg)]} end), - S3 = opt(print_messages, S2, - fun(detail) -> {"decrypted bytes ~p~n",[SshPacket]} end), - S3#s{ssh = inc_recv_seq_num(C3), - enc = Rest, - return_value = Msg - } - end - end. + S2 = opt(print_messages, S1, + fun(X) when X==true;X==detail -> {"Recv~n~s~n",[format_msg(Msg)]} end), + S3 = opt(print_messages, S2, + fun(detail) -> {"decrypted bytes ~p~n",[DecryptedBytes]} end), + S3#s{ssh = inc_recv_seq_num(Ssh2), + return_value = Msg + } + end; + + {get_more, DecryptedBytes, EncryptedDataRest, AeadData, TotalNeeded, Ssh1} -> + %% Here we know that there are not enough bytes in + %% EncryptedDataRest to use. We must wait for more. + Remaining = case TotalNeeded of + undefined -> 8; + _ -> TotalNeeded - size(DecryptedBytes) - size(EncryptedDataRest) + end, + receive_binary_msg( + receive_wait(Remaining, + S0#s{encrypted_data_buffer = EncryptedDataRest, + decrypted_data_buffer = DecryptedBytes, + undecrypted_packet_length = TotalNeeded, + aead_data = AeadData, + ssh = Ssh1} + )) + end. + set_prefix_if_trouble(Msg = <<?BYTE(Op),_/binary>>, #s{alg=#alg{kex=Kex}}) @@ -589,7 +592,7 @@ receive_poll(S=#s{socket=Sock}) -> inet:setopts(Sock, [{active,once}]), receive {tcp,Sock,Data} -> - receive_poll( S#s{enc = <<(S#s.enc)/binary,Data/binary>>} ); + receive_poll( S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>} ); {tcp_closed,Sock} -> throw({tcp,tcp_closed}); {tcp_error, Sock, Reason} -> @@ -603,7 +606,7 @@ receive_wait(S=#s{socket=Sock, inet:setopts(Sock, [{active,once}]), receive {tcp,Sock,Data} -> - S#s{enc = <<(S#s.enc)/binary,Data/binary>>}; + S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>}; {tcp_closed,Sock} -> throw({tcp,tcp_closed}); {tcp_error, Sock, Reason} -> @@ -614,11 +617,11 @@ receive_wait(S=#s{socket=Sock, receive_wait(N, S=#s{socket=Sock, timeout=Timeout, - enc=Enc0}) when N>0 -> + encrypted_data_buffer=Enc0}) when N>0 -> inet:setopts(Sock, [{active,once}]), receive {tcp,Sock,Data} -> - receive_wait(N-size(Data), S#s{enc = <<Enc0/binary,Data/binary>>}); + receive_wait(N-size(Data), S#s{encrypted_data_buffer = <<Enc0/binary,Data/binary>>}); {tcp_closed,Sock} -> throw({tcp,tcp_closed}); {tcp_error, Sock, Reason} -> |