%%%-------------------------------------------------------------------
%%% @author Simon Cornish <[email protected]>
%%% @copyright (C) 2015, Simon Cornish
%%% @doc
%%% Provide manipulatable TCP-level relaying for testing SSH
%%% @end
%%% Created : 7 May 2015 by Simon Cornish <[email protected]>
%%%-------------------------------------------------------------------
-module(ssh_relay).
-behaviour(gen_server).
%% API
-export([start_link/4]).
-export([stop/1]).
-export([hold/4, release/2, release_next/3]).
%% gen_server callbacks
-export([init/1, handle_call/3, handle_cast/2, handle_info/2,
terminate/2, code_change/3]).
-record(hold, {
port,
n,
tmo,
tref,
q = []
}).
-record(state, {
local_addr,
local_port,
peer_addr,
peer_port,
lpid,
local,
peer,
tx_hold,
rx_hold
}).
-define(ACCEPT_TMO, 200).
%%%===================================================================
%%% API
%%%===================================================================
%%--------------------------------------------------------------------
%% @doc
%% Hold N (or 'all') messages in given direction.
%% Messages will be released after the N+1th message or
%% Tmo ms or 'infinity'
%%
%% Dir is 'tx' for direction local -> peer
%% and 'rx' for direction peer -> local
%%
%% An Error, ealready, is returned if there is already a hold
%% in the given direction
%%
%% @spec hold(Srv, Dir, N, Tmo) -> ok | {error, Error}
%% @end
%%--------------------------------------------------------------------
hold(Srv, Dir, N, Tmo) ->
gen_server:call(Srv, {hold, Dir, N, Tmo}).
%%--------------------------------------------------------------------
%% @doc
%% Release all held messages in given direction.
%%
%% An Error, enoent, is returned if there is no hold
%% in the given direction
%%
%% @spec release(Srv, Dir) -> ok | {error, Error}
%% @end
%%--------------------------------------------------------------------
release(Srv, Dir) ->
gen_server:call(Srv, {release, Dir}).
%%--------------------------------------------------------------------
%% @doc
%% Release all held messages in given direction after the
%% next message in the trigger direction
%%
%% An Error, enoent, is returned if there is no hold
%% in the given direction
%%
%% @spec release_next(Srv, Dir, TriggerDir) -> ok | {error, Error}
%% @end
%%--------------------------------------------------------------------
release_next(Srv, Dir, TriggerDir) ->
gen_server:call(Srv, {release_next, Dir, TriggerDir}).
%%--------------------------------------------------------------------
%% @doc
%% Starts the server
%%
%% @spec start_link() -> {ok, Pid} | ignore | {error, Error}
%% @end
%%--------------------------------------------------------------------
start_link(ListenAddr, ListenPort, PeerAddr, PeerPort) ->
gen_server:start_link(?MODULE, [ListenAddr, ListenPort, PeerAddr, PeerPort], []).
stop(Srv) ->
unlink(Srv),
Srv ! stop.
%%%===================================================================
%%% gen_server callbacks
%%%===================================================================
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Initializes the server
%%
%% @spec init(Args) -> {ok, State} |
%% {ok, State, Timeout} |
%% ignore |
%% {stop, Reason}
%% @end
%%--------------------------------------------------------------------
init([ListenAddr, ListenPort, PeerAddr, PeerPort | Options]) ->
IfAddr = case ListenAddr of
{0,0,0,0} ->
[];
_ ->
[{ifaddr, ListenAddr}]
end,
case gen_tcp:listen(ListenPort, [{reuseaddr, true}, {backlog, 1}, {active, false}, binary | IfAddr]) of
{ok, LSock} ->
Parent = self(),
{LPid, _LMod} = spawn_monitor(fun() -> listen(Parent, LSock) end),
S = #state{local_addr = ListenAddr,
local_port = ListenPort,
lpid = LPid,
peer_addr = PeerAddr,
peer_port = PeerPort
},
{ok, S};
Error ->
{stop, Error}
end.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling call messages
%%
%% @spec handle_call(Request, From, State) ->
%% {reply, Reply, State} |
%% {reply, Reply, State, Timeout} |
%% {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, Reply, State} |
%% {stop, Reason, State}
%% @end
%%--------------------------------------------------------------------
handle_call({hold, Dir, N, Tmo}, _From, State) ->
case Dir of
tx ->
do_hold(#state.tx_hold, State#state.peer, N, Tmo, State);
rx ->
do_hold(#state.rx_hold, State#state.local, N, Tmo, State);
_ ->
{reply, {error, einval}, State}
end;
handle_call({release, Dir}, _From, State) ->
case Dir of
tx ->
do_release(#state.tx_hold, State);
rx ->
do_release(#state.rx_hold, State);
_ ->
{reply, {error, einval}, State}
end;
handle_call({release_next, _Dir, _TriggerDir}, _From, State) ->
{reply, {error, nyi}, State};
handle_call(Request, _From, State) ->
Reply = {unhandled, Request},
{reply, Reply, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling cast messages
%%
%% @spec handle_cast(Msg, State) -> {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, State}
%% @end
%%--------------------------------------------------------------------
handle_cast(_Msg, State) ->
{noreply, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Handling all non call/cast messages
%%
%% @spec handle_info(Info, State) -> {noreply, State} |
%% {noreply, State, Timeout} |
%% {stop, Reason, State}
%% @end
%%--------------------------------------------------------------------
handle_info({tcp, Local, Data}, S) when S#state.local == Local ->
S1 = do_local(Data, S),
{noreply, S1};
handle_info({tcp_error, Local, Error}, S) when S#state.local == Local ->
S1 = do_local({error, Error}, S),
{noreply, S1};
handle_info({tcp_closed, Local}, S) when S#state.local == Local ->
S1 = do_local(closed, S),
{noreply, S1};
handle_info({tcp, Peer, Data}, S) when S#state.peer == Peer ->
S1 = do_peer(Data, S),
{noreply, S1};
handle_info({tcp_error, Peer, Error}, S) when S#state.peer == Peer ->
S1 = do_peer({error, Error}, S),
{noreply, S1};
handle_info({tcp_closed, Peer}, S) when S#state.peer == Peer ->
S1 = do_peer(closed, S),
{noreply, S1};
handle_info({accept, Local}, S) ->
S1 = do_accept(Local, S),
{noreply, S1};
handle_info({activate, Local}, State) ->
inet:setopts(Local, [{active, true}]),
{noreply, State};
handle_info({release, Pos}, S) ->
{reply, _, S1} = do_release(Pos,S),
{noreply, S1};
handle_info(stop, State) ->
{stop, normal, State};
handle_info({'DOWN', _Ref, _process, LPid, Reason}, S) when S#state.lpid == LPid ->
io:format("Acceptor has finished: ~p~n", [Reason]),
{noreply, S};
handle_info(_Info, State) ->
io:format("Unhandled info: ~p~n", [_Info]),
{noreply, State}.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% This function is called by a gen_server when it is about to
%% terminate. It should be the opposite of Module:init/1 and do any
%% necessary cleaning up. When it returns, the gen_server terminates
%% with Reason. The return value is ignored.
%%
%% @spec terminate(Reason, State) -> void()
%% @end
%%--------------------------------------------------------------------
terminate(_Reason, _State) ->
ok.
%%--------------------------------------------------------------------
%% @private
%% @doc
%% Convert process state when code is changed
%%
%% @spec code_change(OldVsn, State, Extra) -> {ok, NewState}
%% @end
%%--------------------------------------------------------------------
code_change(_OldVsn, State, _Extra) ->
{ok, State}.
%%%===================================================================
%%% Internal functions
%%%===================================================================
do_hold(Pos, _Port, _N, _Tmo, S) when element(Pos, S) /= undefined ->
{reply, {error, ealready}, S};
do_hold(Pos, Port, N, Tmo, S) ->
TRef = if is_integer(Tmo) andalso Tmo > 0 ->
erlang:send_after(Tmo, self(), {release, Pos});
true ->
undefined
end,
Hold = #hold{port = Port, n = N, tmo = Tmo, tref = TRef},
{reply, ok, setelement(Pos, S, Hold)}.
do_release(HPos, S) when element(HPos, S) == undefined ->
{reply, {error, enoent}, S};
do_release(HPos, S) ->
#hold{port = Port, tref = TRef, q = Q} = element(HPos, S),
lists:foreach(fun(M) -> gen_tcp:send(Port, M), erlang:yield() end, Q),
catch erlang:cancel_timer(TRef),
receive
{release, HPos} -> ok
after 0 ->
ok
end,
{reply, ok, setelement(HPos, S, undefined)}.
listen(Parent, LSock) ->
monitor(process, Parent),
do_listen(Parent, LSock).
do_listen(Parent, LSock) ->
%% So annoying there is no select-like sematic for this
case gen_tcp:accept(LSock, ?ACCEPT_TMO) of
{ok, Sock} ->
Parent ! {accept, Sock},
gen_tcp:controlling_process(Sock, Parent),
Parent ! {activate, Sock},
do_flush(Parent, Sock),
gen_tcp:close(LSock);
{error, timeout} ->
receive
DOWN when element(1, DOWN) == 'DOWN' ->
ok;
stop ->
ok
after 1 ->
do_listen(Parent, LSock)
end;
Error ->
gen_tcp:close(LSock),
exit({accept,Error})
end.
do_flush(Parent, Sock) ->
receive
{Tcp, Sock, _} = Msg when Tcp == tcp; Tcp == tcp_error ->
Parent ! Msg,
do_flush(Parent, Sock);
{tcp_closed, Sock} = Msg ->
Parent ! Msg,
do_flush(Parent, Sock)
after 1 ->
ok
end.
do_accept(Local, S) ->
case gen_tcp:connect(S#state.peer_addr, S#state.peer_port, [{active, true}, binary]) of
{ok, Peer} ->
S#state{local = Local, peer = Peer};
Error ->
exit({connect, Error})
end.
do_local(Data, S) when is_binary(Data) ->
TxH = S#state.tx_hold,
if TxH == undefined ->
gen_tcp:send(S#state.peer, Data),
S;
TxH#hold.n == 0 ->
lists:foreach(fun(M) -> gen_tcp:send(S#state.peer, M) end, TxH#hold.q),
gen_tcp:send(S#state.peer, Data),
catch erlang:cancel_timer(TxH#hold.tref),
TxP = #state.tx_hold,
receive
{release, TxP} ->
ok
after 0 ->
ok
end,
S#state{tx_hold = undefined};
true ->
Q = TxH#hold.q ++ [Data],
N = if is_integer(TxH#hold.n) ->
TxH#hold.n -1;
true ->
TxH#hold.n
end,
S#state{tx_hold = TxH#hold{q = Q, n = N}}
end;
do_local(Error, _S) ->
exit({local, Error}).
do_peer(Data, S) when is_binary(Data) ->
RxH = S#state.rx_hold,
if RxH == undefined ->
gen_tcp:send(S#state.local, Data),
S;
RxH#hold.n == 0 ->
lists:foreach(fun(M) -> gen_tcp:send(S#state.local, M) end, RxH#hold.q),
gen_tcp:send(S#state.local, Data),
catch erlang:cancel_timer(RxH#hold.tref),
RxP = #state.rx_hold,
receive
{release, RxP} ->
ok
after 0 ->
ok
end,
S#state{rx_hold = undefined};
true ->
Q = RxH#hold.q ++ [Data],
N = if is_integer(RxH#hold.n) ->
RxH#hold.n -1;
true ->
RxH#hold.n
end,
S#state{rx_hold = RxH#hold{q = Q, n = N}}
end;
do_peer(Error, _S) ->
exit({peer, Error}).