aboutsummaryrefslogtreecommitdiffstats
path: root/lib/diameter/src/base/diameter_peer_fsm.erl
diff options
context:
space:
mode:
Diffstat (limited to 'lib/diameter/src/base/diameter_peer_fsm.erl')
-rw-r--r--lib/diameter/src/base/diameter_peer_fsm.erl146
1 files changed, 86 insertions, 60 deletions
diff --git a/lib/diameter/src/base/diameter_peer_fsm.erl b/lib/diameter/src/base/diameter_peer_fsm.erl
index 46d231da74..601e48e817 100644
--- a/lib/diameter/src/base/diameter_peer_fsm.erl
+++ b/lib/diameter/src/base/diameter_peer_fsm.erl
@@ -129,6 +129,7 @@
%% the request was sent explicitly with
%% diameter:call/4.
strict :: boolean(),
+ ack = false :: boolean(),
length_errors :: exit | handle | discard,
incoming_maxlen :: integer() | infinity}).
@@ -235,7 +236,7 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) ->
Tmo = proplists:get_value(capx_timeout, Opts, ?CAPX_TIMEOUT),
Strictness = proplists:get_value(capx_strictness, Opts, true),
- OnLengthErr = proplists:get_value(length_errors, Opts, exit),
+ LengthErr = proplists:get_value(length_errors, Opts, exit),
{TPid, Addrs} = start_transport(T, Rest, Svc),
@@ -247,7 +248,7 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) ->
dictionary = Dict0,
mode = M,
service = svc(Svc, Addrs),
- length_errors = OnLengthErr,
+ length_errors = LengthErr,
strict = Strictness,
incoming_maxlen = Maxlen}.
%% The transport returns its local ip addresses so that different
@@ -442,9 +443,18 @@ transition({connection_timeout = T, TPid},
transition({connection_timeout, _}, _) ->
ok;
+%% Requests for acknowledgements to the transport.
+transition({diameter, ack}, S) ->
+ S#state{ack = true};
+
%% Incoming message from the transport.
-transition({diameter, {recv, MsgT}}, S) ->
- incoming(MsgT, S);
+transition({diameter, {recv, Msg}}, S) ->
+ incoming(recv(Msg, S), S);
+
+%% Handler of an incoming request is telling of its existence.
+transition({handler, Pid}, _) ->
+ put_route(Pid),
+ ok;
%% Timeout when still in the same state ...
transition({timeout = T, PS}, #state{state = PS}) ->
@@ -458,7 +468,7 @@ transition({timeout, _}, _) ->
transition({send, Msg}, S) ->
outgoing(Msg, S);
transition({send, Msg, Route}, S) ->
- put_route(Route),
+ route_outgoing(Route),
outgoing(Msg, S);
%% Request for graceful shutdown at remove_transport, stop_service of
@@ -487,12 +497,13 @@ transition({'DOWN', _, process, WPid, _},
transition({'DOWN', _, process, TPid, _},
#state{transport = TPid}
= S) ->
- start_next(S);
+ start_next(S#state{ack = false});
%% Transport has died after connection timeout, or handler process has
%% died.
-transition({'DOWN', _, process, Pid, _}, _) ->
- erase_route(Pid),
+transition({'DOWN', _, process, Pid, _}, #state{transport = TPid}) ->
+ is_reference(erase_route(Pid))
+ andalso send(TPid, false), %% answer not forthcoming
ok;
%% State query.
@@ -502,37 +513,56 @@ transition({state, Pid}, #state{state = S, transport = TPid}) ->
%% Crash on anything unexpected.
-%% put_route/1
-%%
+%% route_outgoing/1
+
%% Map identifiers in an outgoing request to be able to lookup the
%% handler process when the answer is received.
-
-put_route({Pid, Ref, Seqs}) ->
+route_outgoing({Pid, Ref, Seqs}) -> %% request
MRef = monitor(process, Pid),
put(Pid, Seqs),
- put(Seqs, {Pid, Ref, MRef}).
+ put(Seqs, {Pid, Ref, MRef});
+
+%% Remove a mapping made for an incoming request.
+route_outgoing(Pid)
+ when is_pid(Pid) -> %% answer
+ MRef = erase_route(Pid),
+ undefined == MRef orelse demonitor(MRef).
-%% get_route/1
+%% put_route/1
+
+%% Monitor on a handler process for an incoming request.
+put_route(Pid) ->
+ MRef = monitor(process, Pid),
+ put(Pid, MRef).
-get_route(#diameter_packet{header = #diameter_header{is_request = false}}
- = Pkt) ->
+%% get_route/2
+
+%% incoming answer
+get_route(_, #diameter_packet{header = #diameter_header{is_request = false}}
+ = Pkt) ->
Seqs = diameter_codec:sequence_numbers(Pkt),
case erase(Seqs) of
{Pid, Ref, MRef} ->
demonitor(MRef),
erase(Pid),
{Pid, Ref, self()};
- undefined ->
+ undefined -> %% request unknown
false
end;
-get_route(_) ->
- false.
+%% incoming request
+get_route(Ack, _) ->
+ Ack.
%% erase_route/1
erase_route(Pid) ->
- erase(erase(Pid)).
+ case erase(Pid) of
+ {_,_} = Seqs ->
+ erase(Seqs);
+ T ->
+ T
+ end.
%% capx/1
@@ -611,29 +641,24 @@ encode(Rec, Dict) ->
%% incoming/2
-incoming({Msg, NPid}, S) ->
- try recv(Msg, S) of
- T ->
- NPid ! {diameter, discard},
- T
- catch
- {?MODULE, Name, Pkt} ->
- incoming(Name, Pkt, NPid, S)
- end;
+incoming({recv = T, Name, Pkt}, #state{parent = Pid, ack = Ack} = S) ->
+ Pid ! {T, self(), get_route(Ack, Pkt), Name, Pkt},
+ rcv(Name, Pkt, S);
-incoming(Msg, S) ->
- try
- recv(Msg, S)
- catch
- {?MODULE, Name, Pkt} ->
- incoming(Name, Pkt, false, S)
- end.
+incoming(#diameter_header{is_request = R}, #state{transport = TPid,
+ ack = Ack}) ->
+ R andalso Ack andalso send(TPid, false),
+ ok;
+
+incoming(<<_:32, 1:1, _/bits>>, #state{ack = true} = S) ->
+ send(S#state.transport, false),
+ ok;
-%% incoming/4
+incoming(<<_/bits>>, _) ->
+ ok;
-incoming(Name, Pkt, NPid, #state{parent = Pid} = S) ->
- Pid ! {recv, self(), get_route(Pkt), Name, Pkt, NPid},
- rcv(Name, Pkt, S).
+incoming(T, _) ->
+ T.
%% recv/2
@@ -658,18 +683,19 @@ recv1(_,
#diameter_packet{header = H, bin = Bin},
#state{incoming_maxlen = M})
when M < size(Bin) ->
- invalid(false, incoming_maxlen_exceeded, {size(Bin), H});
+ invalid(false, incoming_maxlen_exceeded, {size(Bin), H}),
+ H;
%% Ignore anything but an expected CER/CEA if so configured. This is
%% non-standard behaviour.
-recv1(Name, _, #state{state = {'Wait-CEA', _, _},
- strict = false})
+recv1(Name, #diameter_packet{header = H}, #state{state = {'Wait-CEA', _, _},
+ strict = false})
when Name /= 'CEA' ->
- ok;
-recv1(Name, _, #state{state = recv_CER,
- strict = false})
+ H;
+recv1(Name, #diameter_packet{header = H}, #state{state = recv_CER,
+ strict = false})
when Name /= 'CER' ->
- ok;
+ H;
%% Incoming request after outgoing DPR: discard. Don't discard DPR, so
%% both ends don't do so when sending simultaneously.
@@ -677,13 +703,15 @@ recv1(Name,
#diameter_packet{header = #diameter_header{is_request = true} = H},
#state{dpr = {_,_,_}})
when Name /= 'DPR' ->
- invalid(false, recv_after_outgoing_dpr, H);
+ invalid(false, recv_after_outgoing_dpr, H),
+ H;
%% Incoming request after incoming DPR: discard.
recv1(_,
#diameter_packet{header = #diameter_header{is_request = true} = H},
#state{dpr = true}) ->
- invalid(false, recv_after_incoming_dpr, H);
+ invalid(false, recv_after_incoming_dpr, H),
+ H;
%% DPA with identifier mismatch, or in response to a DPR initiated by
%% the service.
@@ -701,7 +729,7 @@ recv1('DPA' = N,
%% Any other message with a header and no length errors: send to the
%% parent.
recv1(Name, Pkt, #state{}) ->
- throw({?MODULE, Name, Pkt}).
+ {recv, Name, Pkt}.
%% recv/3
@@ -720,10 +748,12 @@ recv(#diameter_header{}
#diameter_packet{bin = Bin},
#state{length_errors = E}) ->
T = {size(Bin), bit_size(Bin) rem 8, H},
- invalid(E, message_length_mismatch, T);
+ invalid(E, message_length_mismatch, T),
+ Bin;
recv(false, #diameter_packet{bin = Bin}, #state{length_errors = E}) ->
- invalid(E, truncated_header, Bin).
+ invalid(E, truncated_header, Bin),
+ Bin.
%% Note that counters here only count discarded messages.
invalid(E, Reason, T) ->
@@ -779,14 +809,10 @@ rcv('DPA' = N,
diameter_peer:close(TPid),
{stop, N};
-%% Ignore anything else, an unsolicited DPA in particular. Note that
-%% dpa_timeout deals with the case in which the peer sends the wrong
-%% identifiers in DPA.
-rcv(N, #diameter_packet{header = H}, _)
- when N == 'CER';
- N == 'CEA';
- N == 'DPR';
- N == 'DPA' ->
+%% Ignore an unsolicited DPA in particular. Note that dpa_timeout
+%% deals with the case in which the peer sends the wrong identifiers
+%% in DPA.
+rcv('DPA' = N, #diameter_packet{header = H}, _) ->
?LOG(ignored, N),
%% Note that these aren't counted in the normal recv counter.
diameter_stats:incr({diameter_codec:msg_id(H), recv, ignored}),