aboutsummaryrefslogtreecommitdiffstats
path: root/lib/ssh/test/ssh_trpt_test_lib.erl
diff options
context:
space:
mode:
authorHans Nilsson <[email protected]>2015-06-15 22:09:53 +0200
committerHans Nilsson <[email protected]>2015-07-02 12:01:48 +0200
commit941ddfbeab3357177ce6eac709456fd881ac2429 (patch)
treebd5facee29c211e3bd2f1730f1467c6b7daa6326 /lib/ssh/test/ssh_trpt_test_lib.erl
parent2e36a5f144ff1c996f02d5a55b07f3c840869ebf (diff)
downloadotp-941ddfbeab3357177ce6eac709456fd881ac2429.tar.gz
otp-941ddfbeab3357177ce6eac709456fd881ac2429.tar.bz2
otp-941ddfbeab3357177ce6eac709456fd881ac2429.zip
ssh: Initial ssh_tprt_test_lib.erl and ssh_protocol_SUITE
This test lib is intended for deeper testing of the SSH application. It makes it possible to do exact steps in the message exchange to test "corner cases"
Diffstat (limited to 'lib/ssh/test/ssh_trpt_test_lib.erl')
-rw-r--r--lib/ssh/test/ssh_trpt_test_lib.erl691
1 files changed, 691 insertions, 0 deletions
diff --git a/lib/ssh/test/ssh_trpt_test_lib.erl b/lib/ssh/test/ssh_trpt_test_lib.erl
new file mode 100644
index 0000000000..8623020a31
--- /dev/null
+++ b/lib/ssh/test/ssh_trpt_test_lib.erl
@@ -0,0 +1,691 @@
+%%
+%% %CopyrightBegin%
+%%
+%% Copyright Ericsson AB 2004-2015. All Rights Reserved.
+%%
+%% The contents of this file are subject to the Erlang Public License,
+%% Version 1.1, (the "License"); you may not use this file except in
+%% compliance with the License. You should have received a copy of the
+%% Erlang Public License along with this software. If not, it can be
+%% retrieved online at http://www.erlang.org/.
+%%
+%% Software distributed under the License is distributed on an "AS IS"
+%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
+%% the License for the specific language governing rights and limitations
+%% under the License.
+%%
+%% %CopyrightEnd%
+%%
+%%
+
+-module(ssh_trpt_test_lib).
+
+%%-compile(export_all).
+
+-export([exec/1, exec/2,
+ format_msg/1,
+ server_host_port/1
+ ]
+ ).
+
+-include_lib("common_test/include/ct.hrl").
+-include_lib("ssh/src/ssh.hrl"). % ?UINT32, ?BYTE, #ssh{} ...
+-include_lib("ssh/src/ssh_transport.hrl").
+-include_lib("ssh/src/ssh_auth.hrl").
+
+%%%----------------------------------------------------------------
+-record(s, {
+ socket,
+ listen_socket,
+ opts = [],
+ timeout = 5000, % ms
+ seen_hello = false,
+ enc = <<>>,
+ ssh = #ssh{}, % #ssh{}
+ own_kexinit,
+ peer_kexinit,
+ vars = dict:new(),
+ reply = [], % Some repy msgs are generated hidden in ssh_transport :[
+ prints = [],
+ return_value
+ }).
+
+-define(role(S), ((S#s.ssh)#ssh.role) ).
+
+
+server_host_port(S=#s{}) ->
+ {Host,Port} = ok(inet:sockname(S#s.listen_socket)),
+ {host(Host), Port}.
+
+
+%%% Options: {print_messages, false} true|detail
+%%% {print_seqnums,false} true
+%%% {print_ops,false} true
+
+exec(L) -> exec(L, #s{}).
+
+exec(L, S) when is_list(L) -> lists:foldl(fun exec/2, S, L);
+
+exec(Op, S0=#s{}) ->
+ S1 = init_op_traces(Op, S0),
+ try seqnum_trace(
+ op(Op, S1))
+ of
+ S = #s{} ->
+ print_traces(S),
+ {ok,S}
+ catch
+ {fail,Reason,Se} ->
+ report_trace('', Reason, Se),
+ {error,{Op,Reason,Se}};
+
+ throw:Term ->
+ report_trace(throw, Term, S1),
+ throw(Term);
+
+ error:Error ->
+ report_trace(error, Error, S1),
+ error(Error);
+
+ exit:Exit ->
+ report_trace(exit, Exit, S1),
+ exit(Exit)
+ end;
+exec(Op, {ok,S=#s{}}) -> exec(Op, S);
+exec(_, Error) -> Error.
+
+
+%%%---- Server ops
+op(listen, S) when ?role(S) == undefined -> op({listen,0}, S);
+
+op({listen,Port}, S) when ?role(S) == undefined ->
+ S#s{listen_socket = ok(gen_tcp:listen(Port, mangle_opts([]))),
+ ssh = (S#s.ssh)#ssh{role=server}
+ };
+
+op({accept,Opts}, S) when ?role(S) == server ->
+ {ok,Socket} = gen_tcp:accept(S#s.listen_socket, S#s.timeout),
+ {Host,_Port} = ok(inet:sockname(Socket)),
+ S#s{socket = Socket,
+ ssh = init_ssh(server,Socket,[{host,host(Host)}|Opts]),
+ return_value = ok};
+
+%%%---- Client ops
+op({connect,Host,Port,Opts}, S) when ?role(S) == undefined ->
+ Socket = ok(gen_tcp:connect(host(Host), Port, mangle_opts([]))),
+ S#s{socket = Socket,
+ ssh = init_ssh(client, Socket, [{host,host(Host)}|Opts]),
+ return_value = ok};
+
+%%%---- ops for both client and server
+op(close_socket, S) ->
+ catch tcp_gen:close(S#s.socket),
+ catch tcp_gen:close(S#s.listen_socket),
+ S#s{socket = undefined,
+ listen_socket = undefined,
+ return_value = ok};
+
+op({set_options,Opts}, S) ->
+ S#s{opts = Opts};
+
+op({send,X}, S) ->
+ send(S, instantiate(X,S));
+
+op(receive_hello, S0) when S0#s.seen_hello =/= true ->
+ case recv(S0) of
+ S1=#s{return_value={hello,_}} -> S1;
+ S1=#s{} -> op(receive_hello, receive_wait(S1))
+ end;
+
+op(receive_msg, S) when S#s.seen_hello == true ->
+ try recv(S)
+ catch
+ {tcp,Exc} -> S#s{return_value=Exc}
+ end;
+
+
+op({expect,timeout,E}, S0) ->
+ try op(E, S0)
+ of
+ S=#s{} -> fail({expected,timeout,S#s.return_value}, S)
+ catch
+ {receive_timeout,_} -> S0#s{return_value=timeout}
+ end;
+
+op({match,M,E}, S0) ->
+ {Val,S2} = op_val(E, S0),
+ case match(M, Val, S2) of
+ {true,S3} ->
+ opt(print_ops,S3,
+ fun(true) ->
+ case dict:fold(
+ fun(K,V,Acc) ->
+ case dict:find(K,S0#s.vars) of
+ error -> [{K,V}|Acc];
+ _ -> Acc
+ end
+ end, [], S3#s.vars)
+ of
+ [] -> {"Matches! No new bindings.",[]};
+ New ->
+ Width = lists:max([length(atom_to_list(K)) || {K,_} <- New]),
+ {lists:flatten(
+ ["Matches! New bindings:~n" |
+ [io_lib:format(" ~*s = ~p~n",[Width,K,V]) || {K,V}<-New]]),
+ []}
+ end
+ end);
+ false ->
+ fail({expected,M,Val},
+ opt(print_ops,S2,fun(true) -> {"nomatch!!~n",[]} end)
+ )
+ end;
+
+op({print,E}, S0) ->
+ {Val,S} = op_val(E, S0),
+ io:format("Result of ~p ~p =~n~s~n",[?role(S0),E,format_msg(Val)]),
+ S;
+
+op(print_state, S) ->
+ io:format("State(~p)=~n~s~n",[?role(S), format_msg(S)]),
+ S;
+
+op('$$', S) ->
+ %% For matching etc
+ S.
+
+
+op_val(E, S0) ->
+ case catch op(E, S0) of
+ {'EXIT',{function_clause,[{ssh_trpt_test_lib,op,[E,S0],_}|_]}} ->
+ {instantiate(E,S0), S0};
+ S=#s{} ->
+ {S#s.return_value, S}
+ end.
+
+
+fail(Reason, S) ->
+ throw({fail, Reason, S}).
+
+%%%----------------------------------------------------------------
+%% No optimizations :)
+
+match('$$', V, S) ->
+ match(S#s.return_value, V, S);
+
+match('_', _, S) ->
+ {true, S};
+
+match(P, V, S) when is_atom(P) ->
+ case atom_to_list(P) of
+ "$"++_ ->
+ %% Variable
+ case dict:find(P,S#s.vars) of
+ {ok,Val} -> match(Val, V, S);
+ error -> {true,S#s{vars = dict:store(P,V,S#s.vars)}}
+ end;
+ _ when P==V ->
+ {true,S};
+ _ ->
+ false
+ end;
+
+match(P, V, S) when P==V ->
+ {true, S};
+
+match(P, V, S) when is_tuple(P),
+ is_tuple(V) ->
+ match(tuple_to_list(P), tuple_to_list(V), S);
+
+match([Hp|Tp], [Hv|Tv], S0) ->
+ case match(Hp, Hv, S0) of
+ {true,S} -> match(Tp, Tv, S);
+ false -> false
+ end;
+
+match(_, _, _) ->
+ false.
+
+
+
+instantiate('$$', S) ->
+ S#s.return_value; % FIXME: What if $$ or $... in return_value?
+
+instantiate(A, S) when is_atom(A) ->
+ case atom_to_list(A) of
+ "$"++_ ->
+ %% Variable
+ case dict:find(A,S#s.vars) of
+ {ok,Val} -> Val; % FIXME: What if $$ or $... in Val?
+ error -> throw({unbound,A})
+ end;
+ _ ->
+ A
+ end;
+
+instantiate(T, S) when is_tuple(T) ->
+ list_to_tuple( instantiate(tuple_to_list(T),S) );
+
+instantiate([H|T], S) ->
+ [instantiate(H,S) | instantiate(T,S)];
+
+instantiate(X, _S) ->
+ X.
+
+%%%================================================================
+%%%
+init_ssh(Role, Socket, Options0) ->
+ Options = [{user_interaction,false}
+ | Options0],
+ ssh_connection_handler:init_ssh(Role,
+ {2,0},
+ lists:concat(["SSH-2.0-ErlangTestLib ",Role]),
+ Options, Socket).
+
+mangle_opts(Options) ->
+ SysOpts = [{reuseaddr, true},
+ {active, false},
+ {mode, binary}
+ ],
+ SysOpts ++ lists:foldl(fun({K,_},Opts) ->
+ lists:keydelete(K,1,Opts)
+ end, Options, SysOpts).
+
+host({0,0,0,0}) -> "localhost";
+host(H) -> H.
+
+%%%----------------------------------------------------------------
+send(S=#s{ssh=C}, hello) ->
+ Hello = case ?role(S) of
+ client -> C#ssh.c_version;
+ server -> C#ssh.s_version
+ end ++ "\r\n",
+ send(S, list_to_binary(Hello));
+
+send(S0, ssh_msg_kexinit) ->
+ {Msg, Bytes, C0} = ssh_transport:key_exchange_init_msg(S0#s.ssh),
+ S1 = opt(print_messages, S0,
+ fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
+ S = case ?role(S1) of
+ server when is_record(S1#s.peer_kexinit, ssh_msg_kexinit) ->
+ {ok, C} =
+ ssh_transport:handle_kexinit_msg(S1#s.peer_kexinit, Msg, C0),
+ S1#s{peer_kexinit = used,
+ own_kexinit = used,
+ ssh = C};
+ _ ->
+ S1#s{ssh = C0,
+ own_kexinit = Msg}
+ end,
+ send_bytes(Bytes, S#s{return_value = Msg});
+
+send(S0, ssh_msg_kexdh_init) when ?role(S0) == client,
+ is_record(S0#s.peer_kexinit, ssh_msg_kexinit),
+ is_record(S0#s.own_kexinit, ssh_msg_kexinit) ->
+ {ok, NextKexMsgBin, C} =
+ ssh_transport:handle_kexinit_msg(S0#s.peer_kexinit, S0#s.own_kexinit, S0#s.ssh),
+
+ S = opt(print_messages, S0,
+ fun(X) when X==true;X==detail ->
+ #ssh{keyex_key = {{_Private, Public}, {_G, _P}}} = C,
+ Msg = #ssh_msg_kexdh_init{e = Public},
+ {"Send (reconstructed)~n~s~n",[format_msg(Msg)]}
+ end),
+
+ send_bytes(NextKexMsgBin, S#s{ssh = C,
+ peer_kexinit = used,
+ own_kexinit = used});
+
+send(S0, ssh_msg_kexdh_reply) ->
+ Bytes = proplists:get_value(ssh_msg_kexdh_reply, S0#s.reply),
+ S = opt(print_messages, S0,
+ fun(X) when X==true;X==detail ->
+ {{_Private, Public}, _} = (S0#s.ssh)#ssh.keyex_key,
+ Msg = #ssh_msg_kexdh_reply{public_host_key = 'Key',
+ f = Public,
+ h_sig = 'H_SIG'
+ },
+ {"Send (reconstructed)~n~s~n",[format_msg(Msg)]}
+ end),
+ send_bytes(Bytes, S#s{return_value = Bytes});
+
+send(S0, Line) when is_binary(Line) ->
+ S = opt(print_messages, S0,
+ fun(X) when X==true;X==detail -> {"Send line~n~p~n",[Line]} end),
+ send_bytes(Line, S#s{return_value = Line});
+
+%%% Msg = #ssh_msg_*{}
+send(S0, Msg) when is_tuple(Msg) ->
+ S = opt(print_messages, S0,
+ fun(X) when X==true;X==detail -> {"Send~n~s~n",[format_msg(Msg)]} end),
+ {Packet, C} = ssh_transport:ssh_packet(Msg, S#s.ssh),
+ send_bytes(Packet, S#s{ssh = C, %%inc_send_seq_num(C),
+ return_value = Msg}).
+
+send_bytes(B, S0) ->
+ S = opt(print_messages, S0, fun(detail) -> {"Send bytes~n~p~n",[B]} end),
+ ok(gen_tcp:send(S#s.socket, B)),
+ S.
+
+%%%----------------------------------------------------------------
+recv(S0 = #s{}) ->
+ S1 = receive_poll(S0),
+ case S1#s.seen_hello of
+ {more,Seen} ->
+ %% Has received parts of a line. Has not seen a complete hello.
+ try_find_crlf(Seen, S1);
+ false ->
+ %% Must see hello before binary messages
+ try_find_crlf(<<>>, S1);
+ true ->
+ %% Has seen hello, therefore no more crlf-messages are alowed.
+ S = receive_binary_msg(S1),
+ case M=S#s.return_value of
+ #ssh_msg_kexinit{} when ?role(S) == server,
+ S#s.own_kexinit =/= undefined ->
+ {ok, C} =
+ ssh_transport:handle_kexinit_msg(M, S#s.own_kexinit, S#s.ssh),
+ S#s{peer_kexinit = used,
+ own_kexinit = used,
+ ssh = C};
+ #ssh_msg_kexinit{} ->
+ S#s{peer_kexinit = M};
+ #ssh_msg_kexdh_init{} -> % Always the server
+ {ok, Reply, C} = ssh_transport:handle_kexdh_init(M, S#s.ssh),
+ S#s{ssh = C,
+ reply = [{ssh_msg_kexdh_reply,Reply} | S#s.reply]
+ };
+ #ssh_msg_kexdh_reply{} ->
+ {ok, _NewKeys, C} = ssh_transport:handle_kexdh_reply(M, S#s.ssh),
+ S#s{ssh=C#ssh{send_sequence=S#s.ssh#ssh.send_sequence}}; % Back the number
+ #ssh_msg_newkeys{} ->
+ {ok, C} = ssh_transport:handle_new_keys(M, S#s.ssh),
+ S#s{ssh=C};
+ _ ->
+ S
+ end
+ end.
+
+%%%================================================================
+try_find_crlf(Seen, S0) ->
+ case erlang:decode_packet(line,S0#s.enc,[]) of
+ {more,_} ->
+ Line = <<Seen/binary,(S0#s.enc)/binary>>,
+ S0#s{seen_hello = {more,Line},
+ enc = <<>>, % didn't find a complete line
+ % -> no more characters to test
+ return_value = {more,Line}
+ };
+ {ok,Used,Rest} ->
+ Line = <<Seen/binary,Used/binary>>,
+ case handle_hello(Line, S0) of
+ false ->
+ S = opt(print_messages, S0,
+ fun(X) when X==true;X==detail -> {"Recv info~n~p~n",[Line]} end),
+ S#s{seen_hello = false,
+ enc = Rest,
+ return_value = {info,Line}};
+ S1=#s{} ->
+ S = opt(print_messages, S1,
+ fun(X) when X==true;X==detail -> {"Recv hello~n~p~n",[Line]} end),
+ S#s{seen_hello = true,
+ enc = Rest,
+ return_value = {hello,Line}}
+ end
+ end.
+
+
+handle_hello(Bin, S=#s{ssh=C}) ->
+ case {ssh_transport:handle_hello_version(binary_to_list(Bin)),
+ ?role(S)}
+ of
+ {{undefined,_}, _} -> false;
+ {{Vp,Vs}, client} -> S#s{ssh = C#ssh{s_vsn=Vp, s_version=Vs}};
+ {{Vp,Vs}, server} -> S#s{ssh = C#ssh{c_vsn=Vp, c_version=Vs}}
+ end.
+
+receive_binary_msg(S0=#s{ssh=C0=#ssh{decrypt_block_size = BlockSize,
+ recv_mac_size = MacSize
+ }
+ }) ->
+ case size(S0#s.enc) >= max(8,BlockSize) of
+ false ->
+ %% Need more bytes to decode the packet_length field
+ Remaining = max(8,BlockSize) - size(S0#s.enc),
+ receive_binary_msg( receive_wait(Remaining, S0) );
+ true ->
+ %% Has enough bytes to decode the packet_length field
+ {_, <<?UINT32(PacketLen), _/binary>>, _} =
+ ssh_transport:decrypt_blocks(S0#s.enc, BlockSize, C0), % FIXME: BlockSize should be at least 4
+
+ %% FIXME: Check that ((4+PacketLen) rem BlockSize) == 0 ?
+
+ S1 = if
+ PacketLen > ?SSH_MAX_PACKET_SIZE ->
+ fail({too_large_message,PacketLen},S0); % FIXME: disconnect
+
+ ((4+PacketLen) rem BlockSize) =/= 0 ->
+ fail(bad_packet_length_modulo, S0); % FIXME: disconnect
+
+ size(S0#s.enc) >= (4 + PacketLen + MacSize) ->
+ %% has the whole packet
+ S0;
+
+ true ->
+ %% need more bytes to get have the whole packet
+ Remaining = (4 + PacketLen + MacSize) - size(S0#s.enc),
+ receive_wait(Remaining, S0)
+ end,
+
+ %% Decrypt all, including the packet_length part (re-use the initial #ssh{})
+ {C1, SshPacket = <<?UINT32(_),?BYTE(PadLen),Tail/binary>>, EncRest} =
+ ssh_transport:decrypt_blocks(S1#s.enc, PacketLen+4, C0),
+
+ PayloadLen = PacketLen - 1 - PadLen,
+ <<CompressedPayload:PayloadLen/binary, _Padding:PadLen/binary>> = Tail,
+
+ {C2, Payload} = ssh_transport:decompress(C1, CompressedPayload),
+
+ <<Mac:MacSize/binary, Rest/binary>> = EncRest,
+
+ case {ssh_transport:is_valid_mac(Mac, SshPacket, C2),
+ catch ssh_message:decode(Payload)}
+ of
+ {false, _} -> fail(bad_mac,S1);
+ {_, {'EXIT',_}} -> fail(decode_failed,S1);
+
+ {true, Msg} ->
+ C3 = case Msg of
+ #ssh_msg_kexinit{} ->
+ ssh_transport:key_init(opposite_role(C2), C2, Payload);
+ _ ->
+ C2
+ end,
+ S2 = opt(print_messages, S1,
+ fun(X) when X==true;X==detail -> {"Recv~n~s~n",[format_msg(Msg)]} end),
+ S3 = opt(print_messages, S2,
+ fun(detail) -> {"decrypted bytes ~p~n",[SshPacket]} end),
+ S3#s{ssh = inc_recv_seq_num(C3),
+ enc = Rest,
+ return_value = Msg
+ }
+ end
+ end.
+
+
+receive_poll(S=#s{socket=Sock}) ->
+ inet:setopts(Sock, [{active,once}]),
+ receive
+ {tcp,Sock,Data} ->
+ receive_poll( S#s{enc = <<(S#s.enc)/binary,Data/binary>>} );
+ {tcp_closed,Sock} ->
+ throw({tcp,tcp_closed});
+ {tcp_error, Sock, Reason} ->
+ throw({tcp,{tcp_error,Reason}})
+ after 0 ->
+ S
+ end.
+
+receive_wait(S=#s{socket=Sock,
+ timeout=Timeout}) ->
+ inet:setopts(Sock, [{active,once}]),
+ receive
+ {tcp,Sock,Data} ->
+ S#s{enc = <<(S#s.enc)/binary,Data/binary>>};
+ {tcp_closed,Sock} ->
+ throw({tcp,tcp_closed});
+ {tcp_error, Sock, Reason} ->
+ throw({tcp,{tcp_error,Reason}})
+ after Timeout ->
+ fail(receive_timeout,S)
+ end.
+
+receive_wait(N, S=#s{socket=Sock,
+ timeout=Timeout,
+ enc=Enc0}) when N>0 ->
+ inet:setopts(Sock, [{active,once}]),
+ receive
+ {tcp,Sock,Data} ->
+ receive_wait(N-size(Data), S#s{enc = <<Enc0/binary,Data/binary>>});
+ {tcp_closed,Sock} ->
+ throw({tcp,tcp_closed});
+ {tcp_error, Sock, Reason} ->
+ throw({tcp,{tcp_error,Reason}})
+ after Timeout ->
+ fail(receive_timeout, S)
+ end;
+receive_wait(_N, S) ->
+ S.
+
+%% random_padding_len(PaddingLen1, ChunkSize) ->
+%% MaxAdditionalRandomPaddingLen = % max 255 bytes padding totaƶ
+%% (255 - PaddingLen1) - ((255 - PaddingLen1) rem ChunkSize),
+%% AddLen0 = crypto:rand_uniform(0,MaxAdditionalRandomPaddingLen),
+%% AddLen0 - (AddLen0 rem ChunkSize). % preserve the blocking
+
+inc_recv_seq_num(C=#ssh{recv_sequence=N}) -> C#ssh{recv_sequence=(N+1) band 16#ffffffff}.
+%%%inc_send_seq_num(C=#ssh{send_sequence=N}) -> C#ssh{send_sequence=(N+1) band 16#ffffffff}.
+
+opposite_role(#ssh{role=R}) -> opposite_role(R);
+opposite_role(client) -> server;
+opposite_role(server) -> client.
+
+ok(ok) -> ok;
+ok({ok,R}) -> R;
+ok({error,E}) -> erlang:error(E).
+
+
+%%%================================================================
+%%%
+%%% Formating of records
+%%%
+
+format_msg(M) -> format_msg(M, 0).
+
+format_msg(M, I0) ->
+ case fields(M) of
+ undefined -> io_lib:format('~p',[M]);
+ Fields ->
+ [Name|Args] = tuple_to_list(M),
+ Head = io_lib:format('#~p{',[Name]),
+ I = lists:flatlength(Head)+I0,
+ NL = io_lib:format('~n~*c',[I,$ ]),
+ Sep = io_lib:format(',~n~*c',[I,$ ]),
+ Tail = [begin
+ S0 = io_lib:format('~p = ',[F]),
+ I1 = I + lists:flatlength(S0),
+ [S0,format_msg(A,I1)]
+ end
+ || {F,A} <- lists:zip(Fields,Args)],
+ [[Head|string:join(Tail,Sep)],NL,"}"]
+ end.
+
+fields(M) ->
+ case M of
+ #ssh_msg_debug{} -> record_info(fields, ssh_msg_debug);
+ #ssh_msg_disconnect{} -> record_info(fields, ssh_msg_disconnect);
+ #ssh_msg_ignore{} -> record_info(fields, ssh_msg_ignore);
+ #ssh_msg_kex_dh_gex_group{} -> record_info(fields, ssh_msg_kex_dh_gex_group);
+ #ssh_msg_kex_dh_gex_init{} -> record_info(fields, ssh_msg_kex_dh_gex_init);
+ #ssh_msg_kex_dh_gex_reply{} -> record_info(fields, ssh_msg_kex_dh_gex_reply);
+ #ssh_msg_kex_dh_gex_request{} -> record_info(fields, ssh_msg_kex_dh_gex_request);
+ #ssh_msg_kex_dh_gex_request_old{} -> record_info(fields, ssh_msg_kex_dh_gex_request_old);
+ #ssh_msg_kexdh_init{} -> record_info(fields, ssh_msg_kexdh_init);
+ #ssh_msg_kexdh_reply{} -> record_info(fields, ssh_msg_kexdh_reply);
+ #ssh_msg_kexinit{} -> record_info(fields, ssh_msg_kexinit);
+ #ssh_msg_newkeys{} -> record_info(fields, ssh_msg_newkeys);
+ #ssh_msg_service_accept{} -> record_info(fields, ssh_msg_service_accept);
+ #ssh_msg_service_request{} -> record_info(fields, ssh_msg_service_request);
+ #ssh_msg_unimplemented{} -> record_info(fields, ssh_msg_unimplemented);
+ #ssh_msg_userauth_request{} -> record_info(fields, ssh_msg_userauth_request);
+ #ssh_msg_userauth_failure{} -> record_info(fields, ssh_msg_userauth_failure);
+ #ssh_msg_userauth_success{} -> record_info(fields, ssh_msg_userauth_success);
+ #ssh_msg_userauth_banner{} -> record_info(fields, ssh_msg_userauth_banner);
+ #ssh_msg_userauth_passwd_changereq{} -> record_info(fields, ssh_msg_userauth_passwd_changereq);
+ #ssh_msg_userauth_pk_ok{} -> record_info(fields, ssh_msg_userauth_pk_ok);
+ #ssh_msg_userauth_info_request{} -> record_info(fields, ssh_msg_userauth_info_request);
+ #ssh_msg_userauth_info_response{} -> record_info(fields, ssh_msg_userauth_info_response);
+ #s{} -> record_info(fields, s);
+ #ssh{} -> record_info(fields, ssh);
+ #alg{} -> record_info(fields, alg);
+ _ -> undefined
+ end.
+
+%%%================================================================
+%%%
+%%% Trace handling
+%%%
+
+init_op_traces(Op, S0) ->
+ opt(print_ops, S0#s{prints=[]},
+ fun(true) ->
+ case ?role(S0) of
+ undefined -> {"-- ~p~n",[Op]};
+ Role -> {"-- ~p ~p~n",[Role,Op]}
+ end
+ end
+ ).
+
+report_trace(Class, Term, S) ->
+ print_traces(
+ opt(print_ops, S,
+ fun(true) -> {"~s ~p",[Class,Term]} end)
+ ).
+
+seqnum_trace(S) ->
+ opt(print_seqnums, S,
+ fun(true) when S#s.ssh#ssh.send_sequence =/= S#s.ssh#ssh.send_sequence,
+ S#s.ssh#ssh.recv_sequence =/= S#s.ssh#ssh.recv_sequence ->
+ {"~p seq num: send ~p->~p, recv ~p->~p~n",
+ [?role(S),
+ S#s.ssh#ssh.send_sequence, S#s.ssh#ssh.send_sequence,
+ S#s.ssh#ssh.recv_sequence, S#s.ssh#ssh.recv_sequence
+ ]};
+ (true) when S#s.ssh#ssh.send_sequence =/= S#s.ssh#ssh.send_sequence ->
+ {"~p seq num: send ~p->~p~n",
+ [?role(S),
+ S#s.ssh#ssh.send_sequence, S#s.ssh#ssh.send_sequence]};
+ (true) when S#s.ssh#ssh.recv_sequence =/= S#s.ssh#ssh.recv_sequence ->
+ {"~p seq num: recv ~p->~p~n",
+ [?role(S),
+ S#s.ssh#ssh.recv_sequence, S#s.ssh#ssh.recv_sequence]}
+ end).
+
+print_traces(S) when S#s.prints == [] -> S;
+print_traces(S) ->
+ ct:log("~s",
+ [lists:foldl(fun({Fmt,Args}, Acc) ->
+ [io_lib:format(Fmt,Args) | Acc]
+ end, "", S#s.prints)]
+ ).
+
+opt(Flag, S, Fun) when is_function(Fun,1) ->
+ try Fun(proplists:get_value(Flag,S#s.opts))
+ of P={Fmt,Args} when is_list(Fmt), is_list(Args) ->
+ save_prints(P, S)
+ catch _:_ ->
+ S
+ end.
+
+save_prints({Fmt,Args}, S) ->
+ S#s{prints = [{Fmt,Args}|S#s.prints]}.