%%
%% %CopyrightBegin%
%%
%% Copyright Ericsson AB 2010-2013. All Rights Reserved.
%%
%% The contents of this file are subject to the Erlang Public License,
%% Version 1.1, (the "License"); you may not use this file except in
%% compliance with the License. You should have received a copy of the
%% Erlang Public License along with this software. If not, it can be
%% retrieved online at http://www.erlang.org/.
%%
%% Software distributed under the License is distributed on an "AS IS"
%% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See
%% the License for the specific language governing rights and limitations
%% under the License.
%%
%% %CopyrightEnd%
%%

%%
%% Tests of capabilities exchange between Diameter nodes. In
%% particular, of error and event handling.
%%

-module(diameter_capx_SUITE).

-export([suite/0,
         all/0,
         groups/0,
         init_per_group/2,
         end_per_group/2,
         init_per_testcase/2,
         end_per_testcase/2]).

%% testcases
-export([start/1,
         start_services/1,
         add_listeners/1,
         s_no_common_application/1,
         c_no_common_application/1,
         s_no_common_security/1,
         c_no_common_security/1,
         s_unknown_peer/1,
         c_unknown_peer/1,
         s_unable/1,
         c_unable/1,
         s_client_reject/1,
         c_client_reject/1,
         remove_listeners/1,
         stop_services/1,
         stop/1]).

%% diameter callbacks
-export([peer_up/4,
         peer_down/4]).

-include("diameter.hrl").

%% ===========================================================================

-define(util, diameter_util).

-define(CLIENT, client).
-define(SERVER, server).

-define(ADDR, {127,0,0,1}).

-define(REALM, "erlang.org").
-define(HOST(Name), Name ++ "." ++ ?REALM).

%% Config for diameter:start_service/2.
-define(SERVICE(Name),
        [{'Origin-Realm', ?REALM},
         {'Host-IP-Address', [?ADDR]},
         {'Vendor-Id', 12345},
         {'Product-Name', "OTP/diameter"},
         {'Auth-Application-Id', [?DIAMETER_APP_ID_COMMON]},
         {'Acct-Application-Id', [?DIAMETER_APP_ID_ACCOUNTING]}
         | [{application, [{alias, A},
                           {dictionary, D},
                           {module, [?MODULE, A]}]}
            || {A,D} <- [{base3588, diameter_gen_base_rfc3588},
                         {acct3588, diameter_gen_base_accounting},
                         {base6733, diameter_gen_base_rfc6733},
                         {acct6733, diameter_gen_acct_rfc6733}]]]).

-define(A, list_to_atom).
-define(L, atom_to_list).

