aboutsummaryrefslogtreecommitdiffstats
path: root/lib/ssh/test/ssh_trpt_test_lib.erl
diff options
context:
space:
mode:
Diffstat (limited to 'lib/ssh/test/ssh_trpt_test_lib.erl')
-rw-r--r--lib/ssh/test/ssh_trpt_test_lib.erl99
1 files changed, 79 insertions, 20 deletions
diff --git a/lib/ssh/test/ssh_trpt_test_lib.erl b/lib/ssh/test/ssh_trpt_test_lib.erl
index 8de550af15..f2c9892f95 100644
--- a/lib/ssh/test/ssh_trpt_test_lib.erl
+++ b/lib/ssh/test/ssh_trpt_test_lib.erl
@@ -41,15 +41,20 @@
opts = [],
timeout = 5000, % ms
seen_hello = false,
- enc = <<>>,
ssh = #ssh{}, % #ssh{}
alg_neg = {undefined,undefined}, % {own_kexinit, peer_kexinit}
alg, % #alg{}
vars = dict:new(),
reply = [], % Some repy msgs are generated hidden in ssh_transport :[
prints = [],
- return_value
- }).
+ return_value,
+
+ %% Packet retrival and decryption
+ decrypted_data_buffer = <<>>,
+ encrypted_data_buffer = <<>>,
+ aead_data = <<>>,
+ undecrypted_packet_length
+ }).
-define(role(S), ((S#s.ssh)#ssh.role) ).
@@ -475,11 +480,11 @@ recv(S0 = #s{}) ->
%%%================================================================
try_find_crlf(Seen, S0) ->
- case erlang:decode_packet(line,S0#s.enc,[]) of
+ case erlang:decode_packet(line,S0#s.encrypted_data_buffer,[]) of
{more,_} ->
- Line = <<Seen/binary,(S0#s.enc)/binary>>,
+ Line = <<Seen/binary,(S0#s.encrypted_data_buffer)/binary>>,
S0#s{seen_hello = {more,Line},
- enc = <<>>, % didn't find a complete line
+ encrypted_data_buffer = <<>>, % didn't find a complete line
% -> no more characters to test
return_value = {more,Line}
};
@@ -490,13 +495,13 @@ try_find_crlf(Seen, S0) ->
S = opt(print_messages, S0,
fun(X) when X==true;X==detail -> {"Recv info~n~p~n",[Line]} end),
S#s{seen_hello = false,
- enc = Rest,
+ encrypted_data_buffer = Rest,
return_value = {info,Line}};
S1=#s{} ->
S = opt(print_messages, S1,
fun(X) when X==true;X==detail -> {"Recv hello~n~p~n",[Line]} end),
S#s{seen_hello = true,
- enc = Rest,
+ encrypted_data_buffer = Rest,
return_value = {hello,Line}}
end
end.
@@ -511,19 +516,73 @@ handle_hello(Bin, S=#s{ssh=C}) ->
{{Vp,Vs}, server} -> S#s{ssh = C#ssh{c_vsn=Vp, c_version=Vs}}
end.
-receive_binary_msg(S0=#s{ssh=C0=#ssh{decrypt_block_size = BlockSize,
+receive_binary_msg(S0=#s{}) ->
+ case ssh_transport:handle_packet_part(
+ S0#s.decrypted_data_buffer,
+ S0#s.encrypted_data_buffer,
+ S0#s.aead_data,
+ S0#s.undecrypted_packet_length,
+ S0#s.ssh)
+ of
+ {packet_decrypted, DecryptedBytes, EncryptedDataRest, Ssh1} ->
+ S1 = S0#s{ssh = Ssh1#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh1#ssh.recv_sequence)},
+ decrypted_data_buffer = <<>>,
+ undecrypted_packet_length = undefined,
+ aead_data = <<>>,
+ encrypted_data_buffer = EncryptedDataRest},
+ case
+ catch ssh_message:decode(set_prefix_if_trouble(DecryptedBytes,S1))
+ of
+ {'EXIT',_} -> fail(decode_failed,S1);
+
+ Msg ->
+ Ssh2 = case Msg of
+ #ssh_msg_kexinit{} ->
+ ssh_transport:key_init(opposite_role(Ssh1), Ssh1, DecryptedBytes);
+ _ ->
+ Ssh1
+ end,
+ S2 = opt(print_messages, S1,
+ fun(X) when X==true;X==detail -> {"Recv~n~s~n",[format_msg(Msg)]} end),
+ S3 = opt(print_messages, S2,
+ fun(detail) -> {"decrypted bytes ~p~n",[DecryptedBytes]} end),
+ S3#s{ssh = inc_recv_seq_num(Ssh2),
+ return_value = Msg
+ }
+ end;
+
+ {get_more, DecryptedBytes, EncryptedDataRest, AeadData, TotalNeeded, Ssh1} ->
+ %% Here we know that there are not enough bytes in
+ %% EncryptedDataRest to use. We must wait for more.
+ Remaining = case TotalNeeded of
+ undefined -> 8;
+ _ -> TotalNeeded - size(DecryptedBytes) - size(EncryptedDataRest)
+ end,
+ receive_binary_msg(
+ receive_wait(Remaining,
+ S0#s{encrypted_data_buffer = EncryptedDataRest,
+ decrypted_data_buffer = DecryptedBytes,
+ undecrypted_packet_length = TotalNeeded,
+ aead_data = AeadData,
+ ssh = Ssh1}
+ ))
+ end.
+
+
+
+old_receive_binary_msg(S0=#s{ssh=C0=#ssh{decrypt_block_size = BlockSize,
recv_mac_size = MacSize
}
}) ->
- case size(S0#s.enc) >= max(8,BlockSize) of
+ case size(S0#s.encrypted_data_buffer) >= max(8,BlockSize) of
false ->
%% Need more bytes to decode the packet_length field
- Remaining = max(8,BlockSize) - size(S0#s.enc),
+ Remaining = max(8,BlockSize) - size(S0#s.encrypted_data_buffer),
receive_binary_msg( receive_wait(Remaining, S0) );
true ->
%% Has enough bytes to decode the packet_length field
{_, <<?UINT32(PacketLen), _/binary>>, _} =
- ssh_transport:decrypt_blocks(S0#s.enc, BlockSize, C0), % FIXME: BlockSize should be at least 4
+ ssh_transport:decrypt_blocks(S0#s.encrypted_data_buffer, BlockSize, C0), % FIXME: BlockSize should be at least 4
%% FIXME: Check that ((4+PacketLen) rem BlockSize) == 0 ?
@@ -534,19 +593,19 @@ receive_binary_msg(S0=#s{ssh=C0=#ssh{decrypt_block_size = BlockSize,
((4+PacketLen) rem BlockSize) =/= 0 ->
fail(bad_packet_length_modulo, S0); % FIXME: disconnect
- size(S0#s.enc) >= (4 + PacketLen + MacSize) ->
+ size(S0#s.encrypted_data_buffer) >= (4 + PacketLen + MacSize) ->
%% has the whole packet
S0;
true ->
%% need more bytes to get have the whole packet
- Remaining = (4 + PacketLen + MacSize) - size(S0#s.enc),
+ Remaining = (4 + PacketLen + MacSize) - size(S0#s.encrypted_data_buffer),
receive_wait(Remaining, S0)
end,
%% Decrypt all, including the packet_length part (re-use the initial #ssh{})
{C1, SshPacket = <<?UINT32(_),?BYTE(PadLen),Tail/binary>>, EncRest} =
- ssh_transport:decrypt_blocks(S1#s.enc, PacketLen+4, C0),
+ ssh_transport:decrypt_blocks(S1#s.encrypted_data_buffer, PacketLen+4, C0),
PayloadLen = PacketLen - 1 - PadLen,
<<CompressedPayload:PayloadLen/binary, _Padding:PadLen/binary>> = Tail,
@@ -573,7 +632,7 @@ receive_binary_msg(S0=#s{ssh=C0=#ssh{decrypt_block_size = BlockSize,
S3 = opt(print_messages, S2,
fun(detail) -> {"decrypted bytes ~p~n",[SshPacket]} end),
S3#s{ssh = inc_recv_seq_num(C3),
- enc = Rest,
+ encrypted_data_buffer = Rest,
return_value = Msg
}
end
@@ -602,7 +661,7 @@ receive_poll(S=#s{socket=Sock}) ->
inet:setopts(Sock, [{active,once}]),
receive
{tcp,Sock,Data} ->
- receive_poll( S#s{enc = <<(S#s.enc)/binary,Data/binary>>} );
+ receive_poll( S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>} );
{tcp_closed,Sock} ->
throw({tcp,tcp_closed});
{tcp_error, Sock, Reason} ->
@@ -616,7 +675,7 @@ receive_wait(S=#s{socket=Sock,
inet:setopts(Sock, [{active,once}]),
receive
{tcp,Sock,Data} ->
- S#s{enc = <<(S#s.enc)/binary,Data/binary>>};
+ S#s{encrypted_data_buffer = <<(S#s.encrypted_data_buffer)/binary,Data/binary>>};
{tcp_closed,Sock} ->
throw({tcp,tcp_closed});
{tcp_error, Sock, Reason} ->
@@ -627,11 +686,11 @@ receive_wait(S=#s{socket=Sock,
receive_wait(N, S=#s{socket=Sock,
timeout=Timeout,
- enc=Enc0}) when N>0 ->
+ encrypted_data_buffer=Enc0}) when N>0 ->
inet:setopts(Sock, [{active,once}]),
receive
{tcp,Sock,Data} ->
- receive_wait(N-size(Data), S#s{enc = <<Enc0/binary,Data/binary>>});
+ receive_wait(N-size(Data), S#s{encrypted_data_buffer = <<Enc0/binary,Data/binary>>});
{tcp_closed,Sock} ->
throw({tcp,tcp_closed});
{tcp_error, Sock, Reason} ->