diff options
Diffstat (limited to 'lib/ssh/src/ssh_connection_handler.erl')
-rw-r--r-- | lib/ssh/src/ssh_connection_handler.erl | 101 |
1 files changed, 68 insertions, 33 deletions
diff --git a/lib/ssh/src/ssh_connection_handler.erl b/lib/ssh/src/ssh_connection_handler.erl index 3bdca4ba94..e6e5749e07 100644 --- a/lib/ssh/src/ssh_connection_handler.erl +++ b/lib/ssh/src/ssh_connection_handler.erl @@ -483,17 +483,22 @@ userauth(#ssh_msg_userauth_request{service = "ssh-connection", service = "ssh-connection", peer = {_, Address}} = Ssh0, opts = Opts, starter = Pid} = State) -> - case ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0) of - {authorized, User, {Reply, Ssh}} -> - send_msg(Reply, State), - Pid ! ssh_connected, - connected_fun(User, Address, Method, Opts), - {next_state, connected, - next_packet(State#state{auth_user = User, ssh_params = Ssh})}; - {not_authorized, {User, Reason}, {Reply, Ssh}} -> - retry_fun(User, Address, Reason, Opts), - send_msg(Reply, State), - {next_state, userauth, next_packet(State#state{ssh_params = Ssh})} + case lists:member(Method, Ssh0#ssh.userauth_methods) of + true -> + case ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0) of + {authorized, User, {Reply, Ssh}} -> + send_msg(Reply, State), + Pid ! ssh_connected, + connected_fun(User, Address, Method, Opts), + {next_state, connected, + next_packet(State#state{auth_user = User, ssh_params = Ssh})}; + {not_authorized, {User, Reason}, {Reply, Ssh}} -> + retry_fun(User, Address, Reason, Opts), + send_msg(Reply, State), + {next_state, userauth, next_packet(State#state{ssh_params = Ssh})} + end; + false -> + userauth(Msg#ssh_msg_userauth_request{method="none"}, State) end; userauth(#ssh_msg_userauth_info_request{} = Msg, @@ -750,15 +755,12 @@ handle_sync_event({info, ChannelPid}, _From, StateName, {reply, {ok, Result}, StateName, State}; handle_sync_event(stop, _, _StateName, #state{connection_state = Connection0, - role = Role, - opts = Opts} = State0) -> - {disconnect, Reason, {{replies, Replies}, Connection}} = + role = Role} = State0) -> + {disconnect, _Reason, {{replies, Replies}, Connection}} = ssh_connection:handle_msg(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, description = "User closed down connection", language = "en"}, Connection0, Role), State = send_replies(Replies, State0), - SSHOpts = proplists:get_value(ssh_opts, Opts), - disconnect_fun(Reason, SSHOpts), {stop, normal, ok, State#state{connection_state = Connection}}; @@ -987,15 +989,38 @@ handle_info({check_cache, _ , _}, #connection{channel_cache = Cache}} = State) -> {next_state, StateName, check_cache(State, Cache)}; -handle_info(UnexpectedMessage, StateName, #state{ssh_params = SshParams} = State) -> - Msg = lists:flatten(io_lib:format( - "Unexpected message '~p' received in state '~p'\n" - "Role: ~p\n" - "Peer: ~p\n" - "Local Address: ~p\n", [UnexpectedMessage, StateName, - SshParams#ssh.role, SshParams#ssh.peer, - proplists:get_value(address, SshParams#ssh.opts)])), - error_logger:info_report(Msg), +handle_info(UnexpectedMessage, StateName, #state{opts = Opts, + ssh_params = SshParams} = State) -> + case unexpected_fun(UnexpectedMessage, Opts, SshParams) of + report -> + Msg = lists:flatten( + io_lib:format( + "Unexpected message '~p' received in state '~p'\n" + "Role: ~p\n" + "Peer: ~p\n" + "Local Address: ~p\n", [UnexpectedMessage, StateName, + SshParams#ssh.role, SshParams#ssh.peer, + proplists:get_value(address, SshParams#ssh.opts)])), + error_logger:info_report(Msg); + + skip -> + ok; + + Other -> + Msg = lists:flatten( + io_lib:format("Call to fun in 'unexpectedfun' failed:~n" + "Return: ~p\n" + "Message: ~p\n" + "Role: ~p\n" + "Peer: ~p\n" + "Local Address: ~p\n", [Other, UnexpectedMessage, + SshParams#ssh.role, + element(2,SshParams#ssh.peer), + proplists:get_value(address, SshParams#ssh.opts)] + )), + + error_logger:error_report(Msg) + end, {next_state, StateName, State}. %%-------------------------------------------------------------------- @@ -1151,9 +1176,9 @@ init_ssh(client = Role, Vsn, Version, Options, Socket) -> }; init_ssh(server = Role, Vsn, Version, Options, Socket) -> - AuthMethods = proplists:get_value(auth_methods, Options, ?SUPPORTED_AUTH_METHODS), + AuthMethodsAsList = string:tokens(AuthMethods, ","), {ok, PeerAddr} = inet:peername(Socket), KeyCb = proplists:get_value(key_cb, Options, ssh_file), @@ -1164,6 +1189,8 @@ init_ssh(server = Role, Vsn, Version, Options, Socket) -> io_cb = proplists:get_value(io_cb, Options, ssh_io), opts = Options, userauth_supported_methods = AuthMethods, + userauth_methods = AuthMethodsAsList, + kb_tries_left = 3, peer = {undefined, PeerAddr}, available_host_keys = supported_host_keys(Role, KeyCb, Options) }. @@ -1275,7 +1302,6 @@ generate_event(<<?BYTE(Byte), _/binary>> = Msg, StateName, #state{ role = Role, starter = User, - opts = Opts, renegotiate = Renegotiation, connection_state = Connection0} = State0, EncData) when Byte == ?SSH_MSG_GLOBAL_REQUEST; @@ -1315,21 +1341,17 @@ generate_event(<<?BYTE(Byte), _/binary>> = Msg, StateName, User ! {self(), not_connected, Reason}, {stop, {shutdown, normal}, next_packet(State#state{connection_state = Connection})}; - {disconnect, Reason, {{replies, Replies}, Connection}} -> + {disconnect, _Reason, {{replies, Replies}, Connection}} -> State = send_replies(Replies, State1#state{connection_state = Connection}), - SSHOpts = proplists:get_value(ssh_opts, Opts), - disconnect_fun(Reason, SSHOpts), {stop, {shutdown, normal}, State#state{connection_state = Connection}} catch _:Error -> - {disconnect, Reason, {{replies, Replies}, Connection}} = + {disconnect, _Reason, {{replies, Replies}, Connection}} = ssh_connection:handle_msg( #ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION, description = "Internal error", language = "en"}, Connection0, Role), State = send_replies(Replies, State1#state{connection_state = Connection}), - SSHOpts = proplists:get_value(ssh_opts, Opts), - disconnect_fun(Reason, SSHOpts), {stop, {shutdown, Error}, State#state{connection_state = Connection}} end; @@ -1576,12 +1598,14 @@ handle_disconnect(#ssh_msg_disconnect{} = DisconnectMsg, State, Error) -> handle_disconnect(Type, #ssh_msg_disconnect{description = Desc} = Msg, #state{connection_state = Connection0, role = Role} = State0) -> {disconnect, _, {{replies, Replies}, Connection}} = ssh_connection:handle_msg(Msg, Connection0, Role), State = send_replies(disconnect_replies(Type, Msg, Replies), State0), + disconnect_fun(Desc, State#state.opts), {stop, {shutdown, Desc}, State#state{connection_state = Connection}}. handle_disconnect(Type, #ssh_msg_disconnect{description = Desc} = Msg, #state{connection_state = Connection0, role = Role} = State0, ErrorMsg) -> {disconnect, _, {{replies, Replies}, Connection}} = ssh_connection:handle_msg(Msg, Connection0, Role), State = send_replies(disconnect_replies(Type, Msg, Replies), State0), + disconnect_fun(Desc, State#state.opts), {stop, {shutdown, {Desc, ErrorMsg}}, State#state{connection_state = Connection}}. disconnect_replies(own, Msg, Replies) -> @@ -1700,6 +1724,8 @@ send_reply({flow_control, Cache, Channel, From, Msg}) -> send_reply({flow_control, From, Msg}) -> gen_fsm:reply(From, Msg). +disconnect_fun({disconnect,Msg}, Opts) -> + disconnect_fun(Msg, Opts); disconnect_fun(_, undefined) -> ok; disconnect_fun(Reason, Opts) -> @@ -1710,6 +1736,15 @@ disconnect_fun(Reason, Opts) -> catch Fun(Reason) end. +unexpected_fun(UnexpectedMessage, Opts, #ssh{peer={_,Peer}}) -> + case proplists:get_value(unexpectedfun, Opts) of + undefined -> + report; + Fun -> + catch Fun(UnexpectedMessage, Peer) + end. + + check_cache(#state{opts = Opts} = State, Cache) -> %% Check the number of entries in Cache case proplists:get_value(size, ets:info(Cache)) of |