From 9dc46e8d58c9464c8a48b74342951265c3b43dc8 Mon Sep 17 00:00:00 2001 From: Hans Nilsson Date: Fri, 22 Jan 2016 19:28:16 +0100 Subject: ssh: Gen_statem rewrite of ssh_connection_handler Including misc fixes in surronding code as well as in test cases. --- lib/ssh/src/ssh_connection_handler.erl | 1760 +++++++++++++++----------------- 1 file changed, 817 insertions(+), 943 deletions(-) (limited to 'lib/ssh/src/ssh_connection_handler.erl') diff --git a/lib/ssh/src/ssh_connection_handler.erl b/lib/ssh/src/ssh_connection_handler.erl index 2bef6a41cd..d26c586c54 100644 --- a/lib/ssh/src/ssh_connection_handler.erl +++ b/lib/ssh/src/ssh_connection_handler.erl @@ -28,7 +28,7 @@ -module(ssh_connection_handler). --behaviour(gen_fsm). +-behaviour(gen_statem). -include("ssh.hrl"). -include("ssh_transport.hrl"). @@ -37,45 +37,37 @@ -compile(export_all). -export([start_link/3]). +%%-define(IO_FORMAT(F,A), io:format(F,A)). +-define(IO_FORMAT(F,A), ok). + %% Internal application API -export([open_channel/6, reply_request/3, request/6, request/7, global_request/4, send/5, send_eof/2, info/1, info/2, connection_info/2, channel_info/3, adjust_window/3, close/2, stop/1, renegotiate/1, renegotiate_data/1, + disconnect/1, disconnect/2, start_connection/4, get_print_info/1]). -%% gen_fsm callbacks --export([hello/2, kexinit/2, key_exchange/2, - key_exchange_dh_gex_init/2, key_exchange_dh_gex_reply/2, - new_keys/2, - service_request/2, connected/2, - userauth/2, - userauth_keyboard_interactive/2, - userauth_keyboard_interactive_info_response/2, - error/2]). - --export([init/1, handle_event/3, - handle_sync_event/4, handle_info/3, terminate/3, format_status/2, code_change/4]). +%% gen_statem callbacks +-export([init/1, handle_event/4, terminate/3, format_status/2, code_change/4]). -record(state, { - role, client, starter, auth_user, connection_state, latest_channel_id = 0, idle_timer_ref, - transport_protocol, % ex: tcp + transport_protocol, % ex: tcp transport_cb, transport_close_tag, - ssh_params, % #ssh{} - from ssh.hrl - socket, % socket() - decoded_data_buffer, % binary() - encoded_data_buffer, % binary() + ssh_params, % #ssh{} - from ssh.hrl + socket, % socket() + decoded_data_buffer, % binary() + encoded_data_buffer, % binary() undecoded_packet_length, % integer() - key_exchange_init_msg, % #ssh_msg_kexinit{} - renegotiate = false, % boolean() + key_exchange_init_msg, % #ssh_msg_kexinit{} last_size_rekey = 0, event_queue = [], connection_queue, @@ -83,30 +75,13 @@ port, opts, recbuf - }). - --type state_name() :: hello | kexinit | key_exchange | key_exchange_dh_gex_init | - key_exchange_dh_gex_reply | new_keys | service_request | - userauth | userauth_keyboard_interactive | - userauth_keyboard_interactive_info_response | - connection. - --type gen_fsm_state_return() :: {next_state, state_name(), term()} | - {next_state, state_name(), term(), timeout()} | - {stop, term(), term()}. - --type gen_fsm_sync_return() :: {next_state, state_name(), term()} | - {next_state, state_name(), term(), timeout()} | - {reply, term(), state_name(), term()} | - {stop, term(), term(), term()}. + }). %%==================================================================== %% Internal application API %%==================================================================== %%-------------------------------------------------------------------- --spec start_connection(client| server, port(), proplists:proplist(), - timeout()) -> {ok, pid()} | {error, term()}. %%-------------------------------------------------------------------- start_connection(client = Role, Socket, Options, Timeout) -> try @@ -128,8 +103,8 @@ start_connection(server = Role, Socket, Options, Timeout) -> try case proplists:get_value(parallel_login, SSH_Opts, false) of true -> - HandshakerPid = - spawn_link(fun() -> + HandshakerPid = + spawn_link(fun() -> receive {do_handshake, Pid} -> handshake(Pid, erlang:monitor(process,Pid), Timeout) @@ -164,11 +139,10 @@ start_link(Role, Socket, Options) -> init([Role, Socket, SshOpts]) -> process_flag(trap_exit, true), {NumVsn, StrVsn} = ssh_transport:versions(Role, SshOpts), - {Protocol, Callback, CloseTag} = + {Protocol, Callback, CloseTag} = proplists:get_value(transport, SshOpts, {tcp, gen_tcp, tcp_closed}), Cache = ssh_channel:cache_create(), State0 = #state{ - role = Role, connection_state = #connection{channel_cache = Cache, channel_id_seed = 0, port_bindings = [], @@ -183,142 +157,118 @@ init([Role, Socket, SshOpts]) -> opts = SshOpts }, - State = init_role(State0), + State = init_role(Role, State0), try init_ssh(Role, NumVsn, StrVsn, SshOpts, Socket) of Ssh -> - gen_fsm:enter_loop(?MODULE, [], hello, - State#state{ssh_params = Ssh}) + gen_statem:enter_loop(?MODULE, + [], %%[{debug,[trace,log,statistics,debug]} || Role==server], + handle_event_function, + {hello,Role}, + State#state{ssh_params = Ssh}, + []) catch _:Error -> - gen_fsm:enter_loop(?MODULE, [], error, {Error, State}) + gen_statem:enter_loop(?MODULE, + [], + handle_event_function, + {init_error,Error}, + State, + []) end. -%% Temporary fix for the Nessus error. SYN-> <-SYNACK ACK-> RST-> ? -error(_Event, {Error,State=#state{}}) -> - case Error of - {badmatch,{error,enotconn}} -> - %% {error,enotconn} probably from inet:peername in - %% init_ssh(server,..)/5 called from init/1 - {stop, {shutdown,"TCP connenction to server was prematurely closed by the client"}, State}; - _ -> - {stop, {shutdown,{init,Error}}, State} - end; -error(Event, State) -> - %% State deliberately not checked beeing #state. This is a panic-clause... - {stop, {shutdown,{init,{spurious_error,Event}}}, State}. - %%-------------------------------------------------------------------- --spec open_channel(pid(), string(), iodata(), integer(), integer(), - timeout()) -> {open, channel_id()} | {error, term()}. %%-------------------------------------------------------------------- open_channel(ConnectionHandler, ChannelType, ChannelSpecificData, InitialWindowSize, MaxPacketSize, Timeout) -> - sync_send_all_state_event(ConnectionHandler, {open, self(), ChannelType, + call(ConnectionHandler, {open, self(), ChannelType, InitialWindowSize, MaxPacketSize, ChannelSpecificData, Timeout}). %%-------------------------------------------------------------------- --spec request(pid(), pid(), channel_id(), string(), boolean(), iodata(), - timeout()) -> success | failure | ok | {error, term()}. %%-------------------------------------------------------------------- request(ConnectionHandler, ChannelPid, ChannelId, Type, true, Data, Timeout) -> - sync_send_all_state_event(ConnectionHandler, {request, ChannelPid, ChannelId, Type, Data, + call(ConnectionHandler, {request, ChannelPid, ChannelId, Type, Data, Timeout}); request(ConnectionHandler, ChannelPid, ChannelId, Type, false, Data, _) -> - send_all_state_event(ConnectionHandler, {request, ChannelPid, ChannelId, Type, Data}). + cast(ConnectionHandler, {request, ChannelPid, ChannelId, Type, Data}). %%-------------------------------------------------------------------- --spec request(pid(), channel_id(), string(), boolean(), iodata(), - timeout()) -> success | failure | {error, timeout}. %%-------------------------------------------------------------------- request(ConnectionHandler, ChannelId, Type, true, Data, Timeout) -> - sync_send_all_state_event(ConnectionHandler, {request, ChannelId, Type, Data, Timeout}); + call(ConnectionHandler, {request, ChannelId, Type, Data, Timeout}); request(ConnectionHandler, ChannelId, Type, false, Data, _) -> - send_all_state_event(ConnectionHandler, {request, ChannelId, Type, Data}). + cast(ConnectionHandler, {request, ChannelId, Type, Data}). %%-------------------------------------------------------------------- --spec reply_request(pid(), success | failure, channel_id()) -> ok. %%-------------------------------------------------------------------- reply_request(ConnectionHandler, Status, ChannelId) -> - send_all_state_event(ConnectionHandler, {reply_request, Status, ChannelId}). + cast(ConnectionHandler, {reply_request, Status, ChannelId}). %%-------------------------------------------------------------------- --spec global_request(pid(), string(), boolean(), iolist()) -> ok | error. %%-------------------------------------------------------------------- global_request(ConnectionHandler, Type, true = Reply, Data) -> - case sync_send_all_state_event(ConnectionHandler, - {global_request, self(), Type, Reply, Data}) of + case call(ConnectionHandler, {global_request, self(), Type, Reply, Data}) of {ssh_cm, ConnectionHandler, {success, _}} -> ok; {ssh_cm, ConnectionHandler, {failure, _}} -> error end; global_request(ConnectionHandler, Type, false = Reply, Data) -> - send_all_state_event(ConnectionHandler, {global_request, self(), Type, Reply, Data}). + cast(ConnectionHandler, {global_request, self(), Type, Reply, Data}). %%-------------------------------------------------------------------- --spec send(pid(), channel_id(), integer(), iodata(), timeout()) -> - ok | {error, timeout} | {error, closed}. %%-------------------------------------------------------------------- send(ConnectionHandler, ChannelId, Type, Data, Timeout) -> - sync_send_all_state_event(ConnectionHandler, {data, ChannelId, Type, Data, Timeout}). + call(ConnectionHandler, {data, ChannelId, Type, Data, Timeout}). %%-------------------------------------------------------------------- --spec send_eof(pid(), channel_id()) -> ok | {error, closed}. %%-------------------------------------------------------------------- send_eof(ConnectionHandler, ChannelId) -> - sync_send_all_state_event(ConnectionHandler, {eof, ChannelId}). + call(ConnectionHandler, {eof, ChannelId}). %%-------------------------------------------------------------------- --spec connection_info(pid(), [atom()]) -> proplists:proplist(). %%-------------------------------------------------------------------- get_print_info(ConnectionHandler) -> - sync_send_all_state_event(ConnectionHandler, get_print_info, 1000). + call(ConnectionHandler, get_print_info, 1000). connection_info(ConnectionHandler, Options) -> - sync_send_all_state_event(ConnectionHandler, {connection_info, Options}). + call(ConnectionHandler, {connection_info, Options}). %%-------------------------------------------------------------------- --spec channel_info(pid(), channel_id(), [atom()]) -> proplists:proplist(). %%-------------------------------------------------------------------- channel_info(ConnectionHandler, ChannelId, Options) -> - sync_send_all_state_event(ConnectionHandler, {channel_info, ChannelId, Options}). + call(ConnectionHandler, {channel_info, ChannelId, Options}). %%-------------------------------------------------------------------- --spec adjust_window(pid(), channel_id(), integer()) -> ok. %%-------------------------------------------------------------------- adjust_window(ConnectionHandler, Channel, Bytes) -> - send_all_state_event(ConnectionHandler, {adjust_window, Channel, Bytes}). + cast(ConnectionHandler, {adjust_window, Channel, Bytes}). %%-------------------------------------------------------------------- --spec renegotiate(pid()) -> ok. %%-------------------------------------------------------------------- renegotiate(ConnectionHandler) -> - send_all_state_event(ConnectionHandler, renegotiate). + cast(ConnectionHandler, renegotiate). %%-------------------------------------------------------------------- --spec renegotiate_data(pid()) -> ok. %%-------------------------------------------------------------------- renegotiate_data(ConnectionHandler) -> - send_all_state_event(ConnectionHandler, data_size). + cast(ConnectionHandler, data_size). %%-------------------------------------------------------------------- --spec close(pid(), channel_id()) -> ok. %%-------------------------------------------------------------------- close(ConnectionHandler, ChannelId) -> - case sync_send_all_state_event(ConnectionHandler, {close, ChannelId}) of + case call(ConnectionHandler, {close, ChannelId}) of ok -> ok; - {error, closed} -> + {error, closed} -> ok - end. - + end. + %%-------------------------------------------------------------------- --spec stop(pid()) -> ok | {error, term()}. %%-------------------------------------------------------------------- stop(ConnectionHandler)-> - case sync_send_all_state_event(ConnectionHandler, stop) of + case call(ConnectionHandler, stop) of {error, closed} -> ok; Other -> @@ -329,484 +279,492 @@ info(ConnectionHandler) -> info(ConnectionHandler, {info, all}). info(ConnectionHandler, ChannelProcess) -> - sync_send_all_state_event(ConnectionHandler, {info, ChannelProcess}). - + call(ConnectionHandler, {info, ChannelProcess}). %%==================================================================== -%% gen_fsm callbacks +%% gen_statem callbacks %%==================================================================== -%%-------------------------------------------------------------------- --spec hello(socket_control | {info_line, list()} | {version_exchange, list()}, - #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- +%% Temporary fix for the Nessus error. SYN-> <-SYNACK ACK-> RST-> ? +handle_event(_, _Event, {init_error,Error}, _State) -> + case Error of + {badmatch,{error,enotconn}} -> + %% {error,enotconn} probably from inet:peername in + %% init_ssh(server,..)/5 called from init/1 + {stop, {shutdown,"TCP connenction to server was prematurely closed by the client"}}; + _ -> + {stop, {shutdown,{init,Error}}} + end; + -hello(socket_control, #state{socket = Socket, ssh_params = Ssh} = State) -> +%%% ######## {hello, client|server} #### + +handle_event(_, socket_control, StateName={hello,_}, S=#state{socket=Socket, + ssh_params=Ssh}) -> VsnMsg = ssh_transport:hello_version_msg(string_version(Ssh)), - send_msg(VsnMsg, State), + send_bytes(VsnMsg, S), case getopt(recbuf, Socket) of {ok, Size} -> - inet:setopts(Socket, [{packet, line}, {active, once}, {recbuf, ?MAX_PROTO_VERSION}]), - {next_state, hello, State#state{recbuf = Size}}; + inet:setopts(Socket, [{packet, line}, {active, once}, {recbuf, ?MAX_PROTO_VERSION}, {nodelay,true}]), + {next_state, StateName, S#state{recbuf=Size}}; {error, Reason} -> - {stop, {shutdown, Reason}, State} + {stop, {shutdown,Reason}} end; -hello({info_line, _Line},#state{role = client, socket = Socket} = State) -> +handle_event(_, {info_line,_Line}, StateName={hello,client}, S=#state{socket=Socket}) -> %% The server may send info lines before the version_exchange inet:setopts(Socket, [{active, once}]), - {next_state, hello, State}; + {next_state, StateName, S}; -hello({info_line, _Line},#state{role = server, - socket = Socket, - transport_cb = Transport } = State) -> +handle_event(_, {info_line,_Line}, {hello,server}, S) -> %% as openssh - Transport:send(Socket, "Protocol mismatch."), - {stop, {shutdown,"Protocol mismatch in version exchange."}, State}; + send_bytes("Protocol mismatch.", S), + {stop, {shutdown,"Protocol mismatch in version exchange."}}; -hello({version_exchange, Version}, #state{ssh_params = Ssh0, - socket = Socket, - recbuf = Size} = State) -> +handle_event(_, {version_exchange,Version}, {hello,Role}, S=#state{ssh_params = Ssh0, + socket = Socket, + recbuf = Size}) -> {NumVsn, StrVsn} = ssh_transport:handle_hello_version(Version), case handle_version(NumVsn, StrVsn, Ssh0) of {ok, Ssh1} -> inet:setopts(Socket, [{packet,0}, {mode,binary}, {active, once}, {recbuf, Size}]), {KeyInitMsg, SshPacket, Ssh} = ssh_transport:key_exchange_init_msg(Ssh1), - send_msg(SshPacket, State), - {next_state, kexinit, next_packet(State#state{ssh_params = Ssh, - key_exchange_init_msg = - KeyInitMsg})}; + send_bytes(SshPacket, S), + {next_state, {kexinit,Role,init}, S#state{ssh_params = Ssh, + key_exchange_init_msg = KeyInitMsg}}; not_supported -> - DisconnectMsg = - #ssh_msg_disconnect{code = - ?SSH_DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED, - description = "Protocol version " ++ StrVsn - ++ " not supported", - language = "en"}, - handle_disconnect(DisconnectMsg, State) - end. - -%%-------------------------------------------------------------------- --spec kexinit({#ssh_msg_kexinit{}, binary()}, #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- -kexinit({#ssh_msg_kexinit{} = Kex, Payload}, - #state{ssh_params = #ssh{role = Role} = Ssh0, - key_exchange_init_msg = OwnKex} = - State) -> - Ssh1 = ssh_transport:key_init(opposite_role(Role), Ssh0, Payload), - case ssh_transport:handle_kexinit_msg(Kex, OwnKex, Ssh1) of - {ok, NextKexMsg, Ssh} when Role == client -> - send_msg(NextKexMsg, State), - {next_state, key_exchange, - next_packet(State#state{ssh_params = Ssh})}; - {ok, Ssh} when Role == server -> - {next_state, key_exchange, - next_packet(State#state{ssh_params = Ssh})} - end. - -%%-------------------------------------------------------------------- --spec key_exchange(#ssh_msg_kexdh_init{} | #ssh_msg_kexdh_reply{} | - #ssh_msg_kex_dh_gex_group{} | #ssh_msg_kex_dh_gex_request{} | - #ssh_msg_kex_dh_gex_request{} | #ssh_msg_kex_dh_gex_reply{}, #state{}) - -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- - -key_exchange(#ssh_msg_kexdh_init{} = Msg, - #state{ssh_params = #ssh{role = server} = Ssh0} = State) -> - case ssh_transport:handle_kexdh_init(Msg, Ssh0) of - {ok, KexdhReply, Ssh1} -> - send_msg(KexdhReply, State), - {ok, NewKeys, Ssh} = ssh_transport:new_keys_message(Ssh1), - send_msg(NewKeys, State), - {next_state, new_keys, next_packet(State#state{ssh_params = Ssh})} + disconnect( + #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED, + description = ["Protocol version ",StrVsn," not supported"]}, + {next_state, {hello,Role}, S}) end; + +%%% ######## {kexinit, client|server, init|renegotiate} #### + +handle_event(_, {#ssh_msg_kexinit{} = Kex, Payload}, {kexinit,client,ReNeg}, + S = #state{ssh_params = Ssh0, + key_exchange_init_msg = OwnKex}) -> + Ssh1 = ssh_transport:key_init(server, Ssh0, Payload), % Yes, *server* + {ok, NextKexMsg, Ssh} = ssh_transport:handle_kexinit_msg(Kex, OwnKex, Ssh1), + send_bytes(NextKexMsg, S), + {next_state, {key_exchange,client,ReNeg}, S#state{ssh_params = Ssh}}; + +handle_event(_, {#ssh_msg_kexinit{} = Kex, Payload}, {kexinit,server,ReNeg}, + S = #state{ssh_params = Ssh0, + key_exchange_init_msg = OwnKex}) -> + Ssh1 = ssh_transport:key_init(client, Ssh0, Payload), % Yes, *client* + {ok, Ssh} = ssh_transport:handle_kexinit_msg(Kex, OwnKex, Ssh1), + {next_state, {key_exchange,server,ReNeg}, S#state{ssh_params = Ssh}}; + +%%% ######## {key_exchange, client|server, init|renegotiate} #### + +handle_event(_, #ssh_msg_kexdh_init{} = Msg, {key_exchange,server,ReNeg}, + S = #state{ssh_params = Ssh0}) -> + {ok, KexdhReply, Ssh1} = ssh_transport:handle_kexdh_init(Msg, Ssh0), + send_bytes(KexdhReply, S), + {ok, NewKeys, Ssh} = ssh_transport:new_keys_message(Ssh1), + send_bytes(NewKeys, S), + {next_state, {new_keys,server,ReNeg}, S#state{ssh_params = Ssh}}; -key_exchange(#ssh_msg_kexdh_reply{} = Msg, - #state{ssh_params = #ssh{role = client} = Ssh0} = State) -> +handle_event(_, #ssh_msg_kexdh_reply{} = Msg, {key_exchange,client,ReNeg}, + #state{ssh_params=Ssh0} = State) -> {ok, NewKeys, Ssh} = ssh_transport:handle_kexdh_reply(Msg, Ssh0), - send_msg(NewKeys, State), - {next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(NewKeys, State), + {next_state, {new_keys,client,ReNeg}, State#state{ssh_params = Ssh}}; -key_exchange(#ssh_msg_kex_dh_gex_request{} = Msg, - #state{ssh_params = #ssh{role = server} = Ssh0} = State) -> +handle_event(_, #ssh_msg_kex_dh_gex_request{} = Msg, {key_exchange,server,ReNeg}, + #state{ssh_params=Ssh0} = State) -> {ok, GexGroup, Ssh} = ssh_transport:handle_kex_dh_gex_request(Msg, Ssh0), - send_msg(GexGroup, State), - {next_state, key_exchange_dh_gex_init, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(GexGroup, State), + {next_state, {key_exchange_dh_gex_init,server,ReNeg}, State#state{ssh_params = Ssh}}; -key_exchange(#ssh_msg_kex_dh_gex_request_old{} = Msg, - #state{ssh_params = #ssh{role = server} = Ssh0} = State) -> +handle_event(_, #ssh_msg_kex_dh_gex_request_old{} = Msg, {key_exchange,server,ReNeg}, + #state{ssh_params=Ssh0} = State) -> {ok, GexGroup, Ssh} = ssh_transport:handle_kex_dh_gex_request(Msg, Ssh0), - send_msg(GexGroup, State), - {next_state, key_exchange_dh_gex_init, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(GexGroup, State), + {next_state, {key_exchange_dh_gex_init,server,ReNeg}, State#state{ssh_params = Ssh}}; -key_exchange(#ssh_msg_kex_dh_gex_group{} = Msg, - #state{ssh_params = #ssh{role = client} = Ssh0} = State) -> +handle_event(_, #ssh_msg_kex_dh_gex_group{} = Msg, {key_exchange,client,ReNeg}, + #state{ssh_params=Ssh0} = State) -> {ok, KexGexInit, Ssh} = ssh_transport:handle_kex_dh_gex_group(Msg, Ssh0), - send_msg(KexGexInit, State), - {next_state, key_exchange_dh_gex_reply, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(KexGexInit, State), + {next_state, {key_exchange_dh_gex_reply,client,ReNeg}, State#state{ssh_params = Ssh}}; -key_exchange(#ssh_msg_kex_ecdh_init{} = Msg, - #state{ssh_params = #ssh{role = server} = Ssh0} = State) -> +handle_event(_, #ssh_msg_kex_ecdh_init{} = Msg, {key_exchange,server,ReNeg}, + #state{ssh_params=Ssh0} = State) -> {ok, KexEcdhReply, Ssh1} = ssh_transport:handle_kex_ecdh_init(Msg, Ssh0), - send_msg(KexEcdhReply, State), + send_bytes(KexEcdhReply, State), {ok, NewKeys, Ssh} = ssh_transport:new_keys_message(Ssh1), - send_msg(NewKeys, State), - {next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(NewKeys, State), + {next_state, {new_keys,server,ReNeg}, State#state{ssh_params = Ssh}}; -key_exchange(#ssh_msg_kex_ecdh_reply{} = Msg, - #state{ssh_params = #ssh{role = client} = Ssh0} = State) -> +handle_event(_, #ssh_msg_kex_ecdh_reply{} = Msg, {key_exchange,client,ReNeg}, + #state{ssh_params=Ssh0} = State) -> {ok, NewKeys, Ssh} = ssh_transport:handle_kex_ecdh_reply(Msg, Ssh0), - send_msg(NewKeys, State), - {next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}. + send_bytes(NewKeys, State), + {next_state, {new_keys,client,ReNeg}, State#state{ssh_params = Ssh}}; -%%-------------------------------------------------------------------- --spec key_exchange_dh_gex_init(#ssh_msg_kex_dh_gex_init{}, #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- -key_exchange_dh_gex_init(#ssh_msg_kex_dh_gex_init{} = Msg, - #state{ssh_params = #ssh{role = server} = Ssh0} = State) -> +%%% ######## {key_exchange_dh_gex_init, server, init|renegotiate} #### + +handle_event(_, #ssh_msg_kex_dh_gex_init{} = Msg, {key_exchange_dh_gex_init,server,ReNeg}, + #state{ssh_params=Ssh0} = State) -> {ok, KexGexReply, Ssh1} = ssh_transport:handle_kex_dh_gex_init(Msg, Ssh0), - send_msg(KexGexReply, State), + send_bytes(KexGexReply, State), {ok, NewKeys, Ssh} = ssh_transport:new_keys_message(Ssh1), - send_msg(NewKeys, State), - {next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}. + send_bytes(NewKeys, State), + {next_state, {new_keys,server,ReNeg}, State#state{ssh_params = Ssh}}; -%%-------------------------------------------------------------------- --spec key_exchange_dh_gex_reply(#ssh_msg_kex_dh_gex_reply{}, #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- -key_exchange_dh_gex_reply(#ssh_msg_kex_dh_gex_reply{} = Msg, - #state{ssh_params = #ssh{role = client} = Ssh0} = State) -> - {ok, NewKeys, Ssh1} = ssh_transport:handle_kex_dh_gex_reply(Msg, Ssh0), - send_msg(NewKeys, State), - {next_state, new_keys, next_packet(State#state{ssh_params = Ssh1})}. +%%% ######## {key_exchange_dh_gex_reply, client, init|renegotiate} #### -%%-------------------------------------------------------------------- --spec new_keys(#ssh_msg_newkeys{}, #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- +handle_event(_, #ssh_msg_kex_dh_gex_reply{} = Msg, {key_exchange_dh_gex_reply,client,ReNeg}, + #state{ssh_params=Ssh0} = State) -> + {ok, NewKeys, Ssh1} = ssh_transport:handle_kex_dh_gex_reply(Msg, Ssh0), + send_bytes(NewKeys, State), + {next_state, {new_keys,client,ReNeg}, State#state{ssh_params = Ssh1}}; + +%%% ######## {new_keys, client|server} #### + +handle_event(_, #ssh_msg_newkeys{} = Msg, {new_keys,client,init}, + #state{ssh_params = Ssh0} = State) -> + {ok, Ssh1} = ssh_transport:handle_new_keys(Msg, Ssh0), + {MsgReq, Ssh} = ssh_auth:service_request_msg(Ssh1), + send_bytes(MsgReq, State), + {next_state, {service_request,client}, State#state{ssh_params=Ssh}}; -new_keys(#ssh_msg_newkeys{} = Msg, #state{ssh_params = Ssh0} = State0) -> +handle_event(_, #ssh_msg_newkeys{} = Msg, {new_keys,server,init}, + S = #state{ssh_params = Ssh0}) -> {ok, Ssh} = ssh_transport:handle_new_keys(Msg, Ssh0), - after_new_keys(next_packet(State0#state{ssh_params = Ssh})). + {next_state, {service_request,server}, S#state{ssh_params = Ssh}}; -%%-------------------------------------------------------------------- --spec service_request(#ssh_msg_service_request{} | #ssh_msg_service_accept{}, - #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- -service_request(#ssh_msg_service_request{name = "ssh-userauth"} = Msg, - #state{ssh_params = #ssh{role = server, - session_id = SessionId} = Ssh0} = State) -> +handle_event(_, #ssh_msg_newkeys{}, {new_keys,Role,renegotiate}, S) -> + {next_state, {connected,Role}, S}; + +%%% ######## {service_request, client|server} + +handle_event(_, #ssh_msg_service_request{name = "ssh-userauth"} = Msg, {service_request,server}, + #state{ssh_params = #ssh{session_id=SessionId} = Ssh0} = State) -> {ok, {Reply, Ssh}} = ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0), - send_msg(Reply, State), - {next_state, userauth, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(Reply, State), + {next_state, {userauth,server}, State#state{ssh_params = Ssh}}; + +handle_event(_, #ssh_msg_service_request{}, {service_request,server}=StateName, State) -> + Msg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE, + description = "Unknown service"}, + disconnect(Msg, StateName, State); -service_request(#ssh_msg_service_accept{name = "ssh-userauth"}, - #state{ssh_params = #ssh{role = client, - service = "ssh-userauth"} = Ssh0} = - State) -> +handle_event(_, #ssh_msg_service_accept{name = "ssh-userauth"}, {service_request,client}, + #state{ssh_params = #ssh{service="ssh-userauth"} = Ssh0} = State) -> {Msg, Ssh} = ssh_auth:init_userauth_request_msg(Ssh0), - send_msg(Msg, State), - {next_state, userauth, next_packet(State#state{auth_user = Ssh#ssh.user, ssh_params = Ssh})}. + send_bytes(Msg, State), + {next_state, {userauth,client}, State#state{auth_user = Ssh#ssh.user, ssh_params = Ssh}}; -%%-------------------------------------------------------------------- --spec userauth(#ssh_msg_userauth_request{} | #ssh_msg_userauth_info_request{} | - #ssh_msg_userauth_info_response{} | #ssh_msg_userauth_success{} | - #ssh_msg_userauth_failure{} | #ssh_msg_userauth_banner{}, - #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- -userauth(#ssh_msg_userauth_request{service = "ssh-connection", - method = "none"} = Msg, - #state{ssh_params = #ssh{session_id = SessionId, role = server, +%%% ######## {userauth, client|server} #### + +handle_event(_, #ssh_msg_userauth_request{service = "ssh-connection", + method = "none"} = Msg, StateName={userauth,server}, + #state{ssh_params = #ssh{session_id = SessionId, service = "ssh-connection"} = Ssh0 - } = State) -> + } = State) -> +?IO_FORMAT('~p #ssh_msg_userauth_request{ssh-connection,~p}~n',[self(),Msg#ssh_msg_userauth_request.method]), {not_authorized, {_User, _Reason}, {Reply, Ssh}} = ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0), - send_msg(Reply, State), - {next_state, userauth, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(Reply, State), + {next_state, StateName, State#state{ssh_params = Ssh}}; -userauth(#ssh_msg_userauth_request{service = "ssh-connection", - method = Method} = Msg, - #state{ssh_params = #ssh{session_id = SessionId, role = server, +handle_event(_, #ssh_msg_userauth_request{service = "ssh-connection", + method = Method} = Msg, StateName={userauth,server}, + #state{ssh_params = #ssh{session_id = SessionId, service = "ssh-connection", peer = {_, Address}} = Ssh0, opts = Opts, starter = Pid} = State) -> +?IO_FORMAT('~p #ssh_msg_userauth_request{ssh-connection,~p}~n',[self(),Msg#ssh_msg_userauth_request.method]), case lists:member(Method, Ssh0#ssh.userauth_methods) of true -> case ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0) of {authorized, User, {Reply, Ssh}} -> - send_msg(Reply, State), + send_bytes(Reply, State), Pid ! ssh_connected, connected_fun(User, Address, Method, Opts), - {next_state, connected, - next_packet(State#state{auth_user = User, ssh_params = Ssh#ssh{authenticated = true}})}; +?IO_FORMAT('~p CONNECTED!~n',[self()]), + {next_state, {connected,server}, + State#state{auth_user = User, ssh_params = Ssh#ssh{authenticated = true}}}; {not_authorized, {User, Reason}, {Reply, Ssh}} when Method == "keyboard-interactive" -> retry_fun(User, Address, Reason, Opts), - send_msg(Reply, State), - {next_state, userauth_keyboard_interactive, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(Reply, State), +?IO_FORMAT('~p not_authorized (1)~n',[self()]), + {next_state, {userauth_keyboard_interactive,server}, State#state{ssh_params = Ssh}}; {not_authorized, {User, Reason}, {Reply, Ssh}} -> retry_fun(User, Address, Reason, Opts), - send_msg(Reply, State), - {next_state, userauth, next_packet(State#state{ssh_params = Ssh})} + send_bytes(Reply, State), +?IO_FORMAT('~p not_authorized (2)~n',[self()]), + {next_state, StateName, State#state{ssh_params = Ssh}} end; false -> - userauth(Msg#ssh_msg_userauth_request{method="none"}, State) - end; + %% At least one non-erlang client does like this. Retry as the next event +?IO_FORMAT('~p bug-fix~n',[self()]), + {next_state, StateName, State, + [{next_event, internal, Msg#ssh_msg_userauth_request{method="none"}}] + } + end; -userauth(#ssh_msg_userauth_success{}, #state{ssh_params = #ssh{role = client} = Ssh, - starter = Pid} = State) -> +handle_event(_, #ssh_msg_userauth_request{service = Service}, {userauth,server}=StateName, State) + when Service =/= "ssh-connection" -> +?IO_FORMAT('~p #ssh_msg_userauth_request{~p,...}~n',[self(),Service]), + Msg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE, + description = "Unknown service"}, + disconnect(Msg, StateName, State); + +handle_event(_, #ssh_msg_userauth_success{}, {userauth,client}, #state{ssh_params = Ssh, + starter = Pid} = State) -> Pid ! ssh_connected, - {next_state, connected, next_packet(State#state{ssh_params = - Ssh#ssh{authenticated = true}})}; -userauth(#ssh_msg_userauth_failure{}, - #state{ssh_params = #ssh{role = client, - userauth_methods = []}} - = State) -> - Msg = #ssh_msg_disconnect{code = - ?SSH_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, + {next_state, {connected,client}, State#state{ssh_params=Ssh#ssh{authenticated = true}}}; + +handle_event(_, #ssh_msg_userauth_failure{}, {userauth,client}=StateName, + #state{ssh_params = #ssh{userauth_methods = []}} = State) -> + Msg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE, description = "Unable to connect using the available" - " authentication methods", - language = "en"}, - handle_disconnect(Msg, State); - -%% Server tells us which authentication methods that are allowed -userauth(#ssh_msg_userauth_failure{authentications = Methodes}, - #state{ssh_params = #ssh{role = client, - userauth_methods = none} = Ssh0} = State) -> - AuthMethods = string:tokens(Methodes, ","), - Ssh1 = Ssh0#ssh{userauth_methods = AuthMethods}, + " authentication methods"}, + disconnect(Msg, StateName, State); + + +handle_event(_, #ssh_msg_userauth_failure{authentications = Methods}, StateName={userauth,client}, + #state{ssh_params = Ssh0 = #ssh{userauth_methods=AuthMthds}} = State) -> + %% The prefered authentication method failed try next method + Ssh1 = case AuthMthds of + none -> + %% Server tells us which authentication methods that are allowed + Ssh0#ssh{userauth_methods = string:tokens(Methods, ",")}; + _ -> + %% We already know... + Ssh0 + end, case ssh_auth:userauth_request_msg(Ssh1) of {disconnect, DisconnectMsg, {Msg, Ssh}} -> - send_msg(Msg, State), - handle_disconnect(DisconnectMsg, State#state{ssh_params = Ssh}); + send_bytes(Msg, State), + disconnect(DisconnectMsg, StateName, State#state{ssh_params = Ssh}); {"keyboard-interactive", {Msg, Ssh}} -> - send_msg(Msg, State), - {next_state, userauth_keyboard_interactive, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(Msg, State), + {next_state, {userauth_keyboard_interactive,client}, State#state{ssh_params = Ssh}}; {_Method, {Msg, Ssh}} -> - send_msg(Msg, State), - {next_state, userauth, next_packet(State#state{ssh_params = Ssh})} + send_bytes(Msg, State), + {next_state, StateName, State#state{ssh_params = Ssh}} end; -%% The prefered authentication method failed try next method -userauth(#ssh_msg_userauth_failure{}, - #state{ssh_params = #ssh{role = client} = Ssh0} = State) -> - case ssh_auth:userauth_request_msg(Ssh0) of - {disconnect, DisconnectMsg,{Msg, Ssh}} -> - send_msg(Msg, State), - handle_disconnect(DisconnectMsg, State#state{ssh_params = Ssh}); - {"keyboard-interactive", {Msg, Ssh}} -> - send_msg(Msg, State), - {next_state, userauth_keyboard_interactive, next_packet(State#state{ssh_params = Ssh})}; - {_Method, {Msg, Ssh}} -> - send_msg(Msg, State), - {next_state, userauth, next_packet(State#state{ssh_params = Ssh})} - end; +handle_event(_, #ssh_msg_userauth_banner{}, StateName={userauth,client}, + #state{ssh_params = #ssh{userauth_quiet_mode=true}} = State) -> + {next_state, StateName, State}; -userauth(#ssh_msg_userauth_banner{}, - #state{ssh_params = #ssh{userauth_quiet_mode = true, - role = client}} = State) -> - {next_state, userauth, next_packet(State)}; -userauth(#ssh_msg_userauth_banner{message = Msg}, - #state{ssh_params = - #ssh{userauth_quiet_mode = false, role = client}} = State) -> +handle_event(_, #ssh_msg_userauth_banner{message = Msg}, StateName={userauth,client}, + #state{ssh_params = #ssh{userauth_quiet_mode=false}} = State) -> io:format("~s", [Msg]), - {next_state, userauth, next_packet(State)}. - + {next_state, StateName, State}; +%%% ######## {userauth_keyboard_interactive, client|server} -userauth_keyboard_interactive(#ssh_msg_userauth_info_request{} = Msg, - #state{ssh_params = #ssh{role = client, - io_cb = IoCb} = Ssh0} = State) -> +handle_event(_, #ssh_msg_userauth_info_request{} = Msg, {userauth_keyboard_interactive, client}, + #state{ssh_params = #ssh{io_cb=IoCb} = Ssh0} = State) -> {ok, {Reply, Ssh}} = ssh_auth:handle_userauth_info_request(Msg, IoCb, Ssh0), - send_msg(Reply, State), - {next_state, userauth_keyboard_interactive_info_response, next_packet(State#state{ssh_params = Ssh})}; + send_bytes(Reply, State), + {next_state, {userauth_keyboard_interactive_info_response,client}, State#state{ssh_params = Ssh}}; -userauth_keyboard_interactive(#ssh_msg_userauth_info_response{} = Msg, - #state{ssh_params = #ssh{role = server, - peer = {_, Address}} = Ssh0, - opts = Opts, starter = Pid} = State) -> +handle_event(_, #ssh_msg_userauth_info_response{} = Msg, {userauth_keyboard_interactive, server}, + #state{ssh_params = #ssh{peer = {_,Address}} = Ssh0, + opts = Opts, + starter = Pid} = State) -> case ssh_auth:handle_userauth_info_response(Msg, Ssh0) of {authorized, User, {Reply, Ssh}} -> - send_msg(Reply, State), + send_bytes(Reply, State), Pid ! ssh_connected, connected_fun(User, Address, "keyboard-interactive", Opts), - {next_state, connected, - next_packet(State#state{auth_user = User, ssh_params = Ssh#ssh{authenticated = true}})}; + {next_state, {connected,server}, State#state{auth_user = User, + ssh_params = Ssh#ssh{authenticated = true}}}; {not_authorized, {User, Reason}, {Reply, Ssh}} -> retry_fun(User, Address, Reason, Opts), - send_msg(Reply, State), - {next_state, userauth, next_packet(State#state{ssh_params = Ssh})} + send_bytes(Reply, State), + {next_state, {userauth,server}, State#state{ssh_params = Ssh}} end; -userauth_keyboard_interactive(Msg = #ssh_msg_userauth_failure{}, - #state{ssh_params = Ssh0 = - #ssh{role = client, - userauth_preference = Prefs0}} - = State) -> +handle_event(_, Msg = #ssh_msg_userauth_failure{}, {userauth_keyboard_interactive, client}, + #state{ssh_params = Ssh0 = #ssh{userauth_preference=Prefs0}} = State) -> Prefs = [{Method,M,F,A} || {Method,M,F,A} <- Prefs0, Method =/= "keyboard-interactive"], - userauth(Msg, State#state{ssh_params = Ssh0#ssh{userauth_preference=Prefs}}). + {next_state, {userauth,client}, + State#state{ssh_params = Ssh0#ssh{userauth_preference=Prefs}}, + [{next_event, internal, Msg}]}; +handle_event(_, Msg=#ssh_msg_userauth_failure{}, {userauth_keyboard_interactive_info_response, client}, S) -> + {next_state, {userauth,client}, S, [{next_event, internal, Msg}]}; +handle_event(_, Msg=#ssh_msg_userauth_success{}, {userauth_keyboard_interactive_info_response, client}, S) -> + {next_state, {userauth,client}, S, [{next_event, internal, Msg}]}; -userauth_keyboard_interactive_info_response(Msg=#ssh_msg_userauth_failure{}, - #state{ssh_params = #ssh{role = client}} = State) -> - userauth(Msg, State); -userauth_keyboard_interactive_info_response(Msg=#ssh_msg_userauth_success{}, - #state{ssh_params = #ssh{role = client}} = State) -> - userauth(Msg, State); -userauth_keyboard_interactive_info_response(Msg=#ssh_msg_userauth_info_request{}, - #state{ssh_params = #ssh{role = client}} = State) -> - userauth_keyboard_interactive(Msg, State). +handle_event(_, Msg=#ssh_msg_userauth_info_request{}, {userauth_keyboard_interactive_info_response, client}, S) -> + {next_state, {userauth_keyboard_interactive,client}, S, [{next_event, internal, Msg}]}; -%%-------------------------------------------------------------------- --spec connected({#ssh_msg_kexinit{}, binary()}, %%| %% #ssh_msg_kexdh_init{}, - #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- -connected({#ssh_msg_kexinit{}, _Payload} = Event, #state{ssh_params = Ssh0} = State0) -> +%%% ######## {connected, client|server} #### + +handle_event(_, {#ssh_msg_kexinit{},_} = Event, {connected,Role}, #state{ssh_params = Ssh0} = State0) -> {KeyInitMsg, SshPacket, Ssh} = ssh_transport:key_exchange_init_msg(Ssh0), State = State0#state{ssh_params = Ssh, - key_exchange_init_msg = KeyInitMsg, - renegotiate = true}, - send_msg(SshPacket, State), - kexinit(Event, State). - -%%-------------------------------------------------------------------- --spec handle_event(#ssh_msg_disconnect{} | #ssh_msg_ignore{} | #ssh_msg_debug{} | - #ssh_msg_unimplemented{} | {adjust_window, integer(), integer()} | - {reply_request, success | failure, integer()} | renegotiate | - data_size | {request, pid(), integer(), integer(), iolist()} | - {request, integer(), integer(), iolist()}, state_name(), - #state{}) -> gen_fsm_state_return(). - -%%-------------------------------------------------------------------- -handle_event(#ssh_msg_disconnect{description = Desc} = DisconnectMsg, _StateName, #state{} = State) -> - handle_disconnect(peer, DisconnectMsg, State), - {stop, {shutdown, Desc}, State}; + key_exchange_init_msg = KeyInitMsg}, + send_bytes(SshPacket, State), + {next_state, {kexinit,Role,renegotiate}, State, [{next_event, internal, Event}]}; + +handle_event(_, #ssh_msg_disconnect{description=Desc} = Msg, StateName, + State0 = #state{connection_state = Connection0}) -> + {disconnect, _, {{replies, Replies}, _Connection}} = + ssh_connection:handle_msg(Msg, Connection0, role(StateName)), + {Repls,State} = send_replies(Replies, State0), + disconnect_fun(Desc, State#state.opts), + {stop_and_reply, {shutdown,Desc}, Repls, State}; -handle_event(#ssh_msg_ignore{}, StateName, State) -> - {next_state, StateName, next_packet(State)}; +handle_event(_, #ssh_msg_ignore{}, StateName, State) -> + {next_state, StateName, State}; -handle_event(#ssh_msg_debug{always_display = Display, message = DbgMsg, language=Lang}, - StateName, #state{opts = Opts} = State) -> - F = proplists:get_value(ssh_msg_debug_fun, Opts, +handle_event(_, #ssh_msg_debug{always_display = Display, + message = DbgMsg, + language = Lang}, StateName, #state{opts = Opts} = State) -> + F = proplists:get_value(ssh_msg_debug_fun, Opts, fun(_ConnRef, _AlwaysDisplay, _Msg, _Language) -> ok end ), catch F(self(), Display, DbgMsg, Lang), - {next_state, StateName, next_packet(State)}; + {next_state, StateName, State}; + +handle_event(_, #ssh_msg_unimplemented{}, StateName, State) -> + {next_state, StateName, State}; + +handle_event(internal, Msg=#ssh_msg_global_request{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_request_success{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); -handle_event(#ssh_msg_unimplemented{}, StateName, State) -> - {next_state, StateName, next_packet(State)}; +handle_event(internal, Msg=#ssh_msg_request_failure{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); -handle_event(renegotiate, connected, #state{ssh_params = Ssh0} - = State) -> +handle_event(internal, Msg=#ssh_msg_channel_open{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_open_confirmation{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_open_failure{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_window_adjust{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_data{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_extended_data{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_eof{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_close{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_request{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_success{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(internal, Msg=#ssh_msg_channel_failure{}, StateName, State) -> + handle_connection_msg(Msg, StateName, State); + +handle_event(cast, renegotiate, {connected,Role}, #state{ssh_params=Ssh0} = State) -> {KeyInitMsg, SshPacket, Ssh} = ssh_transport:key_exchange_init_msg(Ssh0), - send_msg(SshPacket, State), - timer:apply_after(?REKEY_TIMOUT, gen_fsm, send_all_state_event, [self(), renegotiate]), - {next_state, kexinit, - next_packet(State#state{ssh_params = Ssh, - key_exchange_init_msg = KeyInitMsg, - renegotiate = true})}; - -handle_event(renegotiate, StateName, State) -> + send_bytes(SshPacket, State), +%%% FIXME: timer + timer:apply_after(?REKEY_TIMOUT, gen_statem, cast, [self(), renegotiate]), + {next_state, {kexinit,Role,renegotiate}, State#state{ssh_params = Ssh, + key_exchange_init_msg = KeyInitMsg}}; + +handle_event(cast, renegotiate, StateName, State) -> %% Already in key-exchange so safe to ignore {next_state, StateName, State}; %% Rekey due to sent data limit reached? -handle_event(data_size, connected, #state{ssh_params = Ssh0} = State) -> +handle_event(cast, data_size, {connected,Role}, #state{ssh_params=Ssh0} = State) -> {ok, [{send_oct,Sent0}]} = inet:getstat(State#state.socket, [send_oct]), Sent = Sent0 - State#state.last_size_rekey, MaxSent = proplists:get_value(rekey_limit, State#state.opts, 1024000000), - timer:apply_after(?REKEY_DATA_TIMOUT, gen_fsm, send_all_state_event, [self(), data_size]), +%%% FIXME: timer + timer:apply_after(?REKEY_DATA_TIMOUT, gen_statem, cast, [self(), data_size]), case Sent >= MaxSent of true -> {KeyInitMsg, SshPacket, Ssh} = ssh_transport:key_exchange_init_msg(Ssh0), - send_msg(SshPacket, State), - {next_state, kexinit, - next_packet(State#state{ssh_params = Ssh, - key_exchange_init_msg = KeyInitMsg, - renegotiate = true, - last_size_rekey = Sent0})}; + send_bytes(SshPacket, State), + {next_state, {kexinit,Role,renegotiate}, State#state{ssh_params = Ssh, + key_exchange_init_msg = KeyInitMsg, + last_size_rekey = Sent0}}; _ -> - {next_state, connected, next_packet(State)} + {next_state, {connected,Role}, State} end; -handle_event(data_size, StateName, State) -> +handle_event(cast, data_size, StateName, State) -> %% Already in key-exchange so safe to ignore {next_state, StateName, State}; -handle_event(Event, StateName, State) when StateName /= connected -> - Events = [{event, Event} | State#state.event_queue], - {next_state, StateName, State#state{event_queue = Events}}; +handle_event(cast, _, StateName, State) when StateName /= {connected,server}, + StateName /= {connected,client} -> + {next_state, StateName, State, [postpone]}; -handle_event({adjust_window, ChannelId, Bytes}, StateName, +handle_event(cast, {adjust_window,ChannelId,Bytes}, StateName={connected,_Role}, #state{connection_state = #connection{channel_cache = Cache}} = State0) -> - State = - case ssh_channel:cache_lookup(Cache, ChannelId) of - #channel{recv_window_size = WinSize, - recv_window_pending = Pending, - recv_packet_size = PktSize} = Channel - when (WinSize-Bytes) >= 2*PktSize -> - %% The peer can send at least two more *full* packet, no hurry. - ssh_channel:cache_update(Cache, - Channel#channel{recv_window_pending = Pending + Bytes}), - State0; - - #channel{recv_window_size = WinSize, - recv_window_pending = Pending, - remote_id = Id} = Channel -> - %% Now we have to update the window - we can't receive so many more pkts - ssh_channel:cache_update(Cache, - Channel#channel{recv_window_size = - WinSize + Bytes + Pending, - recv_window_pending = 0}), - Msg = ssh_connection:channel_adjust_window_msg(Id, Bytes + Pending), - send_replies([{connection_reply, Msg}], State0); + case ssh_channel:cache_lookup(Cache, ChannelId) of + #channel{recv_window_size = WinSize, + recv_window_pending = Pending, + recv_packet_size = PktSize} = Channel + when (WinSize-Bytes) >= 2*PktSize -> + %% The peer can send at least two more *full* packet, no hurry. + ssh_channel:cache_update(Cache, + Channel#channel{recv_window_pending = Pending + Bytes}), + {next_state, StateName, State0}; + + #channel{recv_window_size = WinSize, + recv_window_pending = Pending, + remote_id = Id} = Channel -> + %% Now we have to update the window - we can't receive so many more pkts + ssh_channel:cache_update(Cache, + Channel#channel{recv_window_size = + WinSize + Bytes + Pending, + recv_window_pending = 0}), + Msg = ssh_connection:channel_adjust_window_msg(Id, Bytes + Pending), + {next_state, StateName, send_msg(Msg,State0)}; + + undefined -> + {next_state, StateName, State0} + end; - undefined -> - State0 - end, - {next_state, StateName, next_packet(State)}; - -handle_event({reply_request, success, ChannelId}, StateName, - #state{connection_state = - #connection{channel_cache = Cache}} = State0) -> - State = case ssh_channel:cache_lookup(Cache, ChannelId) of - #channel{remote_id = RemoteId} -> - Msg = ssh_connection:channel_success_msg(RemoteId), - send_replies([{connection_reply, Msg}], State0); - undefined -> - State0 - end, - {next_state, StateName, State}; +handle_event(cast, {reply_request,success,ChannelId}, StateName={connected,_}, + #state{connection_state = + #connection{channel_cache = Cache}} = State0) -> + case ssh_channel:cache_lookup(Cache, ChannelId) of + #channel{remote_id = RemoteId} -> + Msg = ssh_connection:channel_success_msg(RemoteId), + {next_state, StateName, send_msg(Msg,State0)}; + + undefined -> + {next_state, StateName, State0} + end; -handle_event({request, ChannelPid, ChannelId, Type, Data}, StateName, State0) -> - {{replies, Replies}, State1} = handle_request(ChannelPid, ChannelId, - Type, Data, - false, none, State0), - State = send_replies(Replies, State1), - {next_state, StateName, next_packet(State)}; +handle_event(cast, {request,ChannelPid,ChannelId,Type,Data}, StateName={connected,_}, State0) -> + State = handle_request(ChannelPid, ChannelId, Type, Data, false, none, State0), + {next_state, StateName, State}; -handle_event({request, ChannelId, Type, Data}, StateName, State0) -> - {{replies, Replies}, State1} = handle_request(ChannelId, Type, Data, - false, none, State0), - State = send_replies(Replies, State1), - {next_state, StateName, next_packet(State)}; +handle_event(cast, {request,ChannelId,Type,Data}, StateName={connected,_}, State0) -> + State = handle_request(ChannelId, Type, Data, false, none, State0), + {next_state, StateName, State}; -handle_event({unknown, Data}, StateName, State) -> +handle_event(cast, {unknown,Data}, StateName={connected,_}, State) -> Msg = #ssh_msg_unimplemented{sequence = Data}, - send_msg(Msg, State), - {next_state, StateName, next_packet(State)}. + {next_state, StateName, send_msg(Msg,State)}; -%%-------------------------------------------------------------------- --spec handle_sync_event({request, pid(), channel_id(), integer(), binary(), timeout()} | - {request, channel_id(), integer(), binary(), timeout()} | - {global_request, pid(), integer(), boolean(), binary()} | {eof, integer()} | - {open, pid(), integer(), channel_id(), integer(), binary(), _} | - {send_window, channel_id()} | {recv_window, channel_id()} | - {connection_info, [client_version | server_version | peer | - sockname]} | {channel_info, channel_id(), [recv_window | - send_window]} | - {close, channel_id()} | stop, term(), state_name(), #state{}) - -> gen_fsm_sync_return(). -%%-------------------------------------------------------------------- -handle_sync_event(get_print_info, _From, StateName, State) -> +%%% Previously handle_sync_event began here +handle_event({call,From}, get_print_info, StateName, State) -> Reply = try {inet:sockname(State#state.socket), @@ -818,25 +776,24 @@ handle_sync_event(get_print_info, _From, StateName, State) -> catch _:_ -> {{"?",0},"?"} end, - {reply, Reply, StateName, State}; + {next_state, StateName, State, [{reply,From,Reply}]}; -handle_sync_event({connection_info, Options}, _From, StateName, State) -> +handle_event({call,From}, {connection_info, Options}, StateName, State) -> Info = ssh_info(Options, State, []), - {reply, Info, StateName, State}; + {next_state, StateName, State, [{reply,From,Info}]}; -handle_sync_event({channel_info, ChannelId, Options}, _From, StateName, - #state{connection_state = #connection{channel_cache = Cache}} = State) -> +handle_event({call,From}, {channel_info,ChannelId,Options}, StateName, + State=#state{connection_state = #connection{channel_cache = Cache}}) -> case ssh_channel:cache_lookup(Cache, ChannelId) of #channel{} = Channel -> Info = ssh_channel_info(Options, Channel, []), - {reply, Info, StateName, State}; + {next_state, StateName, State, [{reply,From,Info}]}; undefined -> - {reply, [], StateName, State} + {next_state, StateName, State, [{reply,From,[]}]} end; -handle_sync_event({info, ChannelPid}, _From, StateName, - #state{connection_state = - #connection{channel_cache = Cache}} = State) -> +handle_event({call,From}, {info, ChannelPid}, StateName, State = #state{connection_state = + #connection{channel_cache = Cache}}) -> Result = ssh_channel:cache_foldl( fun(Channel, Acc) when ChannelPid == all; Channel#channel.user == ChannelPid -> @@ -844,86 +801,74 @@ handle_sync_event({info, ChannelPid}, _From, StateName, (_, Acc) -> Acc end, [], Cache), - {reply, {ok, Result}, StateName, State}; + {next_state, StateName, State, [{reply, From, {ok,Result}}]}; -handle_sync_event(stop, _, _StateName, #state{connection_state = Connection0, - role = Role} = State0) -> +handle_event({call,From}, stop, StateName, #state{connection_state = Connection0} = State0) -> {disconnect, _Reason, {{replies, Replies}, Connection}} = ssh_connection:handle_msg(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, - description = "User closed down connection", - language = "en"}, Connection0, Role), - State = send_replies(Replies, State0), - {stop, normal, ok, State#state{connection_state = Connection}}; - + description = "User closed down connection"}, + Connection0, role(StateName)), + {Repls,State} = send_replies(Replies, State0), + {stop_and_reply, normal, [{reply,From,ok}|Repls], State#state{connection_state=Connection}}; -handle_sync_event(Event, From, StateName, State) when StateName /= connected -> - Events = [{sync, Event, From} | State#state.event_queue], - {next_state, StateName, State#state{event_queue = Events}}; +handle_event({call,_}, _, StateName, State) when StateName /= {connected,server}, + StateName /= {connected,client} -> + {next_state, StateName, State, [postpone]}; -handle_sync_event({request, ChannelPid, ChannelId, Type, Data, Timeout}, From, StateName, State0) -> - {{replies, Replies}, State1} = handle_request(ChannelPid, - ChannelId, Type, Data, - true, From, State0), +handle_event({call,From}, {request, ChannelPid, ChannelId, Type, Data, Timeout}, StateName={connected,_}, State0) -> + State = handle_request(ChannelPid, ChannelId, Type, Data, true, From, State0), %% Note reply to channel will happen later when %% reply is recived from peer on the socket - State = send_replies(Replies, State1), start_timeout(ChannelId, From, Timeout), handle_idle_timeout(State), - {next_state, StateName, next_packet(State)}; + {next_state, StateName, State}; -handle_sync_event({request, ChannelId, Type, Data, Timeout}, From, StateName, State0) -> - {{replies, Replies}, State1} = handle_request(ChannelId, Type, Data, - true, From, State0), +handle_event({call,From}, {request, ChannelId, Type, Data, Timeout}, StateName={connected,_}, State0) -> + State = handle_request(ChannelId, Type, Data, true, From, State0), %% Note reply to channel will happen later when %% reply is recived from peer on the socket - State = send_replies(Replies, State1), start_timeout(ChannelId, From, Timeout), handle_idle_timeout(State), - {next_state, StateName, next_packet(State)}; + {next_state, StateName, State}; -handle_sync_event({global_request, Pid, _, _, _} = Request, From, StateName, - #state{connection_state = - #connection{channel_cache = Cache}} = State0) -> +handle_event({call,From}, {global_request, Pid, _, _, _} = Request, StateName={connected,_}, + #state{connection_state = #connection{channel_cache = Cache}} = State0) -> State1 = handle_global_request(Request, State0), Channel = ssh_channel:cache_find(Pid, Cache), State = add_request(true, Channel#channel.local_id, From, State1), - {next_state, StateName, next_packet(State)}; - -handle_sync_event({data, ChannelId, Type, Data, Timeout}, From, StateName, - #state{connection_state = #connection{channel_cache = _Cache} - = Connection0} = State0) -> + {next_state, StateName, State}; +handle_event({call,From}, {data, ChannelId, Type, Data, Timeout}, StateName={connected,_}, + #state{connection_state = #connection{channel_cache=_Cache} = Connection0} = State0) -> case ssh_connection:channel_data(ChannelId, Type, Data, Connection0, From) of {{replies, Replies}, Connection} -> - State = send_replies(Replies, State0#state{connection_state = Connection}), + {Repls,State} = send_replies(Replies, State0#state{connection_state = Connection}), start_timeout(ChannelId, From, Timeout), - {next_state, StateName, next_packet(State)}; + {next_state, StateName, State, Repls}; {noreply, Connection} -> start_timeout(ChannelId, From, Timeout), - {next_state, StateName, next_packet(State0#state{connection_state = Connection})} + {next_state, StateName, State0#state{connection_state = Connection}} end; -handle_sync_event({eof, ChannelId}, _From, StateName, - #state{connection_state = - #connection{channel_cache = Cache}} = State0) -> +handle_event({call,From}, {eof, ChannelId}, StateName={connected,_}, + #state{connection_state = #connection{channel_cache=Cache}} = State0) -> case ssh_channel:cache_lookup(Cache, ChannelId) of #channel{remote_id = Id, sent_close = false} -> - State = send_replies([{connection_reply, - ssh_connection:channel_eof_msg(Id)}], State0), - {reply, ok, StateName, next_packet(State)}; + State = send_msg(ssh_connection:channel_eof_msg(Id), State0), + {next_state, StateName, State, [{reply,From,ok}]}; _ -> - {reply, {error,closed}, StateName, State0} + {next_state, StateName, State0, [{reply,From,{error,closed}}]} end; -handle_sync_event({open, ChannelPid, Type, InitialWindowSize, MaxPacketSize, Data, Timeout}, - From, StateName, #state{connection_state = - #connection{channel_cache = Cache}} = State0) -> +handle_event({call,From}, {open, ChannelPid, Type, InitialWindowSize, MaxPacketSize, Data, Timeout}, + StateName={connected,_}, + #state{connection_state = #connection{channel_cache = Cache}} = State0) -> erlang:monitor(process, ChannelPid), {ChannelId, State1} = new_channel_id(State0), Msg = ssh_connection:channel_open_msg(Type, ChannelId, InitialWindowSize, MaxPacketSize, Data), - State2 = send_replies([{connection_reply, Msg}], State1), + State2 = send_msg(Msg, State1), Channel = #channel{type = Type, sys = "none", user = ChannelPid, @@ -935,11 +880,10 @@ handle_sync_event({open, ChannelPid, Type, InitialWindowSize, MaxPacketSize, Dat ssh_channel:cache_update(Cache, Channel), State = add_request(true, ChannelId, From, State2), start_timeout(ChannelId, From, Timeout), - {next_state, StateName, next_packet(remove_timer_ref(State))}; + {next_state, StateName, remove_timer_ref(State)}; -handle_sync_event({send_window, ChannelId}, _From, StateName, - #state{connection_state = - #connection{channel_cache = Cache}} = State) -> +handle_event({call,From}, {send_window, ChannelId}, StateName={connected,_}, + #state{connection_state = #connection{channel_cache = Cache}} = State) -> Reply = case ssh_channel:cache_lookup(Cache, ChannelId) of #channel{send_window_size = WinSize, send_packet_size = Packsize} -> @@ -947,12 +891,10 @@ handle_sync_event({send_window, ChannelId}, _From, StateName, undefined -> {error, einval} end, - {reply, Reply, StateName, next_packet(State)}; - -handle_sync_event({recv_window, ChannelId}, _From, StateName, - #state{connection_state = #connection{channel_cache = Cache}} - = State) -> + {next_state, StateName, State, [{reply,From,Reply}]}; +handle_event({call,From}, {recv_window, ChannelId}, StateName={connected,_}, + #state{connection_state = #connection{channel_cache = Cache}} = State) -> Reply = case ssh_channel:cache_lookup(Cache, ChannelId) of #channel{recv_window_size = WinSize, recv_packet_size = Packsize} -> @@ -960,127 +902,145 @@ handle_sync_event({recv_window, ChannelId}, _From, StateName, undefined -> {error, einval} end, - {reply, Reply, StateName, next_packet(State)}; - -handle_sync_event({close, ChannelId}, _, StateName, - #state{connection_state = - #connection{channel_cache = Cache}} = State0) -> - State = - case ssh_channel:cache_lookup(Cache, ChannelId) of - #channel{remote_id = Id} = Channel -> - State1 = send_replies([{connection_reply, - ssh_connection:channel_close_msg(Id)}], State0), - ssh_channel:cache_update(Cache, Channel#channel{sent_close = true}), - handle_idle_timeout(State1), - State1; - undefined -> - State0 - end, - {reply, ok, StateName, next_packet(State)}. + {next_state, StateName, State, [{reply,From,Reply}]}; -%%-------------------------------------------------------------------- --spec handle_info({atom(), port(), binary()} | {atom(), port()} | - term (), state_name(), #state{}) -> gen_fsm_state_return(). -%%-------------------------------------------------------------------- +handle_event({call,From}, {close, ChannelId}, StateName={connected,_}, + #state{connection_state = + #connection{channel_cache = Cache}} = State0) -> + case ssh_channel:cache_lookup(Cache, ChannelId) of + #channel{remote_id = Id} = Channel -> + State1 = send_msg(ssh_connection:channel_close_msg(Id), State0), + ssh_channel:cache_update(Cache, Channel#channel{sent_close = true}), + handle_idle_timeout(State1), + {next_state, StateName, State1, [{reply,From,ok}]}; + undefined -> + {next_state, StateName, State0, [{reply,From,ok}]} + end; -handle_info({Protocol, Socket, "SSH-" ++ _ = Version}, hello, - #state{socket = Socket, - transport_protocol = Protocol} = State ) -> - event({version_exchange, Version}, hello, State); - -handle_info({Protocol, Socket, Info}, hello, - #state{socket = Socket, - transport_protocol = Protocol} = State) -> - event({info_line, Info}, hello, State); - -handle_info({Protocol, Socket, Data}, StateName, - #state{socket = Socket, - transport_protocol = Protocol, - ssh_params = Ssh0, - decoded_data_buffer = DecData0, - encoded_data_buffer = EncData0, - undecoded_packet_length = RemainingSshPacketLen0} = State0) -> +handle_event(info, {Protocol, Socket, "SSH-" ++ _ = Version}, StateName={hello,_}, + State=#state{socket = Socket, + transport_protocol = Protocol}) -> + {next_state, StateName, State, [{next_event, internal, {version_exchange,Version}}]}; + +handle_event(info, {Protocol, Socket, Info}, StateName={hello,_}, + State=#state{socket = Socket, + transport_protocol = Protocol}) -> + {next_state, StateName, State, [{next_event, internal, {info_line,Info}}]}; + +handle_event(info, {Protocol, Socket, Data}, StateName, State0 = + #state{socket = Socket, + transport_protocol = Protocol, + decoded_data_buffer = DecData0, + encoded_data_buffer = EncData0, + undecoded_packet_length = RemainingSshPacketLen0, + ssh_params = Ssh0}) -> +?IO_FORMAT('~p Recv tcp~n',[self()]), Encoded = <>, - try ssh_transport:handle_packet_part(DecData0, Encoded, RemainingSshPacketLen0, Ssh0) + try ssh_transport:handle_packet_part(DecData0, Encoded, RemainingSshPacketLen0, Ssh0) of + {decoded, Bytes, EncDataRest, Ssh1} -> + State = State0#state{ssh_params = + Ssh1#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh1#ssh.recv_sequence)}, + decoded_data_buffer = <<>>, + undecoded_packet_length = undefined, + encoded_data_buffer = EncDataRest}, + try + ssh_message:decode(set_prefix_if_trouble(Bytes,State)) + of + Msg = #ssh_msg_kexinit{} -> + {next_state, StateName, State, [{next_event, internal, {Msg,Bytes}}, + {next_event, internal, prepare_next_packet} + ]}; + Msg -> + {next_state, StateName, State, [{next_event, internal, Msg}, + {next_event, internal, prepare_next_packet} + ]} + catch + _C:_E -> + DisconnectMsg = + #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, + description = "Encountered unexpected input"}, + disconnect(DisconnectMsg, StateName, State) + end; + {get_more, DecBytes, EncDataRest, RemainingSshPacketLen, Ssh1} -> - {next_state, StateName, - next_packet(State0#state{encoded_data_buffer = EncDataRest, - decoded_data_buffer = DecBytes, - undecoded_packet_length = RemainingSshPacketLen, - ssh_params = Ssh1})}; - {decoded, MsgBytes, EncDataRest, Ssh1} -> - generate_event(MsgBytes, StateName, - State0#state{ssh_params = Ssh1, - %% Important to be set for - %% next_packet -%%% FIXME: the following three seem to always be set in generate_event! - decoded_data_buffer = <<>>, - undecoded_packet_length = undefined, - encoded_data_buffer = EncDataRest}, - EncDataRest); + %% Here we know that there are not enough bytes in EncDataRest to use. Must wait. + inet:setopts(Socket, [{active, once}]), + {next_state, StateName, State0#state{encoded_data_buffer = EncDataRest, + decoded_data_buffer = DecBytes, + undecoded_packet_length = RemainingSshPacketLen, + ssh_params = Ssh1}}; + {bad_mac, Ssh1} -> - DisconnectMsg = + DisconnectMsg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, - description = "Bad mac", - language = ""}, - handle_disconnect(DisconnectMsg, State0#state{ssh_params=Ssh1}); + description = "Bad mac"}, + disconnect(DisconnectMsg, StateName, State0#state{ssh_params=Ssh1}); {error, {exceeds_max_size,PacketLen}} -> - DisconnectMsg = + DisconnectMsg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, - description = "Bad packet length " - ++ integer_to_list(PacketLen), - language = ""}, - handle_disconnect(DisconnectMsg, State0) + description = "Bad packet length " + ++ integer_to_list(PacketLen)}, + disconnect(DisconnectMsg, StateName, State0) catch - _:_ -> - DisconnectMsg = + _C:_E -> + DisconnectMsg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, - description = "Bad packet", - language = ""}, - handle_disconnect(DisconnectMsg, State0) + description = "Bad packet"}, + disconnect(DisconnectMsg, StateName, State0) end; - -handle_info({CloseTag, _Socket}, _StateName, - #state{transport_close_tag = CloseTag, - ssh_params = #ssh{role = _Role, opts = _Opts}} = State) -> - DisconnectMsg = + +handle_event(internal, prepare_next_packet, StateName, State) -> + Enough = erlang:max(8, State#state.ssh_params#ssh.decrypt_block_size), + case size(State#state.encoded_data_buffer) of + Sz when Sz >= Enough -> +?IO_FORMAT('~p Send <<>> to self~n',[self()]), + self() ! {State#state.transport_protocol, State#state.socket, <<>>}; + _ -> +?IO_FORMAT('~p Set active_once~n',[self()]), + inet:setopts(State#state.socket, [{active, once}]) + end, + {next_state, StateName, State}; + +handle_event(info, {CloseTag,Socket}, StateName, + State=#state{socket = Socket, + transport_close_tag = CloseTag}) -> + DisconnectMsg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, - description = "Connection closed", - language = "en"}, - handle_disconnect(DisconnectMsg, State); + description = "Connection closed"}, + disconnect(DisconnectMsg, StateName, State); -handle_info({timeout, {_, From} = Request}, Statename, +handle_event(info, {timeout, {_, From} = Request}, StateName, #state{connection_state = #connection{requests = Requests} = Connection} = State) -> case lists:member(Request, Requests) of true -> - gen_fsm:reply(From, {error, timeout}), - {next_state, Statename, + {next_state, StateName, State#state{connection_state = Connection#connection{requests = - lists:delete(Request, Requests)}}}; + lists:delete(Request, Requests)}}, + [{reply,From,{error,timeout}}]}; false -> - {next_state, Statename, State} + {next_state, StateName, State} end; %%% Handle that ssh channels user process goes down -handle_info({'DOWN', _Ref, process, ChannelPid, _Reason}, Statename, State0) -> +handle_event(info, {'DOWN', _Ref, process, ChannelPid, _Reason}, StateName, State0) -> {{replies, Replies}, State1} = handle_channel_down(ChannelPid, State0), - State = send_replies(Replies, State1), - {next_state, Statename, next_packet(State)}; + {Repls, State} = send_replies(Replies, State1), + {next_state, StateName, State, Repls}; %%% So that terminate will be run when supervisor is shutdown -handle_info({'EXIT', _Sup, Reason}, _StateName, State) -> - {stop, {shutdown, Reason}, State}; +handle_event(info, {'EXIT', _Sup, Reason}, _, _) -> + {stop, {shutdown, Reason}}; -handle_info({check_cache, _ , _}, - StateName, #state{connection_state = - #connection{channel_cache = Cache}} = State) -> +handle_event(info, {check_cache, _ , _}, StateName, + #state{connection_state = #connection{channel_cache=Cache}} = State) -> {next_state, StateName, check_cache(State, Cache)}; -handle_info(UnexpectedMessage, StateName, #state{opts = Opts, - ssh_params = SshParams} = State) -> +handle_event(info, UnexpectedMessage, StateName, + State = #state{opts = Opts, + ssh_params = SshParams}) -> case unexpected_fun(UnexpectedMessage, Opts, SshParams) of report -> Msg = lists:flatten( @@ -1091,10 +1051,11 @@ handle_info(UnexpectedMessage, StateName, #state{opts = Opts, "Local Address: ~p\n", [UnexpectedMessage, StateName, SshParams#ssh.role, SshParams#ssh.peer, proplists:get_value(address, SshParams#ssh.opts)])), - error_logger:info_report(Msg); + error_logger:info_report(Msg), + {next_state, StateName, State}; skip -> - ok; + {next_state, StateName, State}; Other -> Msg = lists:flatten( @@ -1103,60 +1064,78 @@ handle_info(UnexpectedMessage, StateName, #state{opts = Opts, "Message: ~p\n" "Role: ~p\n" "Peer: ~p\n" - "Local Address: ~p\n", [Other, UnexpectedMessage, - SshParams#ssh.role, + "Local Address: ~p\n", [Other, UnexpectedMessage, + SshParams#ssh.role, element(2,SshParams#ssh.peer), proplists:get_value(address, SshParams#ssh.opts)] )), + error_logger:error_report(Msg), + {next_state, StateName, State} + end; - error_logger:error_report(Msg) - end, - {next_state, StateName, State}. +handle_event(internal, {disconnect,Msg,_Reason}, StateName, State) -> + disconnect(Msg, StateName, State); + +handle_event(Type, Ev, StateName, State) -> + case catch atom_to_list(element(1,Ev)) of + "ssh_msg_" ++_ when Type==internal -> + Msg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, + description = "Message in wrong state"}, + disconnect(Msg, StateName, State); + _ -> + Msg = #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, + description = "Internal error"}, + disconnect(Msg, StateName, State) + end. %%-------------------------------------------------------------------- --spec terminate(Reason::term(), state_name(), #state{}) -> _. -%%-------------------------------------------------------------------- -terminate(normal, _, #state{transport_cb = Transport, - connection_state = Connection, - socket = Socket}) -> - terminate_subsystem(Connection), - (catch Transport:close(Socket)), - ok; +terminate(normal, StateName, State) -> + ?IO_FORMAT('~p ~p:~p terminate ~p ~p~n',[self(),?MODULE,?LINE,normal,StateName]), + normal_termination(StateName, State); terminate({shutdown,{init,Reason}}, StateName, State) -> + ?IO_FORMAT('~p ~p:~p terminate ~p ~p~n',[self(),?MODULE,?LINE,{shutdown,{init,Reason}},StateName]), error_logger:info_report(io_lib:format("Erlang ssh in connection handler init: ~p~n",[Reason])), - terminate(normal, StateName, State); + normal_termination(StateName, State); + +terminate(shutdown, StateName, State) -> + %% Terminated by supervisor + ?IO_FORMAT('~p ~p:~p terminate ~p ~p~n',[self(),?MODULE,?LINE,shutdown,StateName]), + normal_termination(StateName, + #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, + description = "Application shutdown"}, + State); + +%% terminate({shutdown,Msg}, StateName, State) when is_record(Msg,ssh_msg_disconnect)-> +%% ?IO_FORMAT('~p ~p:~p terminate ~p ~p~n',[self(),?MODULE,?LINE,{shutdown,Msg},StateName]), +%% normal_termination(StateName, Msg, State); + +terminate({shutdown,_R}, StateName, State) -> + ?IO_FORMAT('~p ~p:~p terminate ~p ~p~n',[self(),?MODULE,?LINE,{shutdown,_R},StateName]), + normal_termination(StateName, State); + +terminate(Reason, StateName, State) -> + %% Others, e.g undef, {badmatch,_} + ?IO_FORMAT('~p ~p:~p terminate ~p ~p~n',[self(),?MODULE,?LINE,Reason,StateName]), + log_error(Reason), + normal_termination(StateName, + #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, + description = "Internal error"}, + State). -%% Terminated by supervisor -terminate(shutdown, StateName, #state{ssh_params = Ssh0} = State) -> - DisconnectMsg = - #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, - description = "Application shutdown", - language = "en"}, - {SshPacket, Ssh} = ssh_transport:ssh_packet(DisconnectMsg, Ssh0), - send_msg(SshPacket, State), - terminate(normal, StateName, State#state{ssh_params = Ssh}); - -terminate({shutdown, #ssh_msg_disconnect{} = Msg}, StateName, - #state{ssh_params = Ssh0} = State) -> - {SshPacket, Ssh} = ssh_transport:ssh_packet(Msg, Ssh0), - send_msg(SshPacket, State), - terminate(normal, StateName, State#state{ssh_params = Ssh}); - -terminate({shutdown, _}, StateName, State) -> - terminate(normal, StateName, State); - -terminate(Reason, StateName, #state{ssh_params = Ssh0, starter = _Pid, - connection_state = Connection} = State) -> + +normal_termination(StateName, Msg, State0) -> + State = send_msg(Msg,State0), +timer:sleep(400), %% FIXME!!! gen_tcp:shutdown instead + normal_termination(StateName, State). + +normal_termination(_StateName, #state{transport_cb = Transport, + connection_state = Connection, + socket = Socket}) -> + ?IO_FORMAT('~p ~p:~p normal_termination in state ~p~n',[self(),?MODULE,?LINE,_StateName]), terminate_subsystem(Connection), - log_error(Reason), - DisconnectMsg = - #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, - description = "Internal error", - language = "en"}, - {SshPacket, Ssh} = ssh_transport:ssh_packet(DisconnectMsg, Ssh0), - send_msg(SshPacket, State), - terminate(normal, StateName, State#state{ssh_params = Ssh}). + (catch Transport:close(Socket)), + ok. terminate_subsystem(#connection{system_supervisor = SysSup, @@ -1165,9 +1144,10 @@ terminate_subsystem(#connection{system_supervisor = SysSup, terminate_subsystem(_) -> ok. -format_status(normal, [_, State]) -> - [{data, [{"StateData", State}]}]; -format_status(terminate, [_, State]) -> + +format_status(normal, [_, _StateName, State]) -> + [{data, [{"State", State}]}]; +format_status(terminate, [_, _StateName, State]) -> SshParams0 = (State#state.ssh_params), SshParams = SshParams0#ssh{c_keyinit = "***", s_keyinit = "***", @@ -1183,37 +1163,44 @@ format_status(terminate, [_, State]) -> decompress_ctx = "***", shared_secret = "***", exchanged_hash = "***", - session_id = "***", - keyex_key = "***", - keyex_info = "***", + session_id = "***", + keyex_key = "***", + keyex_info = "***", available_host_keys = "***"}, - [{data, [{"StateData", State#state{decoded_data_buffer = "***", - encoded_data_buffer = "***", - key_exchange_init_msg = "***", - opts = "***", - recbuf = "***", - ssh_params = SshParams - }}]}]. + [{data, [{"State", State#state{decoded_data_buffer = "***", + encoded_data_buffer = "***", + key_exchange_init_msg = "***", + opts = "***", + recbuf = "***", + ssh_params = SshParams + }}]}]. + -%%-------------------------------------------------------------------- --spec code_change(OldVsn::term(), state_name(), Oldstate::term(), Extra::term()) -> - {ok, state_name(), #state{}}. -%%-------------------------------------------------------------------- code_change(_OldVsn, StateName, State, _Extra) -> {ok, StateName, State}. %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- -init_role(#state{role = client, opts = Opts} = State0) -> + +%% StateName to Role +role({_,Role}) -> Role; +role({_,Role,_}) -> Role. + +renegotiation({_,_,ReNeg}) -> ReNeg == renegotiation; +renegotiation(_) -> false. + + + +init_role(client, #state{opts = Opts} = State0) -> Pid = proplists:get_value(user_pid, Opts), TimerRef = get_idle_time(Opts), - timer:apply_after(?REKEY_TIMOUT, gen_fsm, send_all_state_event, [self(), renegotiate]), - timer:apply_after(?REKEY_DATA_TIMOUT, gen_fsm, send_all_state_event, + timer:apply_after(?REKEY_TIMOUT, gen_statem, cast, [self(), renegotiate]), + timer:apply_after(?REKEY_DATA_TIMOUT, gen_statem, cast, [self(), data_size]), State0#state{starter = Pid, idle_timer_ref = TimerRef}; -init_role(#state{role = server, opts = Opts, connection_state = Connection} = State) -> +init_role(server, #state{opts = Opts, connection_state = Connection} = State) -> Sups = proplists:get_value(supervisors, Opts), Pid = proplists:get_value(user_pid, Opts), SystemSup = proplists:get_value(system_sup, Sups), @@ -1240,16 +1227,16 @@ get_idle_time(SshOptions) -> init_ssh(client = Role, Vsn, Version, Options, Socket) -> IOCb = case proplists:get_value(user_interaction, Options, true) of - true -> + true -> ssh_io; - false -> + false -> ssh_no_io end, - AuthMethods = proplists:get_value(auth_methods, Options, + AuthMethods = proplists:get_value(auth_methods, Options, ?SUPPORTED_AUTH_METHODS), {ok, PeerAddr} = inet:peername(Socket), - + PeerName = proplists:get_value(host, Options), KeyCb = proplists:get_value(key_cb, Options, ssh_file), @@ -1263,13 +1250,13 @@ init_ssh(client = Role, Vsn, Version, Options, Socket) -> userauth_supported_methods = AuthMethods, peer = {PeerName, PeerAddr}, available_host_keys = supported_host_keys(Role, KeyCb, Options), - random_length_padding = proplists:get_value(max_random_length_padding, - Options, + random_length_padding = proplists:get_value(max_random_length_padding, + Options, (#ssh{})#ssh.random_length_padding) }; init_ssh(server = Role, Vsn, Version, Options, Socket) -> - AuthMethods = proplists:get_value(auth_methods, Options, + AuthMethods = proplists:get_value(auth_methods, Options, ?SUPPORTED_AUTH_METHODS), AuthMethodsAsList = string:tokens(AuthMethods, ","), {ok, PeerAddr} = inet:peername(Socket), @@ -1286,17 +1273,17 @@ init_ssh(server = Role, Vsn, Version, Options, Socket) -> kb_tries_left = 3, peer = {undefined, PeerAddr}, available_host_keys = supported_host_keys(Role, KeyCb, Options), - random_length_padding = proplists:get_value(max_random_length_padding, - Options, + random_length_padding = proplists:get_value(max_random_length_padding, + Options, (#ssh{})#ssh.random_length_padding) }. supported_host_keys(client, _, Options) -> try - case proplists:get_value(public_key, + case proplists:get_value(public_key, proplists:get_value(preferred_algorithms,Options,[]) ) of - undefined -> + undefined -> ssh_transport:default_algorithms(public_key); L -> L -- (L--ssh_transport:default_algorithms(public_key)) @@ -1311,7 +1298,7 @@ supported_host_keys(client, _, Options) -> {stop, {shutdown, Reason}} end; supported_host_keys(server, KeyCb, Options) -> - [atom_to_list(A) || A <- proplists:get_value(public_key, + [atom_to_list(A) || A <- proplists:get_value(public_key, proplists:get_value(preferred_algorithms,Options,[]), ssh_transport:default_algorithms(public_key) ), @@ -1322,10 +1309,19 @@ supported_host_keys(server, KeyCb, Options) -> available_host_key(KeyCb, Alg, Opts) -> element(1, catch KeyCb:host_key(Alg, Opts)) == ok. -send_msg(Msg, #state{socket = Socket, transport_cb = Transport}) -> - Transport:send(Socket, Msg). -handle_version({2, 0} = NumVsn, StrVsn, Ssh0) -> +send_msg(Msg, State=#state{ssh_params=Ssh0}) when is_tuple(Msg) -> + {Bytes, Ssh} = ssh_transport:ssh_packet(Msg, Ssh0), + send_bytes(Bytes, State), + State#state{ssh_params=Ssh}. + +send_bytes(Bytes, #state{socket = Socket, transport_cb = Transport}) -> + R = Transport:send(Socket, Bytes), +?IO_FORMAT('~p send_bytes ~p~n',[self(),R]), + R. + + +handle_version({2, 0} = NumVsn, StrVsn, Ssh0) -> Ssh = counterpart_versions(NumVsn, StrVsn, Ssh0), {ok, Ssh}; handle_version(_,_,_) -> @@ -1336,161 +1332,89 @@ string_version(#ssh{role = client, c_version = Vsn}) -> string_version(#ssh{role = server, s_version = Vsn}) -> Vsn. -send_event(FsmPid, Event) -> - gen_fsm:send_event(FsmPid, Event). -send_all_state_event(FsmPid, Event) -> - gen_fsm:send_all_state_event(FsmPid, Event). +cast(FsmPid, Event) -> + gen_statem:cast(FsmPid, Event). -sync_send_all_state_event(FsmPid, Event) -> - sync_send_all_state_event(FsmPid, Event, infinity). +call(FsmPid, Event) -> + call(FsmPid, Event, infinity). -sync_send_all_state_event(FsmPid, Event, Timeout) -> - try gen_fsm:sync_send_all_state_event(FsmPid, Event, Timeout) of - {closed, _Channel} -> +call(FsmPid, Event, Timeout) -> + try gen_statem:call(FsmPid, Event, Timeout) of + {closed, _R} -> + {error, closed}; + {killed, _R} -> {error, closed}; Result -> Result catch - exit:{noproc, _} -> + exit:{noproc, _R} -> {error, closed}; - exit:{normal, _} -> + exit:{normal, _R} -> {error, closed}; - exit:{{shutdown, _},_} -> + exit:{{shutdown, _R},_} -> {error, closed} end. -%% simulate send_all_state_event(self(), Event) -event(#ssh_msg_disconnect{} = Event, StateName, State) -> - handle_event(Event, StateName, State); -event(#ssh_msg_ignore{} = Event, StateName, State) -> - handle_event(Event, StateName, State); -event(#ssh_msg_debug{} = Event, StateName, State) -> - handle_event(Event, StateName, State); -event(#ssh_msg_unimplemented{} = Event, StateName, State) -> - handle_event(Event, StateName, State); -%% simulate send_event(self(), Event) -event(Event, StateName, State) -> - try - ?MODULE:StateName(Event, State) + +handle_connection_msg(Msg, StateName, State0 = + #state{starter = User, + connection_state = Connection0, + event_queue = Qev0}) -> + Renegotiation = renegotiation(StateName), + Role = role(StateName), + try ssh_connection:handle_msg(Msg, Connection0, Role) of + {{replies, Replies}, Connection} -> + case StateName of + {connected,_} -> + {Repls, State} = send_replies(Replies, + State0#state{connection_state=Connection}), + {next_state, StateName, State, Repls}; + _ -> + {ConnReplies, Replies} = + lists:splitwith(fun not_connected_filter/1, Replies), + {Repls, State} = send_replies(Replies, + State0#state{event_queue = Qev0 ++ ConnReplies}), + {next_state, StateName, State, Repls} + end; + + {noreply, Connection} -> + {next_state, StateName, State0#state{connection_state = Connection}}; + + {disconnect, Reason0, {{replies, Replies}, Connection}} -> + {Repls,State} = send_replies(Replies, State0#state{connection_state = Connection}), + case {Reason0,Role} of + {{_, Reason}, client} when ((StateName =/= {connected,client}) and (not Renegotiation)) -> + User ! {self(), not_connected, Reason}; + _ -> + ok + end, + {stop, {shutdown,normal}, Repls, State#state{connection_state = Connection}} + catch - throw:#ssh_msg_disconnect{} = DisconnectMsg -> - handle_disconnect(DisconnectMsg, State); - throw:{ErrorToDisplay, #ssh_msg_disconnect{} = DisconnectMsg} -> - handle_disconnect(DisconnectMsg, State, ErrorToDisplay); - _C:_Error -> - handle_disconnect(#ssh_msg_disconnect{code = error_code(StateName), - description = "Invalid state", - language = "en"}, State) + _:Error -> + {disconnect, _Reason, {{replies, Replies}, Connection}} = + ssh_connection:handle_msg( + #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, + description = "Internal error"}, + Connection0, Role), + {Repls,State} = send_replies(Replies, State0#state{connection_state = Connection}), + {stop, {shutdown,Error}, Repls, State#state{connection_state = Connection}} end. -error_code(key_exchange) -> - ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED; -error_code(new_keys) -> - ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED; -error_code(_) -> - ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE. - -generate_event(<> = Msg, StateName, - #state{ - role = Role, - starter = User, - renegotiate = Renegotiation, - connection_state = Connection0} = State0, EncData) - when Byte == ?SSH_MSG_GLOBAL_REQUEST; - Byte == ?SSH_MSG_REQUEST_SUCCESS; - Byte == ?SSH_MSG_REQUEST_FAILURE; - Byte == ?SSH_MSG_CHANNEL_OPEN; - Byte == ?SSH_MSG_CHANNEL_OPEN_CONFIRMATION; - Byte == ?SSH_MSG_CHANNEL_OPEN_FAILURE; - Byte == ?SSH_MSG_CHANNEL_WINDOW_ADJUST; - Byte == ?SSH_MSG_CHANNEL_DATA; - Byte == ?SSH_MSG_CHANNEL_EXTENDED_DATA; - Byte == ?SSH_MSG_CHANNEL_EOF; - Byte == ?SSH_MSG_CHANNEL_CLOSE; - Byte == ?SSH_MSG_CHANNEL_REQUEST; - Byte == ?SSH_MSG_CHANNEL_SUCCESS; - Byte == ?SSH_MSG_CHANNEL_FAILURE -> - try - ssh_message:decode(Msg) - of - ConnectionMsg -> - State1 = generate_event_new_state(State0, EncData), - try ssh_connection:handle_msg(ConnectionMsg, Connection0, Role) of - {{replies, Replies0}, Connection} -> - if StateName == connected -> - Replies = Replies0, - State2 = State1; - true -> - {ConnReplies, Replies} = - lists:splitwith(fun not_connected_filter/1, Replies0), - Q = State1#state.event_queue ++ ConnReplies, - State2 = State1#state{ event_queue = Q } - end, - State = send_replies(Replies, State2#state{connection_state = Connection}), - {next_state, StateName, next_packet(State)}; - {noreply, Connection} -> - {next_state, StateName, next_packet(State1#state{connection_state = Connection})}; - {disconnect, {_, Reason}, {{replies, Replies}, Connection}} when - Role == client andalso ((StateName =/= connected) and (not Renegotiation)) -> - State = send_replies(Replies, State1#state{connection_state = Connection}), - User ! {self(), not_connected, Reason}, - {stop, {shutdown, normal}, - next_packet(State#state{connection_state = Connection})}; - {disconnect, _Reason, {{replies, Replies}, Connection}} -> - State = send_replies(Replies, State1#state{connection_state = Connection}), - {stop, {shutdown, normal}, State#state{connection_state = Connection}} - catch - _:Error -> - {disconnect, _Reason, {{replies, Replies}, Connection}} = - ssh_connection:handle_msg( - #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, - description = "Internal error", - language = "en"}, Connection0, Role), - State = send_replies(Replies, State1#state{connection_state = Connection}), - {stop, {shutdown, Error}, State#state{connection_state = Connection}} - end - catch - _:_ -> - handle_disconnect( - #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, - description = "Bad packet received", - language = ""}, State0) - end; -generate_event(Msg, StateName, State0, EncData) -> - try - Event = ssh_message:decode(set_prefix_if_trouble(Msg,State0)), - State = generate_event_new_state(State0, EncData), - case Event of - #ssh_msg_kexinit{} -> - %% We need payload for verification later. - event({Event, Msg}, StateName, State); - _ -> - event(Event, StateName, State) - end - catch - _C:_E -> - DisconnectMsg = - #ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR, - description = "Encountered unexpected input", - language = "en"}, - handle_disconnect(DisconnectMsg, State0) - end. - - -set_prefix_if_trouble(Msg = <>, #state{ssh_params=SshParams}) +set_prefix_if_trouble(Msg = <>, #state{ssh_params=SshParams}) when Op == 30; Op == 31 -> case catch atom_to_list(kex(SshParams)) of - "ecdh-sha2-" ++ _ -> + "ecdh-sha2-" ++ _ -> <<"ecdh",Msg/binary>>; "diffie-hellman-group-exchange-" ++ _ -> <<"dh_gex",Msg/binary>>; "diffie-hellman-group" ++ _ -> <<"dh",Msg/binary>>; - _ -> + _ -> Msg end; set_prefix_if_trouble(Msg, _) -> @@ -1499,7 +1423,7 @@ set_prefix_if_trouble(Msg, _) -> kex(#ssh{algorithms=#alg{kex=Kex}}) -> Kex; kex(_) -> undefined. - +%%%---------------------------------------------------------------- handle_request(ChannelPid, ChannelId, Type, Data, WantReply, From, #state{connection_state = #connection{channel_cache = Cache}} = State0) -> @@ -1508,11 +1432,9 @@ handle_request(ChannelPid, ChannelId, Type, Data, WantReply, From, update_sys(Cache, Channel, Type, ChannelPid), Msg = ssh_connection:channel_request_msg(Id, Type, WantReply, Data), - Replies = [{connection_reply, Msg}], - State = add_request(WantReply, ChannelId, From, State0), - {{replies, Replies}, State}; + send_msg(Msg, add_request(WantReply, ChannelId, From, State0)); undefined -> - {{replies, []}, State0} + State0 end. handle_request(ChannelId, Type, Data, WantReply, From, @@ -1522,13 +1444,12 @@ handle_request(ChannelId, Type, Data, WantReply, From, #channel{remote_id = Id} -> Msg = ssh_connection:channel_request_msg(Id, Type, WantReply, Data), - Replies = [{connection_reply, Msg}], - State = add_request(WantReply, ChannelId, From, State0), - {{replies, Replies}, State}; + send_msg(Msg, add_request(WantReply, ChannelId, From, State0)); undefined -> - {{replies, []}, State0} + State0 end. +%%%---------------------------------------------------------------- handle_global_request({global_request, ChannelPid, "tcpip-forward" = Type, WantReply, < Connection = ssh_connection:unbind(IP, Port, Connection0), Msg = ssh_connection:global_request_msg(Type, WantReply, Data), - send_replies([{connection_reply, Msg}], State#state{connection_state = Connection}); + send_msg(Msg, State#state{connection_state = Connection}); handle_global_request({global_request, _, "cancel-tcpip-forward" = Type, WantReply, Data}, State) -> Msg = ssh_connection:global_request_msg(Type, WantReply, Data), - send_replies([{connection_reply, Msg}], State). + send_msg(Msg, State). +%%%---------------------------------------------------------------- handle_idle_timeout(#state{opts = Opts}) -> case proplists:get_value(idle_time, Opts, infinity) of infinity -> @@ -1594,21 +1516,10 @@ new_channel_id(#state{connection_state = #connection{channel_id_seed = Id} = {Id, State#state{connection_state = Connection#connection{channel_id_seed = Id + 1}}}. -generate_event_new_state(#state{ssh_params = - #ssh{recv_sequence = SeqNum0} - = Ssh} = State, EncData) -> - SeqNum = ssh_transport:next_seqnum(SeqNum0), - State#state{ssh_params = Ssh#ssh{recv_sequence = SeqNum}, - decoded_data_buffer = <<>>, - encoded_data_buffer = EncData, - undecoded_packet_length = undefined}. - -next_packet(#state{decoded_data_buffer = <<>>, - encoded_data_buffer = Buff, - ssh_params = #ssh{decrypt_block_size = BlockSize}, - socket = Socket, - transport_protocol = Protocol} = State) when Buff =/= <<>> -> - case size(Buff) >= erlang:max(8, BlockSize) of +prepare_for_next_packet(State = #state{transport_protocol = Protocol, + socket = Socket}, + Ssh, EncDataRest) -> + case size(EncDataRest) >= erlang:max(8, Ssh#ssh.decrypt_block_size) of true -> %% Enough data from the next packet has been received to %% decode the length indicator, fake a socket-recive @@ -1617,84 +1528,37 @@ next_packet(#state{decoded_data_buffer = <<>>, false -> inet:setopts(Socket, [{active, once}]) end, - State; - -next_packet(#state{socket = Socket} = State) -> - inet:setopts(Socket, [{active, once}]), - State. - -after_new_keys(#state{renegotiate = true} = State) -> - State1 = State#state{renegotiate = false, event_queue = []}, - lists:foldr(fun after_new_keys_events/2, {next_state, connected, State1}, State#state.event_queue); -after_new_keys(#state{renegotiate = false, - ssh_params = #ssh{role = client} = Ssh0} = State) -> - {Msg, Ssh} = ssh_auth:service_request_msg(Ssh0), - send_msg(Msg, State), - {next_state, service_request, State#state{ssh_params = Ssh}}; -after_new_keys(#state{renegotiate = false, - ssh_params = #ssh{role = server}} = State) -> - {next_state, service_request, State}. - -after_new_keys_events({sync, _Event, From}, {stop, _Reason, _StateData}=Terminator) -> - gen_fsm:reply(From, {error, closed}), - Terminator; -after_new_keys_events(_, {stop, _Reason, _StateData}=Terminator) -> - Terminator; -after_new_keys_events({sync, Event, From}, {next_state, StateName, StateData}) -> - case handle_sync_event(Event, From, StateName, StateData) of - {reply, Reply, NextStateName, NewStateData} -> - gen_fsm:reply(From, Reply), - {next_state, NextStateName, NewStateData}; - {next_state, NextStateName, NewStateData}-> - {next_state, NextStateName, NewStateData}; - {stop, Reason, Reply, NewStateData} -> - gen_fsm:reply(From, Reply), - {stop, Reason, NewStateData} - end; -after_new_keys_events({event, Event}, {next_state, StateName, StateData}) -> - case handle_event(Event, StateName, StateData) of - {next_state, NextStateName, NewStateData}-> - {next_state, NextStateName, NewStateData}; - {stop, Reason, NewStateData} -> - {stop, Reason, NewStateData} - end; -after_new_keys_events({connection_reply, _Data} = Reply, {StateName, State}) -> - NewState = send_replies([Reply], State), - {next_state, StateName, NewState}. + State#state{ssh_params = + Ssh#ssh{recv_sequence = ssh_transport:next_seqnum(Ssh#ssh.recv_sequence)}, + decoded_data_buffer = <<>>, + undecoded_packet_length = undefined, + encoded_data_buffer = EncDataRest}. +%%%---------------------------------------------------------------- +%%% Some other module has decided to disconnect: +disconnect(Msg = #ssh_msg_disconnect{}) -> + throw({keep_state_and_data, + [{next_event, internal, {disconnect, Msg, Msg#ssh_msg_disconnect.description}}]}). -handle_disconnect(DisconnectMsg, State) -> - handle_disconnect(own, DisconnectMsg, State). +disconnect(Msg = #ssh_msg_disconnect{}, ExtraInfo) -> + throw({keep_state_and_data, + [{next_event, internal, {disconnect, Msg, {Msg#ssh_msg_disconnect.description,ExtraInfo}}}]}). -handle_disconnect(#ssh_msg_disconnect{} = DisconnectMsg, State, Error) -> - handle_disconnect(own, DisconnectMsg, State, Error); -handle_disconnect(Type, #ssh_msg_disconnect{description = Desc} = Msg, #state{connection_state = Connection0, role = Role} = State0) -> - {disconnect, _, {{replies, Replies}, Connection}} = ssh_connection:handle_msg(Msg, Connection0, Role), - State = send_replies(disconnect_replies(Type, Msg, Replies), State0), - disconnect_fun(Desc, State#state.opts), - {stop, {shutdown, Desc}, State#state{connection_state = Connection}}. -handle_disconnect(Type, #ssh_msg_disconnect{description = Desc} = Msg, #state{connection_state = Connection0, - role = Role} = State0, ErrorMsg) -> - {disconnect, _, {{replies, Replies}, Connection}} = ssh_connection:handle_msg(Msg, Connection0, Role), - State = send_replies(disconnect_replies(Type, Msg, Replies), State0), - disconnect_fun(Desc, State#state.opts), - {stop, {shutdown, {Desc, ErrorMsg}}, State#state{connection_state = Connection}}. - -disconnect_replies(own, Msg, Replies) -> - [{connection_reply, Msg} | Replies]; -disconnect_replies(peer, _, Replies) -> - Replies. +%% %%% This server/client has decided to disconnect via the state machine: +disconnect(Msg=#ssh_msg_disconnect{description=Description}, _StateName, State0) -> + ?IO_FORMAT('~p ~p:~p disconnect ~p ~p~n',[self(),?MODULE,?LINE,Msg,_StateName]), + State = send_msg(Msg, State0), + disconnect_fun(Description, State#state.opts), +timer:sleep(400), + {stop, {shutdown,Description}, State}. +%%%---------------------------------------------------------------- counterpart_versions(NumVsn, StrVsn, #ssh{role = server} = Ssh) -> Ssh#ssh{c_vsn = NumVsn , c_version = StrVsn}; counterpart_versions(NumVsn, StrVsn, #ssh{role = client} = Ssh) -> Ssh#ssh{s_vsn = NumVsn , s_version = StrVsn}. -opposite_role(client) -> - server; -opposite_role(server) -> - client. connected_fun(User, PeerAddr, Method, Opts) -> case proplists:get_value(connectfun, Opts) of undefined -> @@ -1739,7 +1603,7 @@ ssh_info([client_version | Rest], #state{ssh_params = #ssh{c_vsn = IntVsn, ssh_info([server_version | Rest], #state{ssh_params =#ssh{s_vsn = IntVsn, s_version = StringVsn}} = State, Acc) -> ssh_info(Rest, State, [{server_version, {IntVsn, StringVsn}} | Acc]); -ssh_info([peer | Rest], #state{ssh_params = #ssh{peer = Peer}} = State, Acc) -> +ssh_info([peer | Rest], #state{ssh_params = #ssh{peer = Peer}} = State, Acc) -> ssh_info(Rest, State, [{peer, Peer} | Acc]); ssh_info([sockname | Rest], #state{socket = Socket} = State, Acc) -> {ok, SockName} = inet:sockname(Socket), @@ -1749,6 +1613,7 @@ ssh_info([user | Rest], #state{auth_user = User} = State, Acc) -> ssh_info([ _ | Rest], State, Acc) -> ssh_info(Rest, State, Acc). + ssh_channel_info([], _, Acc) -> Acc; @@ -1765,38 +1630,48 @@ ssh_channel_info([send_window | Rest], #channel{send_window_size = WinSize, ssh_channel_info([ _ | Rest], Channel, Acc) -> ssh_channel_info(Rest, Channel, Acc). + log_error(Reason) -> - Report = io_lib:format("Erlang ssh connection handler failed with reason: " - "~p ~n, Stacktrace: ~p ~n", - [Reason, erlang:get_stacktrace()]), - error_logger:error_report(Report), - "Internal error". - -not_connected_filter({connection_reply, _Data}) -> - true; -not_connected_filter(_) -> - false. - -send_replies([], State) -> - State; -send_replies([{connection_reply, Data} | Rest], #state{ssh_params = Ssh0} = State) -> - {Packet, Ssh} = ssh_transport:ssh_packet(Data, Ssh0), - send_msg(Packet, State), - send_replies(Rest, State#state{ssh_params = Ssh}); -send_replies([Msg | Rest], State) -> - catch send_reply(Msg), - send_replies(Rest, State). - -send_reply({channel_data, Pid, Data}) -> - Pid ! {ssh_cm, self(), Data}; -send_reply({channel_requst_reply, From, Data}) -> - gen_fsm:reply(From, Data); -send_reply({flow_control, Cache, Channel, From, Msg}) -> + Report = io_lib:format("Erlang ssh connection handler failed with reason:~n" + " ~p~n" + "Stacktrace:~n" + " ~p~n", + [Reason, erlang:get_stacktrace()]), + error_logger:error_report(Report). + + +%%%---------------------------------------------------------------- +not_connected_filter({connection_reply, _Data}) -> true; +not_connected_filter(_) -> false. + +%%%---------------------------------------------------------------- +send_replies(Repls, State) -> + lists:foldl(fun get_repl/2, + {[],State}, + Repls). + +get_repl({connection_reply,Msg}, {CallRepls,S}) -> + {CallRepls, send_msg(Msg,S)}; +get_repl({channel_data,undefined,Data}, Acc) -> + Acc; +get_repl({channel_data,Pid,Data}, Acc) -> + Pid ! {ssh_cm, self(), Data}, + Acc; +get_repl({channel_request_reply,From,Data}, {CallRepls,S}) -> + {[{reply,From,Data}|CallRepls], S}; +get_repl({flow_control,Cache,Channel,From,Msg}, {CallRepls,S}) -> ssh_channel:cache_update(Cache, Channel#channel{flow_control = undefined}), - gen_fsm:reply(From, Msg); -send_reply({flow_control, From, Msg}) -> - gen_fsm:reply(From, Msg). + {[{reply,From,Msg}|CallRepls], S}; +get_repl({flow_control,From,Msg}, {CallRepls,S}) -> + {[{reply,From,Msg}|CallRepls], S}; +get_repl(noreply, Acc) -> + Acc; +get_repl(X, Acc) -> + exit({get_repl,X,Acc}). + + +%%%---------------------------------------------------------------- disconnect_fun({disconnect,Msg}, Opts) -> disconnect_fun(Msg, Opts); disconnect_fun(_, undefined) -> @@ -1814,7 +1689,7 @@ unexpected_fun(UnexpectedMessage, Opts, #ssh{peer={_,Peer}}) -> undefined -> report; Fun -> - catch Fun(UnexpectedMessage, Peer) + catch Fun(UnexpectedMessage, Peer) end. @@ -1852,7 +1727,7 @@ remove_timer_ref(State) -> socket_control(Socket, Pid, Transport) -> case Transport:controlling_process(Socket, Pid) of ok -> - send_event(Pid, socket_control); + gen_statem:cast(Pid, socket_control); {error, Reason} -> {error, Reason} end. @@ -1893,4 +1768,3 @@ getopt(Opt, Socket) -> Other -> {error, {unexpected_getopts_return, Other}} end. - -- cgit v1.2.3