-define(event,  #diameter_event).
-define(caps,   #diameter_caps).
-define(packet, #diameter_packet).

-define(fail(T), erlang:error({T, process_info(self(), messages)})).

-define(TIMEOUT, 10000).

-define(DICTS, [rfc3588, rfc6733]).

%% ===========================================================================

suite() ->
    [{timetrap, {seconds, 60}}].

all() -> [start,
          start_services,
          add_listeners]
      ++ [{group, D, P} || D <- ?DICTS, P <- [[], [parallel]]]
      ++ [remove_listeners,
          stop_services,
          stop].

groups() ->
    Tc = lists:flatmap(fun tc/1, tc()),
    [{D, [], Tc} || D <- ?DICTS].

%% Generate a unique hostname for each testcase so that watchdogs
%% don't prevent a connection from being brought up immediately.
init_per_testcase(Name, Config) ->
    Uniq = ["." ++ integer_to_list(N) || N <- tuple_to_list(now())],
    [{host, lists:flatten([?L(Name) | Uniq])} | Config].

init_per_group(Name, Config) ->
    [{rfc, Name} | Config].

end_per_group(_, _) ->
    ok.

end_per_testcase(N, _)
  when N == start;
       N == start_services;
       N == add_listeners;
       N == remove_listeners;
       N == stop_services;
       N == stop ->
    ok;

end_per_testcase(Name, Config) ->
    CRef = ?util:read_priv(Config, Name),
    ok = diameter:remove_transport(?CLIENT, CRef).

%% Testcases all come in two flavours, client and server.
tc(Name) ->
    [?A([C,$_|?L(Name)]) || C <- "cs"].

tc() ->
    [no_common_application,
     no_common_security,
     unknown_peer,
     unable,
     client_reject].

%% ===========================================================================
%% start/stop testcases

start(_Config) ->
    ok = diameter:start().

start_services(_Config) ->
    ok = diameter:start_service(?SERVER, ?SERVICE(?SERVER)),
    ok = diameter:start_service(?CLIENT, ?SERVICE(?CLIENT)).

%% One server that responds only to base accounting, one that responds
%% to both this and the common application. Share a common service just
%% to simplify config, and because we can.
add_listeners(Config) ->
    Acct = [listen(?SERVER,
                   [{capabilities, [{'Origin-Host', ?HOST(H)},
                                    {'Auth-Application-Id', []}]},
                    {applications, [A]},
                    {capabilities_cb, [fun server_capx/3, acct]}])
            || {A,H} <- [{acct3588, "acct3588-srv"},
                         {acct6733, "acct6733-srv"}]],
    Base = [listen(?SERVER,
                   [{capabilities, [{'Origin-Host', ?HOST(H)}]},
                    {applications, A},
                    {capabilities_cb, [fun server_capx/3, base]}])
            || {A,H} <- [{[base3588, acct3588], "base3588-srv"},
                         {[base6733, acct6733], "base6733-srv"}]],
    ?util:write_priv(Config, ?MODULE, {Base, Acct}). %% lref/2 reads

remove_listeners(_Config) ->
    ok = diameter:remove_transport(?SERVER, true).

stop_services(_Config) ->
    ok = diameter:stop_service(?CLIENT),
    ok = diameter:stop_service(?SERVER).

stop(_Config) ->
    ok = diameter:stop().

%% ===========================================================================
%% All the testcases come in pairs, one for receiving an event on the
%% client side, one on the server side. Note that testcases will
%% receive events resulting from other testcases when running in
%% parallel since the events are per service. The unique client
%% Origin-Host for each testcase plus transport references are used to
%% ensure that only the relevant event is extracted from the mailbox.
%% Don't bother extracting events that aren't relevant.

%% ====================
%% Ask the accounting server to speak the common application and expect
%% DIAMETER_NO_COMMON_APPLICATION = 5010.

s_no_common_application(Config) ->
    server_closed(Config, fun no_common_application/1, 5010).

c_no_common_application(Config) ->
    client_closed(Config, "acct-srv", fun no_common_application/1, 5010).

no_common_application(Config) ->
    [Common, _Acct] = apps(Config),
    connect(Config, acct, [{capabilities, [{'Acct-Application-Id', []}]},
                           {applications, [Common]}]).

%% ====================
%% Ask the base server to speak accounting with an unknown security
%% method and expect DIAMETER_NO_COMMON_SECURITY = 5017.

s_no_common_security(Config) ->
    server_closed(Config, fun no_common_security/1, 5017).

c_no_common_security(Config) ->
    client_closed(Config, "base-srv", fun no_common_security/1, 5017).

no_common_security(Config) ->
    [Common, _Acct] = apps(Config),
    connect(Config, base, [{capabilities, [{'Acct-Application-Id', []},
                                           {'Inband-Security-Id', [17, 18]}]},
                           {applications, [Common]}]).

%% ====================
%% Have the base server reject a decent CER with the protocol error
%% DIAMETER_UNKNOWN_PEER = 3010.

s_unknown_peer(Config) ->
    server_reject(Config, fun base/1, 3010).

c_unknown_peer(Config) ->
    Dict0 = dict0(Config),
    true = diameter:subscribe(?CLIENT),
    OH = host(Config, "base-srv"),

    {CRef, _} = base(Config),

    {'CEA', ?caps{}, ?packet{msg = Msg}} = client_recv(CRef),

    ['diameter_base_answer-message' | _] = Dict0:'#get-'(Msg),
    [OH, 3010] = Dict0:'#get-'(['Origin-Host', 'Result-Code'], Msg).

base(Config) ->
    connect(Config, base, [{applications, apps(Config)}]).

%% ====================
%% Have the base server reject a decent CER with the non-protocol
%% error DIAMETER_UNABLE_TO_COMPLY = 5012.

s_unable(Config) ->
    server_reject(Config, fun base/1, 5012).

c_unable(Config) ->
    client_closed(Config, "base-srv", fun base/1, 5012).

%% ====================
%% Have the client reject a decent CEA.

s_client_reject(Config) ->
    true = diameter:subscribe(?SERVER),
    OH = host(Config),

    {_, LRef} = client_reject(Config),

    receive
        ?event{service = ?SERVER,
               info = {up, LRef,
                           {_, ?caps{origin_host = {_, OH}}},
                           {listen, _},
                           ?packet{}}}
                    = Info ->
            Info
    after ?TIMEOUT ->
            ?fail({LRef, OH})
    end.

c_client_reject(Config) ->
    Dict0 = dict0(Config),
    true = diameter:subscribe(?CLIENT),
    OH = host(Config, "acct-srv"),

    {CRef, _} = client_reject(Config),

    {'CEA', {capabilities_cb, _, discard},
            ?caps{origin_host = {_, OH}},
            ?packet{msg = CEA}}
        = client_recv(CRef),

    [diameter_base_CEA | _] = Dict0:'#get-'(CEA),
    [2001] = Dict0:'#get-'(['Result-Code'], CEA).

client_reject(Config) ->
    connect(Config, acct, [{capabilities_cb, fun client_capx/2},
                           {applications, apps(Config)}]).

%% ===========================================================================

%% server_closed/3

server_closed(Config, F, RC) ->
    true = diameter:subscribe(?SERVER),
    OH = host(Config),

    {_, LRef} = F(Config),

    receive
        ?event{service = ?SERVER,
               info = {closed, LRef,
                               {'CER', RC,
                                       ?caps{origin_host = {_, OH}},
                                       ?packet{}}
                               = Reason,
                               {listen, _}}} ->
            Reason
    after ?TIMEOUT ->
            ?fail({LRef, OH})
    end.

%% server_reject/3

server_reject(Config, F, RC) ->
    true = diameter:subscribe(?SERVER),
    OH = host(Config),

    {_, LRef} = F(Config),

    receive
        ?event{service = ?SERVER,
               info = {closed, LRef,
                               {'CER', {capabilities_cb, _, RC},
                                       ?caps{origin_host = {_, OH}},
                                       ?packet{}}
                               = Reason,
                               {listen, _}}} ->
            Reason
    after ?TIMEOUT ->
            ?fail({LRef, OH})
    end.

%% cliient_closed/4

client_closed(Config, Host, F, RC) ->
    true = diameter:subscribe(?CLIENT),
    OH = host(Config, Host),

    {CRef, _} = F(Config),

    {'CEA', RC, ?caps{origin_host = {_, OH}}, ?packet{}}
        = client_recv(CRef).

srv(Config, Host) ->
    "rfc" ++ N = atom_to_list(proplists:get_value(rfc, Config)),
    [H, "srv" = S] = string:tokens(Host, "-"),
    H ++ N ++ "-" ++ S.

host(Config, Name) ->
    ?HOST(srv(Config, Name)).

%% client_recv/1

client_recv(CRef) ->
    receive
        ?event{service = ?CLIENT,
               info = {closed, CRef, Reason, {connect, _}}} ->
            Reason
    after ?TIMEOUT ->
            ?fail(CRef)
    end.

%% server_capx/3

server_capx(_, ?caps{origin_host = {_, [_,$_|"unknown_peer." ++ _]}}, _) ->
    unknown;

server_capx(_, ?caps{origin_host = {_, [_,$_|"unable." ++ _]}}, _) ->
    5012;  %% DIAMETER_UNABLE_TO_COMPLY

server_capx(_, ?caps{origin_host = {OH,DH}}, _) ->
    io:format("connection: ~p -> ~p~n", [DH,OH]),
    ok.

%% client_capx/2

client_capx(_, ?caps{origin_host = {[_,$_|"client_reject." ++ _], _}}) ->
    discard.

%% ===========================================================================

dict0(Config) ->
    case proplists:get_value(rfc, Config) of
        rfc3588 -> diameter_gen_base_rfc3588;
        rfc6733 -> diameter_gen_base_rfc6733
    end.

apps(Config) ->
    case proplists:get_value(rfc, Config) of
        rfc3588 -> [base3588, acct3588];
        rfc6733 -> [base6733, acct6733]
    end.

host(Config) ->
    {_, H} = lists:keyfind(host, 1, Config),
    ?HOST(H).

listen(Name, Opts) ->
    ?util:listen(Name, tcp, Opts).

connect(Config, T, Opts) ->
    {_, H} = lists:keyfind(host, 1, Config),
    LRef = lref(Config, T),
    CRef = connect(LRef, [{capabilities, [{'Origin-Host', ?HOST(H)}]}
                          | Opts]),
    Name = lists:takewhile(fun(C) -> C /= $. end, H),
    ?util:write_priv(Config, Name, CRef),  %% end_per_testcase reads
    {CRef, LRef}.

connect(LRef, Opts) ->
    [PortNr] = ?util:lport(tcp, LRef, 20),
    {ok, CRef} = diameter:add_transport(?CLIENT,
                                        {connect, opts(PortNr, Opts)}),
    CRef.

opts(PortNr, Opts) ->
    [{transport_module, diameter_tcp},
     {transport_config, [{raddr, ?ADDR},
                         {rport, PortNr},
                         {ip, ?ADDR},
                         {port, 0}]}
     | Opts].

lref(rfc3588, [LRef, _]) ->
    LRef;
lref(rfc6733, [_, LRef]) ->
    LRef;

lref(Config, T) ->
    lref(proplists:get_value(rfc, Config),
         case ?util:read_priv(Config, ?MODULE) of
             {R, _} when T == base ->
                 R;
             {_, R} when T == acct ->
                 R
         end).

%% ===========================================================================
%% diameter callbacks

peer_up(?SERVER,
        {_, ?caps{origin_host = {"acct" ++ _,
                                 [_,$_|"client_reject." ++ _]}}},
        State,
        _) ->
    State.

peer_down(?SERVER,
          {_, ?caps{origin_host = {"acct" ++ _,
                                   [_,$_|"client_reject." ++ _]}}},
          State,
          _) ->
    State.