From 3f77d6d4326122486be1536d5184f1153117598b Mon Sep 17 00:00:00 2001 From: Hans Nilsson Date: Fri, 27 Apr 2018 15:48:48 +0200 Subject: ssh: Refactor connection_msg handling --- lib/ssh/src/ssh_connection_handler.erl | 153 ++++++++++++++------------------- 1 file changed, 66 insertions(+), 87 deletions(-) (limited to 'lib/ssh') diff --git a/lib/ssh/src/ssh_connection_handler.erl b/lib/ssh/src/ssh_connection_handler.erl index 2ad6f779d3..c9d389d887 100644 --- a/lib/ssh/src/ssh_connection_handler.erl +++ b/lib/ssh/src/ssh_connection_handler.erl @@ -559,6 +559,10 @@ renegotiation(_) -> false. #data{} ) -> gen_statem:event_handler_result(state_name()) . +-define(CONNECTION_MSG(Msg), + [{next_event, internal, prepare_next_packet}, + {next_event,internal,{conn_msg,Msg}}]). + %% . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . . callback_mode() -> @@ -1017,52 +1021,50 @@ handle_event(_, #ssh_msg_debug{} = Msg, _, D) -> debug_fun(Msg, D), keep_state_and_data; -handle_event(internal, Msg=#ssh_msg_global_request{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_request_success{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_request_failure{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_open{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_open_confirmation{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_open_failure{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_window_adjust{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_data{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_extended_data{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_eof{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_close{}, {connected,server} = StateName, D) -> - handle_connection_msg(Msg, StateName, cache_request_idle_timer_check(D)); - -handle_event(internal, Msg=#ssh_msg_channel_close{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_request{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); - -handle_event(internal, Msg=#ssh_msg_channel_success{}, StateName, D) -> - update_inet_buffers(D#data.socket), - handle_connection_msg(Msg, StateName, D); +handle_event(internal, {conn_msg,Msg}, StateName, #data{starter = User, + connection_state = Connection0, + event_queue = Qev0} = D0) -> + Role = role(StateName), + Rengotation = renegotiation(StateName), + try ssh_connection:handle_msg(Msg, Connection0, Role) of + {disconnect, Reason0, RepliesConn} -> + {Repls, D} = send_replies(RepliesConn, D0), + case {Reason0,Role} of + {{_, Reason}, client} when ((StateName =/= {connected,client}) + and (not Rengotation)) -> + User ! {self(), not_connected, Reason}; + _ -> + ok + end, + {stop_and_reply, {shutdown,normal}, Repls, D}; -handle_event(internal, Msg=#ssh_msg_channel_failure{}, StateName, D) -> - handle_connection_msg(Msg, StateName, D); + {Replies, Connection} when is_list(Replies) -> + {Repls, D} = + case StateName of + {connected,_} -> + send_replies(Replies, D0#data{connection_state=Connection}); + _ -> + {ConnReplies, NonConnReplies} = lists:splitwith(fun not_connected_filter/1, Replies), + send_replies(NonConnReplies, D0#data{event_queue = Qev0 ++ ConnReplies}) + end, + case Msg of + #ssh_msg_channel_close{} when StateName == {connected,server} -> + {keep_state, cache_request_idle_timer_check(D), Repls}; + #ssh_msg_channel_success{} -> + update_inet_buffers(D#data.socket), + {keep_state, D, Repls}; + _ -> + {keep_state, D, Repls} + end + catch + Class:Error -> + {Repls, D1} = send_replies(ssh_connection:handle_stop(Connection0), D0), + {Shutdown, D} = ?send_disconnect(?SSH_DISCONNECT_BY_APPLICATION, + io_lib:format("Internal error: ~p:~p",[Class,Error]), + StateName, D1), + {stop_and_reply, Shutdown, Repls, D} + end; handle_event(cast, renegotiate, {connected,Role}, D) -> {KeyInitMsg, SshPacket, Ssh} = ssh_transport:key_exchange_init_msg(D#data.ssh_params), @@ -1319,6 +1321,7 @@ handle_event(info, {Proto, Sock, Info}, {hello,_}, #data{socket = Sock, {keep_state_and_data, [{next_event, internal, {info_line,Info}}]} end; + handle_event(info, {Proto, Sock, NewData}, StateName, D0 = #data{socket = Sock, transport_protocol = Proto}) -> try ssh_transport:handle_packet_part( @@ -1336,13 +1339,29 @@ handle_event(info, {Proto, Sock, NewData}, StateName, D0 = #data{socket = Sock, try ssh_message:decode(set_kex_overload_prefix(DecryptedBytes,D1)) of - Msg = #ssh_msg_kexinit{} -> + #ssh_msg_kexinit{} = Msg -> {keep_state, D1, [{next_event, internal, prepare_next_packet}, {next_event, internal, {Msg,DecryptedBytes}} ]}; + + #ssh_msg_global_request{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_request_success{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_request_failure{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_open{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_open_confirmation{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_open_failure{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_window_adjust{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_data{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_extended_data{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_eof{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_close{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_request{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_failure{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + #ssh_msg_channel_success{} = Msg -> {keep_state, D1, ?CONNECTION_MSG(Msg)}; + Msg -> {keep_state, D1, [{next_event, internal, prepare_next_packet}, - {next_event, internal, Msg} + {next_event, internal, Msg} ]} catch C:E -> @@ -1433,7 +1452,7 @@ handle_event(info, {'DOWN', _Ref, process, ChannelPid, _Reason}, _, D) -> end, [], Cache), {keep_state, cache_check_set_idle_timer(D)}; -handle_event(info, idle_timer_timeout, StateName, _) -> +handle_event(info, idle_timer_timeout, _, _) -> {stop, {shutdown, "Timeout"}}; %%% So that terminate will be run when supervisor is shutdown @@ -1758,46 +1777,6 @@ call(FsmPid, Event, Timeout) -> end. -handle_connection_msg(Msg, StateName, D0 = #data{starter = User, - connection_state = Connection0, - event_queue = Qev0}) -> - Renegotiation = renegotiation(StateName), - Role = role(StateName), - try ssh_connection:handle_msg(Msg, Connection0, Role) of - {disconnect, Reason0, RepliesConn} -> - {Repls, D} = send_replies(RepliesConn, D0), - case {Reason0,Role} of - {{_, Reason}, client} when ((StateName =/= {connected,client}) and (not Renegotiation)) -> - User ! {self(), not_connected, Reason}; - _ -> - ok - end, - {stop_and_reply, {shutdown,normal}, Repls, D}; - - {[], Connection} -> - {keep_state, D0#data{connection_state = Connection}}; - - {Replies, Connection} when is_list(Replies) -> - {Repls, D} = - case StateName of - {connected,_} -> - send_replies(Replies, D0#data{connection_state=Connection}); - _ -> - {ConnReplies, NonConnReplies} = lists:splitwith(fun not_connected_filter/1, Replies), - send_replies(NonConnReplies, D0#data{event_queue = Qev0 ++ ConnReplies}) - end, - {keep_state, D, Repls} - - catch - Class:Error -> - {Repls, D1} = send_replies(ssh_connection:handle_stop(Connection0), D0), - {Shutdown, D} = ?send_disconnect(?SSH_DISCONNECT_BY_APPLICATION, - io_lib:format("Internal error: ~p:~p",[Class,Error]), - StateName, D1), - {stop_and_reply, Shutdown, Repls, D} - end. - - set_kex_overload_prefix(Msg = <>, #data{ssh_params=SshParams}) when Op == 30; Op == 31 -- cgit v1.2.3