aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--lib/ssh/src/ssh_cli.erl12
-rw-r--r--lib/ssh/src/ssh_connection.erl5
-rw-r--r--lib/ssh/test/ssh_connection_SUITE.erl45
3 files changed, 52 insertions, 10 deletions
diff --git a/lib/ssh/src/ssh_cli.erl b/lib/ssh/src/ssh_cli.erl
index 77453e8fd7..18841e3d2d 100644
--- a/lib/ssh/src/ssh_cli.erl
+++ b/lib/ssh/src/ssh_cli.erl
@@ -457,17 +457,17 @@ bin_to_list(I) when is_integer(I) ->
start_shell(ConnectionHandler, State) ->
Shell = State#state.shell,
- ConnectionInfo = ssh_connection_handler:info(ConnectionHandler,
+ ConnectionInfo = ssh_connection_handler:connection_info(ConnectionHandler,
[peer, user]),
ShellFun = case is_function(Shell) of
true ->
- {ok, User} =
+ User =
proplists:get_value(user, ConnectionInfo),
case erlang:fun_info(Shell, arity) of
{arity, 1} ->
fun() -> Shell(User) end;
{arity, 2} ->
- [{_, PeerAddr}] =
+ {_, PeerAddr} =
proplists:get_value(peer, ConnectionInfo),
fun() -> Shell(User, PeerAddr) end;
_ ->
@@ -485,9 +485,9 @@ start_shell(_ConnectionHandler, Cmd, #state{exec={M, F, A}} = State) ->
State#state{group = Group, buf = empty_buf()};
start_shell(ConnectionHandler, Cmd, #state{exec=Shell} = State) when is_function(Shell) ->
- ConnectionInfo = ssh_connection_handler:info(ConnectionHandler,
+ ConnectionInfo = ssh_connection_handler:connection_info(ConnectionHandler,
[peer, user]),
- {ok, User} =
+ User =
proplists:get_value(user, ConnectionInfo),
ShellFun =
case erlang:fun_info(Shell, arity) of
@@ -496,7 +496,7 @@ start_shell(ConnectionHandler, Cmd, #state{exec=Shell} = State) when is_function
{arity, 2} ->
fun() -> Shell(Cmd, User) end;
{arity, 3} ->
- [{_, PeerAddr}] =
+ {_, PeerAddr} =
proplists:get_value(peer, ConnectionInfo),
fun() -> Shell(Cmd, User, PeerAddr) end;
_ ->
diff --git a/lib/ssh/src/ssh_connection.erl b/lib/ssh/src/ssh_connection.erl
index b377614949..33849f4527 100644
--- a/lib/ssh/src/ssh_connection.erl
+++ b/lib/ssh/src/ssh_connection.erl
@@ -782,9 +782,8 @@ handle_cli_msg(#connection{channel_cache = Cache} = Connection,
erlang:monitor(process, Pid),
Channel = Channel0#channel{user = Pid},
ssh_channel:cache_update(Cache, Channel),
- Reply = {connection_reply,
- channel_success_msg(RemoteId)},
- {{replies, [{channel_data, Pid, Reply0}, Reply]}, Connection};
+ {Reply, Connection1} = reply_msg(Channel, Connection, Reply0),
+ {{replies, [Reply]}, Connection1};
_Other ->
Reply = {connection_reply,
channel_failure_msg(RemoteId)},
diff --git a/lib/ssh/test/ssh_connection_SUITE.erl b/lib/ssh/test/ssh_connection_SUITE.erl
index f4f0682b40..0b057f10de 100644
--- a/lib/ssh/test/ssh_connection_SUITE.erl
+++ b/lib/ssh/test/ssh_connection_SUITE.erl
@@ -37,7 +37,8 @@ suite() ->
all() ->
[
{group, openssh_payload},
- interrupted_send
+ interrupted_send,
+ start_shell
].
groups() ->
[{openssh_payload, [], [simple_exec,
@@ -276,6 +277,39 @@ interrupted_send(Config) when is_list(Config) ->
ssh:stop_daemon(Pid).
%%--------------------------------------------------------------------
+start_shell() ->
+ [{doc, "Start a shell"}].
+
+start_shell(Config) when is_list(Config) ->
+ PrivDir = ?config(priv_dir, Config),
+ UserDir = filename:join(PrivDir, nopubkey), % to make sure we don't use public-key-auth
+ file:make_dir(UserDir),
+ SysDir = ?config(data_dir, Config),
+ {Pid, Host, Port} = ssh_test_lib:daemon([{system_dir, SysDir},
+ {user_dir, UserDir},
+ {password, "morot"},
+ {shell, fun(U, H) -> start_our_shell(U, H) end} ]),
+
+ ConnectionRef = ssh_test_lib:connect(Host, Port, [{silently_accept_hosts, true},
+ {user, "foo"},
+ {password, "morot"},
+ {user_interaction, true},
+ {user_dir, UserDir}]),
+
+ {ok, ChannelId0} = ssh_connection:session_channel(ConnectionRef, infinity),
+ ok = ssh_connection:shell(ConnectionRef,ChannelId0),
+
+ receive
+ {ssh_cm,ConnectionRef, {data, ChannelId, 0, <<"Enter command\r\n">>}} ->
+ ok
+ after 5000 ->
+ ct:fail("CLI Timeout")
+ end,
+
+ ssh:close(ConnectionRef),
+ ssh:stop_daemon(Pid).
+
+%%--------------------------------------------------------------------
%% Internal functions ------------------------------------------------
%%--------------------------------------------------------------------
big_cat_rx(ConnectionRef, ChannelId) ->
@@ -308,3 +342,12 @@ collect_data(ConnectionRef, ChannelId, Acc) ->
after 5000 ->
timeout
end.
+
+%%%-------------------------------------------------------------------
+% This is taken from the ssh example code.
+start_our_shell(_User, _Peer) ->
+ spawn(fun() ->
+ io:format("Enter command\n")
+ %% Don't actually loop, just exit
+ end).
+