%% %% %CopyrightBegin% %% %% Copyright Ericsson AB 2010-2013. All Rights Reserved. %% %% The contents of this file are subject to the Erlang Public License, %% Version 1.1, (the "License"); you may not use this file except in %% compliance with the License. You should have received a copy of the %% Erlang Public License along with this software. If not, it can be %% retrieved online at http://www.erlang.org/. %% %% Software distributed under the License is distributed on an "AS IS" %% basis, WITHOUT WARRANTY OF ANY KIND, either express or implied. See %% the License for the specific language governing rights and limitations %% under the License. %% %% %CopyrightEnd% %% -module(diameter_tcp). -behaviour(gen_server). %% interface -export([start/3]). %% child start from supervisor -export([start_link/1]). %% child start from here -export([init/1]). %% gen_server callbacks -export([handle_call/3, handle_cast/2, handle_info/2, code_change/3, terminate/2]). -export([info/1]). %% service_info callback -export([ports/0, ports/1]). -include_lib("diameter/include/diameter.hrl"). %% Keys into process dictionary. -define(INFO_KEY, info). -define(REF_KEY, ref). -define(ERROR(T), erlang:error({T, ?MODULE, ?LINE})). -define(DEFAULT_PORT, 3868). %% RFC 3588, ch 2.1 -define(LISTENER_TIMEOUT, 30000). -define(DEFAULT_FRAGMENT_TIMEOUT, 1000). -define(IS_UINT32(N), (is_integer(N) andalso 0 =< N andalso 0 == N bsr 32)). -define(IS_TIMEOUT(N), (infinity == N orelse ?IS_UINT32(N))). %% cb_info passed to ssl. -define(TCP_CB(Mod), {Mod, tcp, tcp_closed, tcp_error}). %% The same gen_server implementation supports three different kinds %% of processes: an actual transport process, one that will club it to %% death should the parent die before a connection is established, and %% a process owning the listening port. %% Listener process state. -record(listener, {socket :: inet:socket(), count = 1 :: non_neg_integer(), tref :: reference()}). %% Monitor process state. -record(monitor, {parent :: pid(), transport = self() :: pid()}). -type length() :: 0..16#FFFFFF. %% message length from Diameter header -type size() :: non_neg_integer(). %% accumulated binary size -type frag() :: {length(), size(), binary(), list(binary())} | binary(). %% Accepting/connecting transport process state. -record(transport, {socket :: inet:socket() | ssl:sslsocket(), %% accept/connect socket parent :: pid(), %% of process that started us module :: module(), %% gen_tcp-like module frag = <<>> :: frag(), %% message fragment ssl :: boolean() | [term()], %% ssl options timeout :: infinity | 0..16#FFFFFFFF, %% fragment timeout flush = false :: boolean()}). %% flush fragment at timeout %% The usual transport using gen_tcp can be replaced by anything %% sufficiently gen_tcp-like by passing a 'module' option as the first %% (for simplicity) transport option. The transport_module diameter_etcp %% uses this to set itself as the module to call, its start/3 just %% calling start/3 here with the option set. %% --------------------------------------------------------------------------- %% # start/3 %% --------------------------------------------------------------------------- start({T, Ref}, #diameter_service{capabilities = Caps}, Opts) -> diameter_tcp_sup:start(), %% start tcp supervisors on demand {Mod, Rest} = split(Opts), Addrs = Caps#diameter_caps.host_ip_address, Arg = {T, Ref, Mod, self(), Rest, Addrs}, diameter_tcp_sup:start_child(Arg). split([{module, M} | Opts]) -> {M, Opts}; split(Opts) -> {gen_tcp, Opts}. %% start_link/1 start_link(T) -> proc_lib:start_link(?MODULE, init, [T], infinity, diameter_lib:spawn_opts(server, [])). %% --------------------------------------------------------------------------- %% # info/1 %% --------------------------------------------------------------------------- info({Mod, Sock}) -> lists:flatmap(fun(K) -> info(Mod, K, Sock) end, [{socket, fun sockname/2}, {peer, fun peername/2}, {statistics, fun getstat/2} | ssl_info(Mod, Sock)]). info(Mod, {K,F}, Sock) -> case F(Mod, Sock) of {ok, V} -> [{K,V}]; _ -> [] end. ssl_info(ssl = M, Sock) -> [{M, ssl_info(Sock)}]; ssl_info(_, _) -> []. ssl_info(Sock) -> [{peercert, C} || {ok, C} <- [ssl:peercert(Sock)]]. %% --------------------------------------------------------------------------- %% # init/1 %% --------------------------------------------------------------------------- init(T) -> gen_server:enter_loop(?MODULE, [], i(T)). %% i/1 %% A transport process. i({T, Ref, Mod, Pid, Opts, Addrs}) when T == accept; T == connect -> erlang:monitor(process, Pid), %% Since accept/connect might block indefinitely, spawn a process %% that does nothing but kill us with the parent until call %% returns. {ok, MPid} = diameter_tcp_sup:start_child(#monitor{parent = Pid}), {SslOpts, Rest0} = ssl(Opts), {OwnOpts, Rest} = own(Rest0), Tmo = proplists:get_value(fragment_timer, OwnOpts, ?DEFAULT_FRAGMENT_TIMEOUT), ?IS_TIMEOUT(Tmo) orelse ?ERROR({fragment_timer, Tmo}), Sock = i(T, Ref, Mod, Pid, SslOpts, Rest, Addrs), MPid ! {stop, self()}, %% tell the monitor to die M = if SslOpts -> ssl; true -> Mod end, setopts(M, Sock), putr(?REF_KEY, Ref), infinity == Tmo orelse erlang:start_timer(Tmo, self(), flush), #transport{parent = Pid, module = M, socket = Sock, ssl = SslOpts, timeout = Tmo}; %% Put the reference in the process dictionary since we now use it %% advertise the ssl socket after TLS upgrade. %% A monitor process to kill the transport if the parent dies. i(#monitor{parent = Pid, transport = TPid} = S) -> proc_lib:init_ack({ok, self()}), erlang:monitor(process, Pid), erlang:monitor(process, TPid), S; %% In principle a link between the transport and killer processes %% could do the same thing: have the accepting/connecting process be %% killed when the killer process dies as a consequence of parent %% death. However, a link can be unlinked and this is exactly what %% gen_tcp seems to so. Links should be left to supervisors. i({listen, LRef, APid, {Mod, Opts, Addrs}}) -> {[LA, LP], Rest} = proplists:split(Opts, [ip, port]), LAddr = get_addr(LA, Addrs), LPort = get_port(LP), {ok, LSock} = Mod:listen(LPort, gen_opts(LAddr, Rest)), true = diameter_reg:add_new({?MODULE, listener, {LRef, {LAddr, LSock}}}), proc_lib:init_ack({ok, self(), {LAddr, LSock}}), erlang:monitor(process, APid), start_timer(#listener{socket = LSock}). own(Opts) -> {Own, Rest} = proplists:split(Opts, [fragment_timer]), {lists:append(Own), Rest}. ssl(Opts) -> {[SslOpts], Rest} = proplists:split(Opts, [ssl_options]), {ssl_opts(SslOpts), Rest}. ssl_opts([]) -> false; ssl_opts([{ssl_options, true}]) -> true; ssl_opts([{ssl_options, Opts}]) when is_list(Opts) -> Opts; ssl_opts(L) -> ?ERROR({ssl_options, L}). %% i/7 %% Establish a TLS connection before capabilities exchange ... i(Type, Ref, Mod, Pid, true, Opts, Addrs) -> i(Type, Ref, ssl, Pid, [{cb_info, ?TCP_CB(Mod)} | Opts], Addrs); %% ... or not. i(Type, Ref, Mod, Pid, _, Opts, Addrs) -> i(Type, Ref, Mod, Pid, Opts, Addrs). i(accept = T, Ref, Mod, Pid, Opts, Addrs) -> {LAddr, LSock} = listener(Ref, {Mod, Opts, Addrs}), proc_lib:init_ack({ok, self(), [LAddr]}), Sock = ok(accept(Mod, LSock)), publish(Mod, T, Ref, Sock), diameter_peer:up(Pid), Sock; i(connect = T, Ref, Mod, Pid, Opts, Addrs) -> {[LA, RA, RP], Rest} = proplists:split(Opts, [ip, raddr, rport]), LAddr = get_addr(LA, Addrs), RAddr = get_addr(RA, []), RPort = get_port(RP), proc_lib:init_ack({ok, self(), [LAddr]}), Sock = ok(connect(Mod, RAddr, RPort, gen_opts(LAddr, Rest))), publish(Mod, T, Ref, Sock), diameter_peer:up(Pid, {RAddr, RPort}), Sock. publish(Mod, T, Ref, Sock) -> true = diameter_reg:add_new({?MODULE, T, {Ref, Sock}}), putr(?INFO_KEY, {Mod, Sock}). %% for info/1 ok({ok, T}) -> T; ok(No) -> x(No). x(Reason) -> exit({shutdown, Reason}). %% listener/2 listener(LRef, T) -> l(diameter_reg:match({?MODULE, listener, {LRef, '_'}}), LRef, T). %% Existing process with the listening socket ... l([{{?MODULE, listener, {_, AS}}, LPid}], _, _) -> LPid ! {accept, self()}, AS; %% ... or not: start one. l([], LRef, T) -> {ok, _, AS} = diameter_tcp_sup:start_child({listen, LRef, self(), T}), AS. %% get_addr/2 get_addr(As, Def) -> diameter_lib:ipaddr(addr(As, Def)). %% Take the first address from the service if several are unspecified. addr([], [Addr | _]) -> Addr; addr([{_, Addr}], _) -> Addr; addr(As, Addrs) -> ?ERROR({invalid_addrs, As, Addrs}). %% get_port/1 get_port([{_, Port}]) -> Port; get_port([]) -> ?DEFAULT_PORT; get_port(Ps) -> ?ERROR({invalid_ports, Ps}). %% gen_opts/2 gen_opts(LAddr, Opts) -> {L,_} = proplists:split(Opts, [binary, packet, active]), [[],[],[]] == L orelse ?ERROR({reserved_options, Opts}), [binary, {packet, 0}, {active, once}, {ip, LAddr} | Opts]. %% --------------------------------------------------------------------------- %% # ports/1 %% --------------------------------------------------------------------------- ports() -> Ts = diameter_reg:match({?MODULE, '_', '_'}), [{type(T), resolve(T,S), Pid} || {{?MODULE, T, {_,S}}, Pid} <- Ts]. ports(Ref) -> Ts = diameter_reg:match({?MODULE, '_', {Ref, '_'}}), [{type(T), resolve(T,S), Pid} || {{?MODULE, T, {R,S}}, Pid} <- Ts, R == Ref]. type(listener) -> listen; type(T) -> T. sock(listener, {_LAddr, Sock}) -> Sock; sock(_, Sock) -> Sock. resolve(Type, S) -> Sock = sock(Type, S), try ok(portnr(Sock)) catch _:_ -> Sock end. portnr(Sock) when is_port(Sock) -> portnr(gen_tcp, Sock); portnr(Sock) -> portnr(ssl, Sock). %% --------------------------------------------------------------------------- %% # handle_call/3 %% --------------------------------------------------------------------------- handle_call(_, _, State) -> {reply, nok, State}. %% --------------------------------------------------------------------------- %% # handle_cast/2 %% --------------------------------------------------------------------------- handle_cast(_, State) -> {noreply, State}. %% --------------------------------------------------------------------------- %% # handle_info/2 %% --------------------------------------------------------------------------- handle_info(T, #transport{} = S) -> {noreply, #transport{} = t(T,S)}; handle_info(T, #listener{} = S) -> {noreply, #listener{} = l(T,S)}; handle_info(T, #monitor{} = S) -> m(T,S), x(T). %% --------------------------------------------------------------------------- %% # code_change/3 %% --------------------------------------------------------------------------- code_change(_, State, _) -> {ok, State}. %% --------------------------------------------------------------------------- %% # terminate/2 %% --------------------------------------------------------------------------- terminate(_, _) -> ok. %% --------------------------------------------------------------------------- putr(Key, Val) -> put({?MODULE, Key}, Val). getr(Key) -> get({?MODULE, Key}). %% start_timer/1 start_timer(#listener{count = 0} = S) -> S#listener{tref = erlang:start_timer(?LISTENER_TIMEOUT, self(), close)}; start_timer(S) -> S. %% m/2 %% %% Transition monitor state. %% Transport is telling us to die. m({stop, TPid}, #monitor{transport = TPid}) -> ok; %% Transport has died. m({'DOWN', _, process, TPid, _}, #monitor{transport = TPid}) -> ok; %% Transport parent has died. m({'DOWN', _, process, Pid, _}, #monitor{parent = Pid, transport = TPid}) -> exit(TPid, {shutdown, parent}). %% l/2 %% %% Transition listener state. %% Another accept transport is attaching. l({accept, TPid}, #listener{count = N} = S) -> erlang:monitor(process, TPid), S#listener{count = N+1}; %% Accepting process has died. l({'DOWN', _, process, _, _}, #listener{count = N} = S) -> start_timer(S#listener{count = N-1}); %% Timeout after the last accepting process has died. l({timeout, TRef, close = T}, #listener{tref = TRef, count = 0}) -> x(T); l({timeout, _, close}, #listener{} = S) -> S. %% t/2 %% %% Transition transport state. t(T,S) -> case transition(T,S) of ok -> S; #transport{} = NS -> NS; {stop, Reason} -> x(Reason); stop -> x(T) end. %% transition/2 %% Initial incoming message when we might need to upgrade to TLS: %% don't request another message until we know. transition({tcp, Sock, Bin}, #transport{socket = Sock, parent = Pid, frag = Head, module = M, ssl = Opts} = S) when is_list(Opts) -> case rcv(Head, Bin) of {Msg, B} when is_binary(Msg) -> diameter_peer:recv(Pid, Msg), S#transport{frag = B}; Frag -> setopts(M, Sock), S#transport{frag = Frag} end; %% Incoming message. transition({P, Sock, Bin}, #transport{socket = Sock, module = M, ssl = B} = S) when P == tcp, not B; P == ssl, B -> setopts(M, Sock), recv(Bin, S); %% Capabilties exchange has decided on whether or not to run over TLS. transition({diameter, {tls, Ref, Type, B}}, #transport{parent = Pid} = S) -> #transport{socket = Sock, module = M} = NS = tls_handshake(Type, B, S), Pid ! {diameter, {tls, Ref}}, setopts(M, Sock), NS#transport{ssl = B}; transition({C, Sock}, #transport{socket = Sock, ssl = B}) when C == tcp_closed, not B; C == ssl_closed, B -> stop; transition({E, Sock, _Reason} = T, #transport{socket = Sock, ssl = B} = S) when E == tcp_error, not B; E == ssl_error, B -> ?ERROR({T,S}); %% Outgoing message. transition({diameter, {send, Bin}}, #transport{socket = Sock, module = M}) -> case send(M, Sock, Bin) of ok -> ok; {error, Reason} -> {stop, {send, Reason}} end; %% Request to close the transport connection. transition({diameter, {close, Pid}}, #transport{parent = Pid, socket = Sock, module = M}) -> M:close(Sock), stop; %% Timeout for reception of outstanding packets. transition({timeout, _TRef, flush}, #transport{timeout = Tmo} = S) -> erlang:start_timer(Tmo, self(), flush), flush(S); %% Request for the local port number. transition({resolve_port, Pid}, #transport{socket = Sock, module = M}) when is_pid(Pid) -> Pid ! portnr(M, Sock), ok; %% Parent process has died. transition({'DOWN', _, process, Pid, _}, #transport{parent = Pid}) -> stop. %% Crash on anything unexpected. %% tls_handshake/3 %% %% In the case that no tls message is received (eg. the service hasn't %% been configured to advertise TLS support) we will simply never ask %% for another TCP message, which will force the watchdog to %% eventually take us down. %% TLS has already been established with the connection. tls_handshake(_, _, #transport{ssl = true} = S) -> S; %% Capabilities exchange negotiated TLS but transport was not %% configured with an options list. tls_handshake(_, true, #transport{ssl = false}) -> ?ERROR(no_ssl_options); %% Capabilities exchange negotiated TLS: upgrade the connection. tls_handshake(Type, true, #transport{socket = Sock, module = M, ssl = Opts} = S) -> {ok, SSock} = tls(Type, Sock, [{cb_info, ?TCP_CB(M)} | Opts]), Ref = getr(?REF_KEY), true = diameter_reg:add_new({?MODULE, Type, {Ref, SSock}}), S#transport{socket = SSock, module = ssl}; %% Capabilities exchange has not negotiated TLS. tls_handshake(_, false, S) -> S. tls(connect, Sock, Opts) -> ssl:connect(Sock, Opts); tls(accept, Sock, Opts) -> ssl:ssl_accept(Sock, Opts). %% recv/2 %% %% Reassemble fragmented messages and extract multiple message sent %% using Nagle. recv(Bin, #transport{parent = Pid, frag = Head} = S) -> case rcv(Head, Bin) of {Msg, B} when is_binary(Msg) -> diameter_peer:recv(Pid, Msg), recv(B, S#transport{frag = <<>>}); Frag -> S#transport{frag = Frag, flush = false} end. %% rcv/2 %% No previous fragment. rcv(<<>>, Bin) -> rcv(Bin); %% Not even the first four bytes of the header. rcv(Head, Bin) when is_binary(Head) -> rcv(<
>); %% Or enough to know how many bytes to extract. rcv({Len, N, Head, Acc}, Bin) -> rcv(Len, N + size(Bin), Head, [Bin | Acc]). %% rcv/4 %% Extract a message for which we have all bytes. rcv(Len, N, Head, Acc) when Len =< N -> recv1(Len, bin(Head, Acc)); %% Wait for more packets. rcv(Len, N, Head, Acc) -> {Len, N, Head, Acc}. %% rcv/1 %% Nothing left. rcv(<<>> = Bin) -> Bin; %% The Message Length isn't even sufficient for a header. Chances are %% things will go south from here but if we're lucky then the bytes we %% have extend to an intended message boundary and we can recover by %% simply receiving them. Make it so. rcv(<<_:1/binary, Len:24, _/binary>> = Bin) when Len < 20 -> {Bin, <<>>}; %% Enough bytes to extract a message. rcv(<<_:1/binary, Len:24, _/binary>> = Bin) when Len =< size(Bin) -> recv1(Len, Bin); %% Or not: wait for more packets. rcv(<<_:1/binary, Len:24, _/binary>> = Head) -> {Len, size(Head), Head, []}; %% Not even 4 bytes yet. rcv(Head) -> Head. %% recv1/2 recv1(Len, Bin) -> <