%%
%% %CopyrightBegin%
%%
%% Copyright Ericsson AB 2008-2012. All Rights Reserved.
%%
%% The contents of this file are subject to the Erlang Public License,
%% Version 1.1, (the "License"); you may not use this file except in
%% compliance with the License. You should have received a copy of the
%% Erlang Public License along with this software. If not, it can be
%% retrieved online at http://www.erlang.org/.
%%
%% Software distributed under the License is distributed on an "AS IS"
%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
%% the License for the specific language governing rights and limitations
%% under the License.
%%
%% %CopyrightEnd%
%%
%%
%%----------------------------------------------------------------------
%% Purpose: Handles the setup of an ssh connection, e.i. both the
%% setup SSH Transport Layer Protocol (RFC 4253) and Authentication
%% Protocol (RFC 4252). Details of the different protocols are
%% implemented in ssh_transport.erl, ssh_auth.erl
%% ----------------------------------------------------------------------
-module(ssh_connection_handler).
-behaviour(gen_fsm).
-include("ssh.hrl").
-include("ssh_transport.hrl").
-include("ssh_auth.hrl").
-include("ssh_connect.hrl").
-export([start_link/4, send/2, renegotiate/1, send_event/2,
connection_info/3,
peer_address/1]).
%% gen_fsm callbacks
-export([hello/2, kexinit/2, key_exchange/2, new_keys/2,
userauth/2, connected/2]).
-export([init/1, handle_event/3,
handle_sync_event/4, handle_info/3, terminate/3, code_change/4]).
%% spawn export
-export([ssh_info_handler/3]).
-record(state, {
transport_protocol, % ex: tcp
transport_cb,
transport_close_tag,
ssh_params, % #ssh{} - from ssh.hrl
socket, % socket()
decoded_data_buffer, % binary()
encoded_data_buffer, % binary()
undecoded_packet_length, % integer()
key_exchange_init_msg, % #ssh_msg_kexinit{}
renegotiate = false, % boolean()
manager, % pid()
connection_queue,
address,
port,
opts
}).
-define(DBG_MESSAGE, true).
%%====================================================================
%% Internal application API
%%====================================================================
%%--------------------------------------------------------------------
%% Function: start_link() -> ok,Pid} | ignore | {error,Error}
%% Description:Creates a gen_fsm process which calls Module:init/1 to
%% initialize. To ensure a synchronized start-up procedure, this function
%% does not return until Module:init/1 has returned.
%%--------------------------------------------------------------------
start_link(Role, Manager, Socket, Options) ->
gen_fsm:start_link(?MODULE, [Role, Manager, Socket, Options], []).
send(ConnectionHandler, Data) ->
send_all_state_event(ConnectionHandler, {send, Data}).
renegotiate(ConnectionHandler) ->
send_all_state_event(ConnectionHandler, renegotiate).
connection_info(ConnectionHandler, From, Options) ->
send_all_state_event(ConnectionHandler, {info, From, Options}).
%% Replaced with option to connection_info/3. For now keep
%% for backwards compatibility
peer_address(ConnectionHandler) ->
sync_send_all_state_event(ConnectionHandler, peer_address).
%%====================================================================
%% gen_fsm callbacks
%%====================================================================
%%--------------------------------------------------------------------
%% Function: init(Args) -> {ok, StateName, State} |
%% {ok, StateName, State, Timeout} |
%% ignore |
%% {stop, StopReason}
%% Description:Whenever a gen_fsm is started using gen_fsm:start/[3,4] or
%% gen_fsm:start_link/3,4, this function is called by the new process to
%% initialize.
%%--------------------------------------------------------------------
init([Role, Manager, Socket, SshOpts]) ->
process_flag(trap_exit, true),
{NumVsn, StrVsn} = ssh_transport:versions(Role, SshOpts),
ssh_bits:install_messages(ssh_transport:transport_messages(NumVsn)),
{Protocol, Callback, CloseTag} =
proplists:get_value(transport, SshOpts, {tcp, gen_tcp, tcp_closed}),
try init_ssh(Role, NumVsn, StrVsn, SshOpts, Socket) of
Ssh ->
{ok, hello, #state{ssh_params =
Ssh#ssh{send_sequence = 0, recv_sequence = 0},
socket = Socket,
decoded_data_buffer = <<>>,
encoded_data_buffer = <<>>,
transport_protocol = Protocol,
transport_cb = Callback,
transport_close_tag = CloseTag,
manager = Manager,
opts = SshOpts
}}
catch
exit:Reason ->
{stop, {shutdown, Reason}}
end.
%%--------------------------------------------------------------------
%% Function:
%% state_name(Event, State) -> {next_state, NextStateName, NextState}|
%% {next_state, NextStateName,
%% NextState, Timeout} |
%% {stop, Reason, NewState}
%% Description:There should be one instance of this function for each possible
%% state name. Whenever a gen_fsm receives an event sent using
%% gen_fsm:send_event/2, the instance of this function with the same name as
%% the current state name StateName is called to handle the event. It is also
%% called if a timeout occurs.
%%--------------------------------------------------------------------
hello(socket_control, #state{socket = Socket, ssh_params = Ssh} = State) ->
VsnMsg = ssh_transport:hello_version_msg(string_version(Ssh)),
send_msg(VsnMsg, State),
inet:setopts(Socket, [{packet, line}]),
{next_state, hello, next_packet(State)};
hello({info_line, _Line}, State) ->
{next_state, hello, next_packet(State)};
hello({version_exchange, Version}, #state{ssh_params = Ssh0,
socket = Socket} = State) ->
{NumVsn, StrVsn} = ssh_transport:handle_hello_version(Version),
case handle_version(NumVsn, StrVsn, Ssh0) of
{ok, Ssh1} ->
inet:setopts(Socket, [{packet,0}, {mode,binary}]),
{KeyInitMsg, SshPacket, Ssh} = ssh_transport:key_exchange_init_msg(Ssh1),
send_msg(SshPacket, State),
{next_state, kexinit, next_packet(State#state{ssh_params = Ssh,
key_exchange_init_msg =
KeyInitMsg})};
not_supported ->
DisconnectMsg =
#ssh_msg_disconnect{code =
?SSH_DISCONNECT_PROTOCOL_VERSION_NOT_SUPPORTED,
description = "Protocol version " ++ StrVsn
++ " not supported",
language = "en"},
handle_disconnect(DisconnectMsg, State)
end.
kexinit({#ssh_msg_kexinit{} = Kex, Payload},
#state{ssh_params = #ssh{role = Role} = Ssh0,
key_exchange_init_msg = OwnKex} =
State) ->
Ssh1 = ssh_transport:key_init(opposite_role(Role), Ssh0, Payload),
try ssh_transport:handle_kexinit_msg(Kex, OwnKex, Ssh1) of
{ok, NextKexMsg, Ssh} when Role == client ->
send_msg(NextKexMsg, State),
{next_state, key_exchange,
next_packet(State#state{ssh_params = Ssh})};
{ok, Ssh} when Role == server ->
{next_state, key_exchange,
next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
description = Desc,
language = "en"}, State)
end.
key_exchange(#ssh_msg_kexdh_init{} = Msg,
#state{ssh_params = #ssh{role = server} =Ssh0} = State) ->
try ssh_transport:handle_kexdh_init(Msg, Ssh0) of
{ok, KexdhReply, Ssh1} ->
send_msg(KexdhReply, State),
{ok, NewKeys, Ssh} = ssh_transport:new_keys_message(Ssh1),
send_msg(NewKeys, State),
{next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
description = Desc,
language = "en"}, State)
end;
key_exchange(#ssh_msg_kexdh_reply{} = Msg,
#state{ssh_params = #ssh{role = client} = Ssh0} = State) ->
try ssh_transport:handle_kexdh_reply(Msg, Ssh0) of
{ok, NewKeys, Ssh} ->
send_msg(NewKeys, State),
{next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
description = Desc,
language = "en"}, State)
end;
key_exchange(#ssh_msg_kex_dh_gex_group{} = Msg,
#state{ssh_params = #ssh{role = server} = Ssh0} = State) ->
try ssh_transport:handle_kex_dh_gex_group(Msg, Ssh0) of
{ok, NextKexMsg, Ssh1} ->
send_msg(NextKexMsg, State),
{ok, NewKeys, Ssh} = ssh_transport:new_keys_message(Ssh1),
send_msg(NewKeys, State),
{next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
description = Desc,
language = "en"}, State)
end;
key_exchange(#ssh_msg_kex_dh_gex_request{} = Msg,
#state{ssh_params = #ssh{role = client} = Ssh0} = State) ->
try ssh_transport:handle_kex_dh_gex_request(Msg, Ssh0) of
{ok, NextKexMsg, Ssh} ->
send_msg(NextKexMsg, State),
{next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
description = Desc,
language = "en"}, State)
end;
key_exchange(#ssh_msg_kex_dh_gex_reply{} = Msg,
#state{ssh_params = #ssh{role = client} = Ssh0} = State) ->
try ssh_transport:handle_kex_dh_gex_reply(Msg, Ssh0) of
{ok, NewKeys, Ssh} ->
send_msg(NewKeys, State),
{next_state, new_keys, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
description = Desc,
language = "en"}, State)
end.
new_keys(#ssh_msg_newkeys{} = Msg, #state{ssh_params = Ssh0} = State0) ->
try ssh_transport:handle_new_keys(Msg, Ssh0) of
{ok, Ssh} ->
{NextStateName, State} =
after_new_keys(State0#state{ssh_params = Ssh}),
{next_state, NextStateName, next_packet(State)}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State0);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_KEY_EXCHANGE_FAILED,
description = Desc,
language = "en"}, State0)
end.
userauth(#ssh_msg_service_request{name = "ssh-userauth"} = Msg,
#state{ssh_params = #ssh{role = server,
session_id = SessionId} = Ssh0} = State) ->
ssh_bits:install_messages(ssh_auth:userauth_messages()),
try ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0) of
{ok, {Reply, Ssh}} ->
send_msg(Reply, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE,
description = Desc,
language = "en"}, State)
end;
userauth(#ssh_msg_service_accept{name = "ssh-userauth"},
#state{ssh_params = #ssh{role = client,
service = "ssh-userauth"} = Ssh0} =
State) ->
{Msg, Ssh} = ssh_auth:init_userauth_request_msg(Ssh0),
send_msg(Msg, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})};
userauth(#ssh_msg_userauth_request{service = "ssh-connection",
method = "none"} = Msg,
#state{ssh_params = #ssh{session_id = SessionId, role = server,
service = "ssh-connection"} = Ssh0
} = State) ->
try ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0) of
{not_authorized, {_User, _Reason}, {Reply, Ssh}} ->
send_msg(Reply, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE,
description = Desc,
language = "en"}, State)
end;
userauth(#ssh_msg_userauth_request{service = "ssh-connection",
method = Method} = Msg,
#state{ssh_params = #ssh{session_id = SessionId, role = server,
service = "ssh-connection",
peer = {_, Address}} = Ssh0,
opts = Opts, manager = Pid} = State) ->
try ssh_auth:handle_userauth_request(Msg, SessionId, Ssh0) of
{authorized, User, {Reply, Ssh}} ->
send_msg(Reply, State),
ssh_userreg:register_user(User, Pid),
Pid ! ssh_connected,
connected_fun(User, Address, Method, Opts),
{next_state, connected,
next_packet(State#state{ssh_params = Ssh})};
{not_authorized, {User, Reason}, {Reply, Ssh}} ->
retry_fun(User, Reason, Opts),
send_msg(Reply, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE,
description = Desc,
language = "en"}, State)
end;
userauth(#ssh_msg_userauth_info_request{} = Msg,
#state{ssh_params = #ssh{role = client,
io_cb = IoCb} = Ssh0} = State) ->
try ssh_auth:handle_userauth_info_request(Msg, IoCb, Ssh0) of
{ok, {Reply, Ssh}} ->
send_msg(Reply, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE,
description = Desc,
language = "en"}, State)
end;
userauth(#ssh_msg_userauth_info_response{} = Msg,
#state{ssh_params = #ssh{role = server} = Ssh0} = State) ->
try ssh_auth:handle_userauth_info_response(Msg, Ssh0) of
{ok, {Reply, Ssh}} ->
send_msg(Reply, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})}
catch
#ssh_msg_disconnect{} = DisconnectMsg ->
handle_disconnect(DisconnectMsg, State);
_:Error ->
Desc = log_error(Error),
handle_disconnect(#ssh_msg_disconnect{code = ?SSH_DISCONNECT_SERVICE_NOT_AVAILABLE,
description = Desc,
language = "en"}, State)
end;
userauth(#ssh_msg_userauth_success{}, #state{ssh_params = #ssh{role = client},
manager = Pid} = State) ->
Pid ! ssh_connected,
{next_state, connected, next_packet(State)};
userauth(#ssh_msg_userauth_failure{},
#state{ssh_params = #ssh{role = client,
userauth_methods = []}}
= State) ->
Msg = #ssh_msg_disconnect{code =
?SSH_DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
description = "Unable to connect using the available"
" authentication methods",
language = "en"},
handle_disconnect(Msg, State);
%% Server tells us which authentication methods that are allowed
userauth(#ssh_msg_userauth_failure{authentications = Methodes},
#state{ssh_params = #ssh{role = client,
userauth_methods = none} = Ssh0} = State) ->
AuthMethods = string:tokens(Methodes, ","),
{Msg, Ssh} = ssh_auth:userauth_request_msg(
Ssh0#ssh{userauth_methods = AuthMethods}),
send_msg(Msg, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})};
%% The prefered authentication method failed try next method
userauth(#ssh_msg_userauth_failure{},
#state{ssh_params = #ssh{role = client} = Ssh0} = State) ->
case ssh_auth:userauth_request_msg(Ssh0) of
{disconnect, DisconnectMsg,{Msg, Ssh}} ->
send_msg(Msg, State),
handle_disconnect(DisconnectMsg, State#state{ssh_params = Ssh});
{Msg, Ssh} ->
send_msg(Msg, State),
{next_state, userauth, next_packet(State#state{ssh_params = Ssh})}
end;
userauth(#ssh_msg_userauth_banner{},
#state{ssh_params = #ssh{userauth_quiet_mode = true,
role = client}} = State) ->
{next_state, userauth, next_packet(State)};
userauth(#ssh_msg_userauth_banner{message = Msg},
#state{ssh_params =
#ssh{userauth_quiet_mode = false, role = client}} = State) ->
io:format("~s", [Msg]),
{next_state, userauth, next_packet(State)}.
connected({#ssh_msg_kexinit{}, _Payload} = Event, State) ->
kexinit(Event, State#state{renegotiate = true}).
%%--------------------------------------------------------------------
%% Function:
%% handle_event(Event, StateName, State) -> {next_state, NextStateName,
%% NextState} |
%% {next_state, NextStateName,
%% NextState, Timeout} |
%% {stop, Reason, NewState}
%% Description: Whenever a gen_fsm receives an event sent using
%% gen_fsm:send_all_state_event/2, this function is called to handle
%% the event.
%%--------------------------------------------------------------------
handle_event({send, Data}, StateName, #state{ssh_params = Ssh0} = State) ->
{Packet, Ssh} = ssh_transport:pack(Data, Ssh0),
send_msg(Packet, State),
{next_state, StateName, next_packet(State#state{ssh_params = Ssh})};
handle_event(#ssh_msg_disconnect{} = Msg, _StateName,
#state{manager = Pid} = State) ->
(catch ssh_connection_manager:event(Pid, Msg)),
{stop, normal, State};
handle_event(#ssh_msg_ignore{}, StateName, State) ->
{next_state, StateName, next_packet(State)};
handle_event(#ssh_msg_debug{always_display = true, message = DbgMsg},
StateName, State) ->
io:format("DEBUG: ~p\n", [DbgMsg]),
{next_state, StateName, next_packet(State)};
handle_event(#ssh_msg_debug{}, StateName, State) ->
{next_state, StateName, next_packet(State)};
handle_event(#ssh_msg_unimplemented{}, StateName, State) ->
{next_state, StateName, next_packet(State)};
handle_event(renegotiate, connected, #state{ssh_params = Ssh0}
= State) ->
{KeyInitMsg, SshPacket, Ssh} = ssh_transport:key_exchange_init_msg(Ssh0),
send_msg(SshPacket, State),
{next_state, connected,
next_packet(State#state{ssh_params = Ssh,
key_exchange_init_msg = KeyInitMsg,
renegotiate = true})};
handle_event(renegotiate, StateName, State) ->
%% Allready in keyexcahange so ignore
{next_state, StateName, State};
handle_event({info, From, Options}, StateName, #state{ssh_params = Ssh} = State) ->
spawn(?MODULE, ssh_info_handler, [Options, Ssh, From]),
{next_state, StateName, State};
handle_event({unknown, Data}, StateName, State) ->
Msg = #ssh_msg_unimplemented{sequence = Data},
send_msg(Msg, State),
{next_state, StateName, next_packet(State)}.
%%--------------------------------------------------------------------
%% Function:
%% handle_sync_event(Event, From, StateName,
%% State) -> {next_state, NextStateName, NextState} |
%% {next_state, NextStateName, NextState,
%% Timeout} |
%% {reply, Reply, NextStateName, NextState}|
%% {reply, Reply, NextStateName, NextState,
%% Timeout} |
%% {stop, Reason, NewState} |
%% {stop, Reason, Reply, NewState}
%% Description: Whenever a gen_fsm receives an event sent using
%% gen_fsm:sync_send_all_state_event/2,3, this function is called to handle
%% the event.
%%--------------------------------------------------------------------
%% Replaced with option to connection_info/3. For now keep
%% for backwards compatibility
handle_sync_event(peer_address, _From, StateName,
#state{ssh_params = #ssh{peer = {_, Address}}} = State) ->
{reply, {ok, Address}, StateName, State}.
%%--------------------------------------------------------------------
%% Function:
%% handle_info(Info,StateName,State)-> {next_state, NextStateName, NextState}|
%% {next_state, NextStateName, NextState,
%% Timeout} |
%% {stop, Reason, NewState}
%% Description: This function is called by a gen_fsm when it receives any
%% other message than a synchronous or asynchronous event
%% (or a system message).
%%--------------------------------------------------------------------
handle_info({Protocol, Socket, "SSH-" ++ _ = Version}, hello,
#state{socket = Socket,
transport_protocol = Protocol} = State ) ->
event({version_exchange, Version}, hello, State);
handle_info({Protocol, Socket, Info}, hello,
#state{socket = Socket,
transport_protocol = Protocol} = State) ->
event({info_line, Info}, hello, State);
handle_info({Protocol, Socket, Data}, Statename,
#state{socket = Socket,
transport_protocol = Protocol,
ssh_params = #ssh{decrypt_block_size = BlockSize,
recv_mac_size = MacSize} = Ssh0,
decoded_data_buffer = <<>>,
encoded_data_buffer = EncData0} = State0) ->
%% Implementations SHOULD decrypt the length after receiving the
%% first 8 (or cipher block size, whichever is larger) bytes of a
%% packet. (RFC 4253: Section 6 - Binary Packet Protocol)
case size(EncData0) + size(Data) >= erlang:max(8, BlockSize) of
true ->
{Ssh, SshPacketLen, DecData, EncData} =
ssh_transport:decrypt_first_block(<<EncData0/binary,
Data/binary>>, Ssh0),
case SshPacketLen > ?SSH_MAX_PACKET_SIZE of
true ->
DisconnectMsg =
#ssh_msg_disconnect{code =
?SSH_DISCONNECT_PROTOCOL_ERROR,
description = "Bad packet length "
++ integer_to_list(SshPacketLen),
language = "en"},
handle_disconnect(DisconnectMsg, State0);
false ->
RemainingSshPacketLen =
(SshPacketLen + ?SSH_LENGHT_INDICATOR_SIZE) -
BlockSize + MacSize,
State = State0#state{ssh_params = Ssh},
handle_ssh_packet_data(RemainingSshPacketLen,
DecData, EncData, Statename,
State)
end;
false ->
{next_state, Statename,
next_packet(State0#state{encoded_data_buffer =
<<EncData0/binary, Data/binary>>})}
end;
handle_info({Protocol, Socket, Data}, Statename,
#state{socket = Socket,
transport_protocol = Protocol,
decoded_data_buffer = DecData,
encoded_data_buffer = EncData,
undecoded_packet_length = Len} =
State) when is_integer(Len) ->
handle_ssh_packet_data(Len, DecData, <<EncData/binary, Data/binary>>,
Statename, State);
handle_info({CloseTag, _Socket}, _StateName,
#state{transport_close_tag = CloseTag,
ssh_params = #ssh{role = _Role, opts = _Opts}} = State) ->
DisconnectMsg =
#ssh_msg_disconnect{code = ?SSH_DISCONNECT_CONNECTION_LOST,
description = "Connection Lost",
language = "en"},
{stop, {shutdown, DisconnectMsg}, State};
%%% So that terminate will be run when supervisor is shutdown
handle_info({'EXIT', _Sup, Reason}, _StateName, State) ->
{stop, Reason, State};
handle_info(UnexpectedMessage, StateName, #state{ssh_params = SshParams} = State) ->
Msg = lists:flatten(io_lib:format(
"Unexpected message '~p' received in state '~p'\n"
"Role: ~p\n"
"Peer: ~p\n"
"Local Address: ~p\n", [UnexpectedMessage, StateName,
SshParams#ssh.role, SshParams#ssh.peer,
proplists:get_value(address, SshParams#ssh.opts)])),
error_logger:info_report(Msg),
{next_state, StateName, State}.
%%--------------------------------------------------------------------
%% Function: terminate(Reason, StateName, State) -> void()
%% Description:This function is called by a gen_fsm when it is about
%% to terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_fsm terminates with
%% Reason. The return value is ignored.
%%--------------------------------------------------------------------
terminate(normal, _, #state{transport_cb = Transport,
socket = Socket,
manager = Pid}) ->
(catch ssh_userreg:delete_user(Pid)),
(catch Transport:close(Socket)),
ok;
%% Terminated as manager terminated
terminate(shutdown, StateName, #state{ssh_params = Ssh0} = State) ->
DisconnectMsg =
#ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION,
description = "Application shutdown",
language = "en"},
{SshPacket, Ssh} = ssh_transport:ssh_packet(DisconnectMsg, Ssh0),
send_msg(SshPacket, State),
terminate(normal, StateName, State#state{ssh_params = Ssh});
terminate({shutdown, #ssh_msg_disconnect{} = Msg}, StateName, #state{ssh_params = Ssh0, manager = Pid} = State) ->
{SshPacket, Ssh} = ssh_transport:ssh_packet(Msg, Ssh0),
send_msg(SshPacket, State),
ssh_connection_manager:event(Pid, Msg),
terminate(normal, StateName, State#state{ssh_params = Ssh});
terminate(Reason, StateName, #state{ssh_params = Ssh0, manager = Pid} = State) ->
log_error(Reason),
DisconnectMsg =
#ssh_msg_disconnect{code = ?SSH_DISCONNECT_BY_APPLICATION,
description = "Internal error",
language = "en"},
{SshPacket, Ssh} = ssh_transport:ssh_packet(DisconnectMsg, Ssh0),
ssh_connection_manager:event(Pid, DisconnectMsg),
send_msg(SshPacket, State),
terminate(normal, StateName, State#state{ssh_params = Ssh}).
%%--------------------------------------------------------------------
%% Function:
%% code_change(OldVsn, StateName, State, Extra) -> {ok, StateName, NewState}
%% Description: Convert process state when code is changed
%%--------------------------------------------------------------------
code_change(_OldVsn, StateName, State, _Extra) ->
{ok, StateName, State}.
%%--------------------------------------------------------------------
%%% Internal functions
%%--------------------------------------------------------------------
init_ssh(client = Role, Vsn, Version, Options, Socket) ->
IOCb = case proplists:get_value(user_interaction, Options, true) of
true ->
ssh_io;
false ->
ssh_no_io
end,
AuthMethods = proplists:get_value(auth_methods, Options,
?SUPPORTED_AUTH_METHODS),
{ok, PeerAddr} = inet:peername(Socket),
PeerName = proplists:get_value(host, Options),
KeyCb = proplists:get_value(key_cb, Options, ssh_file),
#ssh{role = Role,
c_vsn = Vsn,
c_version = Version,
key_cb = KeyCb,
io_cb = IOCb,
userauth_quiet_mode = proplists:get_value(quiet_mode, Options, false),
opts = Options,
userauth_supported_methods = AuthMethods,
peer = {PeerName, PeerAddr},
available_host_keys = supported_host_keys(Role, KeyCb, Options)
};
init_ssh(server = Role, Vsn, Version, Options, Socket) ->
AuthMethods = proplists:get_value(auth_methods, Options,
?SUPPORTED_AUTH_METHODS),
{ok, PeerAddr} = inet:peername(Socket),
KeyCb = proplists:get_value(key_cb, Options, ssh_file),
#ssh{role = Role,
s_vsn = Vsn,
s_version = Version,
key_cb = KeyCb,
io_cb = proplists:get_value(io_cb, Options, ssh_io),
opts = Options,
userauth_supported_methods = AuthMethods,
peer = {undefined, PeerAddr},
available_host_keys = supported_host_keys(Role, KeyCb, Options)
}.
supported_host_keys(client, _, Options) ->
try
case extract_algs(proplists:get_value(pref_public_key_algs, Options, false), []) of
false ->
["ssh-rsa", "ssh-dss"];
Algs ->
Algs
end
catch
exit:Reason ->
{stop, {shutdown, Reason}}
end;
supported_host_keys(server, KeyCb, Options) ->
lists:foldl(fun(Type, Acc) ->
case available_host_key(KeyCb, Type, Options) of
{error, _} ->
Acc;
Alg ->
[Alg | Acc]
end
end, [],
%% Prefered alg last so no need to reverse
["ssh-dss", "ssh-rsa"]).
extract_algs(false, _) ->
false;
extract_algs([],[]) ->
false;
extract_algs([], NewList) ->
lists:reverse(NewList);
extract_algs([H|T], NewList) ->
case H of
'ssh-dss' ->
extract_algs(T, ["ssh-dss"|NewList]);
'ssh-rsa' ->
extract_algs(T, ["ssh-rsa"|NewList])
end.
available_host_key(KeyCb, "ssh-dss"= Alg, Opts) ->
case KeyCb:host_key('ssh-dss', Opts) of
{ok, _} ->
Alg;
Other ->
Other
end;
available_host_key(KeyCb, "ssh-rsa" = Alg, Opts) ->
case KeyCb:host_key('ssh-rsa', Opts) of
{ok, _} ->
Alg;
Other ->
Other
end.
send_msg(Msg, #state{socket = Socket, transport_cb = Transport}) ->
Transport:send(Socket, Msg).
handle_version({2, 0} = NumVsn, StrVsn, Ssh0) ->
Ssh = counterpart_versions(NumVsn, StrVsn, Ssh0),
{ok, Ssh};
handle_version(_,_,_) ->
not_supported.
string_version(#ssh{role = client, c_version = Vsn}) ->
Vsn;
string_version(#ssh{role = server, s_version = Vsn}) ->
Vsn.
send_event(FsmPid, Event) ->
gen_fsm:send_event(FsmPid, Event).
send_all_state_event(FsmPid, Event) ->
gen_fsm:send_all_state_event(FsmPid, Event).
sync_send_all_state_event(FsmPid, Event) ->
gen_fsm:sync_send_all_state_event(FsmPid, Event).
%% simulate send_all_state_event(self(), Event)
event(#ssh_msg_disconnect{} = Event, StateName, State) ->
handle_event(Event, StateName, State);
event(#ssh_msg_ignore{} = Event, StateName, State) ->
handle_event(Event, StateName, State);
event(#ssh_msg_debug{} = Event, StateName, State) ->
handle_event(Event, StateName, State);
event(#ssh_msg_unimplemented{} = Event, StateName, State) ->
handle_event(Event, StateName, State);
%% simulate send_event(self(), Event)
event(Event, StateName, State) ->
?MODULE:StateName(Event, State).
generate_event(<<?BYTE(Byte), _/binary>> = Msg, StateName,
#state{manager = Pid} = State0, EncData)
when Byte == ?SSH_MSG_GLOBAL_REQUEST;
Byte == ?SSH_MSG_REQUEST_SUCCESS;
Byte == ?SSH_MSG_REQUEST_FAILURE;
Byte == ?SSH_MSG_CHANNEL_OPEN;
Byte == ?SSH_MSG_CHANNEL_OPEN_CONFIRMATION;
Byte == ?SSH_MSG_CHANNEL_OPEN_FAILURE;
Byte == ?SSH_MSG_CHANNEL_WINDOW_ADJUST;
Byte == ?SSH_MSG_CHANNEL_DATA;
Byte == ?SSH_MSG_CHANNEL_EXTENDED_DATA;
Byte == ?SSH_MSG_CHANNEL_EOF;
Byte == ?SSH_MSG_CHANNEL_CLOSE;
Byte == ?SSH_MSG_CHANNEL_REQUEST;
Byte == ?SSH_MSG_CHANNEL_SUCCESS;
Byte == ?SSH_MSG_CHANNEL_FAILURE ->
try
ssh_connection_manager:event(Pid, Msg),
State = generate_event_new_state(State0, EncData),
next_packet(State),
{next_state, StateName, State}
catch
exit:{noproc, Reason} ->
{stop, {shutdown, Reason}, State0}
end;
generate_event(Msg, StateName, State0, EncData) ->
Event = ssh_bits:decode(Msg),
State = generate_event_new_state(State0, EncData),
case Event of
#ssh_msg_kexinit{} ->
%% We need payload for verification later.
event({Event, Msg}, StateName, State);
_ ->
event(Event, StateName, State)
end.
generate_event_new_state(#state{ssh_params =
#ssh{recv_sequence = SeqNum0}
= Ssh} = State, EncData) ->
SeqNum = ssh_transport:next_seqnum(SeqNum0),
State#state{ssh_params = Ssh#ssh{recv_sequence = SeqNum},
decoded_data_buffer = <<>>,
encoded_data_buffer = EncData,
undecoded_packet_length = undefined}.
next_packet(#state{decoded_data_buffer = <<>>,
encoded_data_buffer = Buff,
ssh_params = #ssh{decrypt_block_size = BlockSize},
socket = Socket,
transport_protocol = Protocol} = State) when Buff =/= <<>> ->
case size(Buff) >= erlang:max(8, BlockSize) of
true ->
%% Enough data from the next packet has been received to
%% decode the length indicator, fake a socket-recive
%% message so that the data will be processed
self() ! {Protocol, Socket, <<>>};
false ->
inet:setopts(Socket, [{active, once}])
end,
State;
next_packet(#state{socket = Socket} = State) ->
inet:setopts(Socket, [{active, once}]),
State.
after_new_keys(#state{renegotiate = true} = State) ->
{connected, State#state{renegotiate = false}};
after_new_keys(#state{renegotiate = false,
ssh_params = #ssh{role = client} = Ssh0} = State) ->
ssh_bits:install_messages(ssh_auth:userauth_messages()),
{Msg, Ssh} = ssh_auth:service_request_msg(Ssh0),
send_msg(Msg, State),
{userauth, State#state{ssh_params = Ssh}};
after_new_keys(#state{renegotiate = false,
ssh_params = #ssh{role = server}} = State) ->
{userauth, State}.
handle_ssh_packet_data(RemainingSshPacketLen, DecData, EncData, StateName,
State) ->
EncSize = size(EncData),
case RemainingSshPacketLen > EncSize of
true ->
{next_state, StateName,
next_packet(State#state{decoded_data_buffer = DecData,
encoded_data_buffer = EncData,
undecoded_packet_length =
RemainingSshPacketLen})};
false ->
handle_ssh_packet(RemainingSshPacketLen, StateName,
State#state{decoded_data_buffer = DecData,
encoded_data_buffer = EncData})
end.
handle_ssh_packet(Length, StateName, #state{decoded_data_buffer = DecData0,
encoded_data_buffer = EncData0,
ssh_params = Ssh0,
transport_protocol = _Protocol,
socket = _Socket} = State0) ->
{Ssh1, DecData, EncData, Mac} =
ssh_transport:unpack(EncData0, Length, Ssh0),
SshPacket = <<DecData0/binary, DecData/binary>>,
case ssh_transport:is_valid_mac(Mac, SshPacket, Ssh1) of
true ->
PacketData = ssh_transport:msg_data(SshPacket),
{Ssh1, Msg} = ssh_transport:decompress(Ssh1, PacketData),
generate_event(Msg, StateName,
State0#state{ssh_params = Ssh1,
%% Important to be set for
%% next_packet
decoded_data_buffer = <<>>}, EncData);
false ->
DisconnectMsg =
#ssh_msg_disconnect{code = ?SSH_DISCONNECT_PROTOCOL_ERROR,
description = "Bad mac",
language = "en"},
handle_disconnect(DisconnectMsg, State0)
end.
handle_disconnect(#ssh_msg_disconnect{} = Msg, State) ->
{stop, {shutdown, Msg}, State}.
counterpart_versions(NumVsn, StrVsn, #ssh{role = server} = Ssh) ->
Ssh#ssh{c_vsn = NumVsn , c_version = StrVsn};
counterpart_versions(NumVsn, StrVsn, #ssh{role = client} = Ssh) ->
Ssh#ssh{s_vsn = NumVsn , s_version = StrVsn}.
opposite_role(client) ->
server;
opposite_role(server) ->
client.
connected_fun(User, PeerAddr, Method, Opts) ->
case proplists:get_value(connectfun, Opts) of
undefined ->
ok;
Fun ->
catch Fun(User, PeerAddr, Method)
end.
retry_fun(_, undefined, _) ->
ok;
retry_fun(User, {error, Reason}, Opts) ->
case proplists:get_value(failfun, Opts) of
undefined ->
ok;
Fun ->
catch Fun(User, Reason)
end;
retry_fun(User, Reason, Opts) ->
case proplists:get_value(infofun, Opts) of
undefined ->
ok;
Fun ->
catch Fun(User, Reason)
end.
ssh_info_handler(Options, Ssh, From) ->
Info = ssh_info(Options, Ssh, []),
ssh_connection_manager:send_msg({channel_requst_reply, From, Info}).
ssh_info([], _, Acc) ->
Acc;
ssh_info([client_version | Rest], #ssh{c_vsn = IntVsn,
c_version = StringVsn} = SshParams, Acc) ->
ssh_info(Rest, SshParams, [{client_version, {IntVsn, StringVsn}} | Acc]);
ssh_info([server_version | Rest], #ssh{s_vsn = IntVsn,
s_version = StringVsn} = SshParams, Acc) ->
ssh_info(Rest, SshParams, [{server_version, {IntVsn, StringVsn}} | Acc]);
ssh_info([peer | Rest], #ssh{peer = Peer} = SshParams, Acc) ->
ssh_info(Rest, SshParams, [{peer, Peer} | Acc]);
ssh_info([ _ | Rest], SshParams, Acc) ->
ssh_info(Rest, SshParams, Acc).
log_error(Reason) ->
Report = io_lib:format("Erlang ssh connection handler failed with reason: "
"~p ~n, Stacktace: ~p ~n"
"please report this to [email protected] \n",
[Reason, erlang:get_stacktrace()]),
error_logger:error_report(Report),
"Internal error".