diff options
-rw-r--r-- | lib/ssh/src/ssh_connection_handler.erl | 13 | ||||
-rw-r--r-- | lib/ssh/src/ssh_connection_manager.erl | 49 | ||||
-rw-r--r-- | lib/ssh/src/ssh_transport.erl | 4 |
3 files changed, 56 insertions, 10 deletions
diff --git a/lib/ssh/src/ssh_connection_handler.erl b/lib/ssh/src/ssh_connection_handler.erl index 787d82c4db..74a6ac7d19 100644 --- a/lib/ssh/src/ssh_connection_handler.erl +++ b/lib/ssh/src/ssh_connection_handler.erl @@ -223,11 +223,13 @@ key_exchange(#ssh_msg_kexdh_reply{} = Msg, catch #ssh_msg_disconnect{} = DisconnectMsg -> handle_disconnect(DisconnectMsg, State); + {ErrorToDisplay, #ssh_msg_disconnect{} = DisconnectMsg} -> + handle_disconnect(DisconnectMsg, State, ErrorToDisplay); _:Error -> Desc = log_error(Error), handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED, - description = Desc, - language = "en"}, State) + description = Desc, + language = "en"}, State) end; key_exchange(#ssh_msg_kex_dh_gex_group{} = Msg, @@ -673,6 +675,11 @@ terminate({shutdown, #ssh_msg_disconnect{} = Msg}, StateName, #state{ssh_params send_msg(SshPacket, State), ssh_connection_manager:event(Pid, Msg), terminate(normal, StateName, State#state{ssh_params = Ssh}); +terminate({shutdown, {#ssh_msg_disconnect{} = Msg, ErrorMsg}}, StateName, #state{ssh_params = Ssh0, manager = Pid} = State) -> + {SshPacket, Ssh} = ssh_transport:ssh_packet(Msg, Ssh0), + send_msg(SshPacket, State), + ssh_connection_manager:event(Pid, Msg, ErrorMsg), + terminate(normal, StateName, State#state{ssh_params = Ssh}); terminate(Reason, StateName, #state{ssh_params = Ssh0, manager = Pid} = State) -> log_error(Reason), DisconnectMsg = @@ -950,6 +957,8 @@ handle_ssh_packet(Length, StateName, #state{decoded_data_buffer = DecData0, handle_disconnect(#ssh_msg_disconnect{} = Msg, State) -> {stop, {shutdown, Msg}, State}. +handle_disconnect(#ssh_msg_disconnect{} = Msg, State, ErrorMsg) -> + {stop, {shutdown, {Msg, ErrorMsg}}, State}. counterpart_versions(NumVsn, StrVsn, #ssh{role = server} = Ssh) -> Ssh#ssh{c_vsn = NumVsn , c_version = StrVsn}; diff --git a/lib/ssh/src/ssh_connection_manager.erl b/lib/ssh/src/ssh_connection_manager.erl index 94a9ed505f..9536eb9dec 100644 --- a/lib/ssh/src/ssh_connection_manager.erl +++ b/lib/ssh/src/ssh_connection_manager.erl @@ -40,8 +40,7 @@ close/2, stop/1, send/5, send_eof/2]). --export([open_channel/6, reply_request/3, request/6, request/7, global_request/4, event/2, - cast/2]). +-export([open_channel/6, reply_request/3, request/6, request/7, global_request/4, event/2, event/3, cast/2]). %% Internal application API and spawn -export([send_msg/1, ssh_channel_info_handler/3]). @@ -110,10 +109,11 @@ global_request(ConnectionManager, Type, true = Reply, Data) -> global_request(ConnectionManager, Type, false = Reply, Data) -> cast(ConnectionManager, {global_request, self(), Type, Reply, Data}). - + +event(ConnectionManager, BinMsg, ErrorMsg) -> + call(ConnectionManager, {ssh_msg, self(), BinMsg, ErrorMsg}). event(ConnectionManager, BinMsg) -> call(ConnectionManager, {ssh_msg, self(), BinMsg}). - info(ConnectionManager) -> info(ConnectionManager, {info, all}). @@ -262,8 +262,7 @@ handle_call({ssh_msg, Pid, Msg}, From, %% To avoid that not all data sent by the other side is processes before %% possible crash in ssh_connection_handler takes down the connection. - gen_server:reply(From, ok), - + gen_server:reply(From, ok), ConnectionMsg = decode_ssh_msg(Msg), try ssh_connection:handle_msg(ConnectionMsg, Connection0, Pid, Role) of {{replies, Replies}, Connection} -> @@ -294,7 +293,45 @@ handle_call({ssh_msg, Pid, Msg}, From, disconnect_fun(Reason, SSHOpts), {stop, {shutdown, Error}, State#state{connection_state = Connection}} end; +handle_call({ssh_msg, Pid, Msg, ErrorMsg}, From, + #state{connection_state = Connection0, + role = Role, opts = Opts, connected = IsConnected, + client = ClientPid} + = State) -> + %% To avoid that not all data sent by the other side is processes before + %% possible crash in ssh_connection_handler takes down the connection. + gen_server:reply(From, ok), + ConnectionMsg = decode_ssh_msg(Msg), + try ssh_connection:handle_msg(ConnectionMsg, Connection0, Pid, Role) of + {{replies, Replies}, Connection} -> + lists:foreach(fun send_msg/1, Replies), + {noreply, State#state{connection_state = Connection}}; + {noreply, Connection} -> + {noreply, State#state{connection_state = Connection}}; + {disconnect, {_, Reason}, {{replies, Replies}, Connection}} + when Role == client andalso (not IsConnected) -> + lists:foreach(fun send_msg/1, Replies), + ClientPid ! {self(), not_connected, {Reason, ErrorMsg}}, + {stop, {shutdown, normal}, State#state{connection = Connection}}; + {disconnect, Reason, {{replies, Replies}, Connection}} -> + lists:foreach(fun send_msg/1, Replies), + SSHOpts = proplists:get_value(ssh_opts, Opts), + disconnect_fun(Reason, SSHOpts), + {stop, {shutdown, normal}, State#state{connection_state = Connection}} + catch + _:Error -> + {disconnect, Reason, {{replies, Replies}, Connection}} = + ssh_connection:handle_msg( + #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, + description = "Internal error", + language = "en"}, Connection0, undefined, + Role), + lists:foreach(fun send_msg/1, Replies), + SSHOpts = proplists:get_value(ssh_opts, Opts), + disconnect_fun(Reason, SSHOpts), + {stop, {shutdown, Error}, State#state{connection_state = Connection}} + end; handle_call({global_request, Pid, _, _, _} = Request, From, #state{connection_state = #connection{channel_cache = Cache}} = State0) -> diff --git a/lib/ssh/src/ssh_transport.erl b/lib/ssh/src/ssh_transport.erl index 1abb69921d..a47a55b707 100644 --- a/lib/ssh/src/ssh_transport.erl +++ b/lib/ssh/src/ssh_transport.erl @@ -356,12 +356,12 @@ handle_kexdh_reply(#ssh_msg_kexdh_reply{public_host_key = HostKey, f = F, {ok, SshPacket, Ssh#ssh{shared_secret = K, exchanged_hash = H, session_id = sid(Ssh, H)}}; - _Error -> + Error -> Disconnect = #ssh_msg_disconnect{ code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED, description = "Key exchange failed", language = "en"}, - throw(Disconnect) + throw({Error, Disconnect}) end. handle_kex_dh_gex_request(#ssh_msg_kex_dh_gex_request{min = _Min, |