aboutsummaryrefslogtreecommitdiffstats
path: root/lib/diameter/src/base/diameter_peer_fsm.erl
diff options
context:
space:
mode:
Diffstat (limited to 'lib/diameter/src/base/diameter_peer_fsm.erl')
-rw-r--r--lib/diameter/src/base/diameter_peer_fsm.erl223
1 files changed, 131 insertions, 92 deletions
diff --git a/lib/diameter/src/base/diameter_peer_fsm.erl b/lib/diameter/src/base/diameter_peer_fsm.erl
index 46d231da74..1b0dc417e5 100644
--- a/lib/diameter/src/base/diameter_peer_fsm.erl
+++ b/lib/diameter/src/base/diameter_peer_fsm.erl
@@ -128,7 +128,12 @@
%% outgoing DPR; boolean says whether or not
%% the request was sent explicitly with
%% diameter:call/4.
+ codec :: #{string_decode := boolean(),
+ strict_mbit := boolean(),
+ rfc := 3588 | 6733,
+ ordered_encode := false},
strict :: boolean(),
+ ack = false :: boolean(),
length_errors :: exit | handle | discard,
incoming_maxlen :: integer() | infinity}).
@@ -159,10 +164,7 @@
%% # start/3
%% ---------------------------------------------------------------------------
--spec start(T, [Opt], {[diameter:service_opt()],
- [node()],
- module(),
- #diameter_service{}})
+-spec start(T, [Opt], {map(), [node()], module(), #diameter_service{}})
-> {reference(), pid()}
when T :: {connect|accept, diameter:transport_ref()},
Opt :: diameter:transport_opt().
@@ -221,9 +223,10 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) ->
erlang:monitor(process, WPid),
wait(Ack, WPid),
diameter_stats:reg(Ref),
- diameter_codec:setopts([{common_dictionary, Dict0} | SvcOpts]),
- {_,_} = Mask = proplists:get_value(sequence, SvcOpts),
- Maxlen = proplists:get_value(incoming_maxlen, SvcOpts, 16#FFFFFF),
+
+ #{sequence := Mask, incoming_maxlen := Maxlen}
+ = SvcOpts,
+
{[Cs,Ds], Rest} = proplists:split(Opts, [capabilities_cb, disconnect_cb]),
putr(?CB_KEY, {Ref, [F || {_,F} <- Cs]}),
putr(?DPR_KEY, [F || {_, F} <- Ds]),
@@ -235,7 +238,7 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) ->
Tmo = proplists:get_value(capx_timeout, Opts, ?CAPX_TIMEOUT),
Strictness = proplists:get_value(capx_strictness, Opts, true),
- OnLengthErr = proplists:get_value(length_errors, Opts, exit),
+ LengthErr = proplists:get_value(length_errors, Opts, exit),
{TPid, Addrs} = start_transport(T, Rest, Svc),
@@ -247,9 +250,14 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) ->
dictionary = Dict0,
mode = M,
service = svc(Svc, Addrs),
- length_errors = OnLengthErr,
+ length_errors = LengthErr,
strict = Strictness,
- incoming_maxlen = Maxlen}.
+ incoming_maxlen = Maxlen,
+ codec = maps:with([string_decode,
+ strict_mbit,
+ rfc,
+ ordered_encode],
+ SvcOpts#{ordered_encode => false})}.
%% The transport returns its local ip addresses so that different
%% transports on the same service can use different local addresses.
%% The local addresses are put into Host-IP-Address avps here when
@@ -442,9 +450,18 @@ transition({connection_timeout = T, TPid},
transition({connection_timeout, _}, _) ->
ok;
+%% Requests for acknowledgements to the transport.
+transition({diameter, ack}, S) ->
+ S#state{ack = true};
+
%% Incoming message from the transport.
-transition({diameter, {recv, MsgT}}, S) ->
- incoming(MsgT, S);
+transition({diameter, {recv, Msg}}, S) ->
+ incoming(recv(Msg, S), S);
+
+%% Handler of an incoming request is telling of its existence.
+transition({handler, Pid}, _) ->
+ put_route(Pid),
+ ok;
%% Timeout when still in the same state ...
transition({timeout = T, PS}, #state{state = PS}) ->
@@ -458,7 +475,7 @@ transition({timeout, _}, _) ->
transition({send, Msg}, S) ->
outgoing(Msg, S);
transition({send, Msg, Route}, S) ->
- put_route(Route),
+ route_outgoing(Route),
outgoing(Msg, S);
%% Request for graceful shutdown at remove_transport, stop_service of
@@ -487,12 +504,13 @@ transition({'DOWN', _, process, WPid, _},
transition({'DOWN', _, process, TPid, _},
#state{transport = TPid}
= S) ->
- start_next(S);
+ start_next(S#state{ack = false});
%% Transport has died after connection timeout, or handler process has
%% died.
-transition({'DOWN', _, process, Pid, _}, _) ->
- erase_route(Pid),
+transition({'DOWN', _, process, Pid, _}, #state{transport = TPid}) ->
+ is_reference(erase_route(Pid))
+ andalso send(TPid, false), %% answer not forthcoming
ok;
%% State query.
@@ -502,37 +520,56 @@ transition({state, Pid}, #state{state = S, transport = TPid}) ->
%% Crash on anything unexpected.
-%% put_route/1
-%%
+%% route_outgoing/1
+
%% Map identifiers in an outgoing request to be able to lookup the
%% handler process when the answer is received.
-
-put_route({Pid, Ref, Seqs}) ->
+route_outgoing({Pid, Ref, Seqs}) -> %% request
MRef = monitor(process, Pid),
put(Pid, Seqs),
- put(Seqs, {Pid, Ref, MRef}).
+ put(Seqs, {Pid, Ref, MRef});
-%% get_route/1
+%% Remove a mapping made for an incoming request.
+route_outgoing(Pid)
+ when is_pid(Pid) -> %% answer
+ MRef = erase_route(Pid),
+ undefined == MRef orelse demonitor(MRef).
-get_route(#diameter_packet{header = #diameter_header{is_request = false}}
- = Pkt) ->
+%% put_route/1
+
+%% Monitor on a handler process for an incoming request.
+put_route(Pid) ->
+ MRef = monitor(process, Pid),
+ put(Pid, MRef).
+
+%% get_route/2
+
+%% incoming answer
+get_route(_, #diameter_packet{header = #diameter_header{is_request = false}}
+ = Pkt) ->
Seqs = diameter_codec:sequence_numbers(Pkt),
case erase(Seqs) of
{Pid, Ref, MRef} ->
demonitor(MRef),
erase(Pid),
{Pid, Ref, self()};
- undefined ->
+ undefined -> %% request unknown
false
end;
-get_route(_) ->
- false.
+%% incoming request
+get_route(Ack, _) ->
+ Ack.
%% erase_route/1
erase_route(Pid) ->
- erase(erase(Pid)).
+ case erase(Pid) of
+ {_,_} = Seqs ->
+ erase(Seqs);
+ T ->
+ T
+ end.
%% capx/1
@@ -560,7 +597,8 @@ send_CER(#state{state = {'Wait-Conn-Ack', Tmo},
mode = {connect, Remote},
service = #diameter_service{capabilities = LCaps},
transport = TPid,
- dictionary = Dict}
+ dictionary = Dict,
+ codec = Opts}
= S) ->
OH = LCaps#diameter_caps.origin_host,
req_send_CER(OH, Remote)
@@ -570,7 +608,7 @@ send_CER(#state{state = {'Wait-Conn-Ack', Tmo},
#diameter_packet{header = #diameter_header{end_to_end_id = Eid,
hop_by_hop_id = Hid}}
= Pkt
- = encode(CER, Dict),
+ = encode(CER, Opts, Dict),
incr(send, Pkt, Dict),
send(TPid, Pkt),
?LOG(send, 'CER'),
@@ -599,41 +637,36 @@ build_CER(#state{service = #diameter_service{capabilities = LCaps},
{ok, CER} = diameter_capx:build_CER(LCaps, Dict),
CER.
-%% encode/2
+%% encode/3
-encode(Rec, Dict) ->
+encode(Rec, Opts, Dict) ->
Seq = diameter_session:sequence({_,_} = getr(?SEQUENCE_KEY)),
Hdr = #diameter_header{version = ?DIAMETER_VERSION,
end_to_end_id = Seq,
hop_by_hop_id = Seq},
- diameter_codec:encode(Dict, #diameter_packet{header = Hdr,
- msg = Rec}).
+ diameter_codec:encode(Dict, Opts, #diameter_packet{header = Hdr,
+ msg = Rec}).
%% incoming/2
-incoming({Msg, NPid}, S) ->
- try recv(Msg, S) of
- T ->
- NPid ! {diameter, discard},
- T
- catch
- {?MODULE, Name, Pkt} ->
- incoming(Name, Pkt, NPid, S)
- end;
+incoming({recv = T, Name, Pkt}, #state{parent = Pid, ack = Ack} = S) ->
+ Pid ! {T, self(), get_route(Ack, Pkt), Name, Pkt},
+ rcv(Name, Pkt, S);
-incoming(Msg, S) ->
- try
- recv(Msg, S)
- catch
- {?MODULE, Name, Pkt} ->
- incoming(Name, Pkt, false, S)
- end.
+incoming(#diameter_header{is_request = R}, #state{transport = TPid,
+ ack = Ack}) ->
+ R andalso Ack andalso send(TPid, false),
+ ok;
-%% incoming/4
+incoming(<<_:32, 1:1, _/bits>>, #state{ack = true} = S) ->
+ send(S#state.transport, false),
+ ok;
-incoming(Name, Pkt, NPid, #state{parent = Pid} = S) ->
- Pid ! {recv, self(), get_route(Pkt), Name, Pkt, NPid},
- rcv(Name, Pkt, S).
+incoming(<<_/bits>>, _) ->
+ ok;
+
+incoming(T, _) ->
+ T.
%% recv/2
@@ -658,18 +691,19 @@ recv1(_,
#diameter_packet{header = H, bin = Bin},
#state{incoming_maxlen = M})
when M < size(Bin) ->
- invalid(false, incoming_maxlen_exceeded, {size(Bin), H});
+ invalid(false, incoming_maxlen_exceeded, {size(Bin), H}),
+ H;
%% Ignore anything but an expected CER/CEA if so configured. This is
%% non-standard behaviour.
-recv1(Name, _, #state{state = {'Wait-CEA', _, _},
- strict = false})
+recv1(Name, #diameter_packet{header = H}, #state{state = {'Wait-CEA', _, _},
+ strict = false})
when Name /= 'CEA' ->
- ok;
-recv1(Name, _, #state{state = recv_CER,
- strict = false})
+ H;
+recv1(Name, #diameter_packet{header = H}, #state{state = recv_CER,
+ strict = false})
when Name /= 'CER' ->
- ok;
+ H;
%% Incoming request after outgoing DPR: discard. Don't discard DPR, so
%% both ends don't do so when sending simultaneously.
@@ -677,13 +711,15 @@ recv1(Name,
#diameter_packet{header = #diameter_header{is_request = true} = H},
#state{dpr = {_,_,_}})
when Name /= 'DPR' ->
- invalid(false, recv_after_outgoing_dpr, H);
+ invalid(false, recv_after_outgoing_dpr, H),
+ H;
%% Incoming request after incoming DPR: discard.
recv1(_,
#diameter_packet{header = #diameter_header{is_request = true} = H},
#state{dpr = true}) ->
- invalid(false, recv_after_incoming_dpr, H);
+ invalid(false, recv_after_incoming_dpr, H),
+ H;
%% DPA with identifier mismatch, or in response to a DPR initiated by
%% the service.
@@ -701,7 +737,7 @@ recv1('DPA' = N,
%% Any other message with a header and no length errors: send to the
%% parent.
recv1(Name, Pkt, #state{}) ->
- throw({?MODULE, Name, Pkt}).
+ {recv, Name, Pkt}.
%% recv/3
@@ -720,10 +756,12 @@ recv(#diameter_header{}
#diameter_packet{bin = Bin},
#state{length_errors = E}) ->
T = {size(Bin), bit_size(Bin) rem 8, H},
- invalid(E, message_length_mismatch, T);
+ invalid(E, message_length_mismatch, T),
+ Bin;
recv(false, #diameter_packet{bin = Bin}, #state{length_errors = E}) ->
- invalid(E, truncated_header, Bin).
+ invalid(E, truncated_header, Bin),
+ Bin.
%% Note that counters here only count discarded messages.
invalid(E, Reason, T) ->
@@ -767,26 +805,23 @@ rcv('DPA' = N,
= Pkt,
#state{dictionary = Dict0,
transport = TPid,
- dpr = {X, Hid, Eid}}) ->
+ dpr = {X, Hid, Eid},
+ codec = Opts}) ->
?LOG(recv, N),
X orelse begin
%% Only count DPA in response to a DPR sent by the
%% service: explicit DPR is counted in the same way
%% as other explicitly sent requests.
incr(recv, H, Dict0),
- incr_rc(recv, diameter_codec:decode(Dict0, Pkt), Dict0)
+ incr_rc(recv, diameter_codec:decode(Dict0, Opts, Pkt), Dict0)
end,
diameter_peer:close(TPid),
{stop, N};
-%% Ignore anything else, an unsolicited DPA in particular. Note that
-%% dpa_timeout deals with the case in which the peer sends the wrong
-%% identifiers in DPA.
-rcv(N, #diameter_packet{header = H}, _)
- when N == 'CER';
- N == 'CEA';
- N == 'DPR';
- N == 'DPA' ->
+%% Ignore an unsolicited DPA in particular. Note that dpa_timeout
+%% deals with the case in which the peer sends the wrong identifiers
+%% in DPA.
+rcv('DPA' = N, #diameter_packet{header = H}, _) ->
?LOG(ignored, N),
%% Note that these aren't counted in the normal recv counter.
diameter_stats:incr({diameter_codec:msg_id(H), recv, ignored}),
@@ -839,7 +874,7 @@ outgoing(#diameter_packet{header = #diameter_header{application_id = 0,
invalid(false, dpr_after_dpr, H) %% DPR sent: discard
end;
-%% Explict CER or DWR: discard. These are sent by us.
+%% Explicit CER or DWR: discard. These are sent by us.
outgoing(#diameter_packet{header = #diameter_header{application_id = 0,
cmd_code = C,
is_request = true}
@@ -875,15 +910,21 @@ header(Bin) -> %% DWR
%% Incoming CER or DPR.
handle_request(Name,
- #diameter_packet{header = H} = Pkt,
- #state{dictionary = Dict0} = S) ->
+ #diameter_packet{header = H}
+ = Pkt,
+ #state{dictionary = Dict0,
+ codec = Opts}
+ = S) ->
?LOG(recv, Name),
incr(recv, H, Dict0),
- send_answer(Name, diameter_codec:decode(Dict0, Pkt), S).
+ send_answer(Name, diameter_codec:decode(Dict0, Opts, Pkt), S).
%% send_answer/3
-send_answer(Type, ReqPkt, #state{transport = TPid, dictionary = Dict} = S) ->
+send_answer(Type, ReqPkt, #state{transport = TPid,
+ dictionary = Dict,
+ codec = Opts}
+ = S) ->
incr_error(recv, ReqPkt, Dict),
#diameter_packet{header = H,
@@ -902,7 +943,7 @@ send_answer(Type, ReqPkt, #state{transport = TPid, dictionary = Dict} = S) ->
msg = Msg,
transport_data = TD},
- AnsPkt = diameter_codec:encode(Dict, Pkt),
+ AnsPkt = diameter_codec:encode(Dict, Opts, Pkt),
incr(send, AnsPkt, Dict),
incr_rc(send, AnsPkt, Dict),
@@ -929,8 +970,6 @@ build_answer('CER',
= Pkt,
#state{dictionary = Dict0}
= S) ->
- diameter_codec:setopts([{string_decode, false}]),
-
{SupportedApps, RCaps, CEA} = recv_CER(CER, S),
[RC, IS] = Dict0:'#get-'(['Result-Code', 'Inband-Security-Id'], CEA),
@@ -1131,18 +1170,16 @@ recv_CER(CER, #state{service = Svc, dictionary = Dict}) ->
handle_CEA(#diameter_packet{header = H}
= Pkt,
#state{dictionary = Dict0,
- service = #diameter_service{capabilities = LCaps}}
+ service = #diameter_service{capabilities = LCaps},
+ codec = Opts}
= S) ->
incr(recv, H, Dict0),
#diameter_packet{}
= DPkt
- = diameter_codec:decode(Dict0, Pkt),
-
- diameter_codec:setopts([{string_decode, false}]),
+ = diameter_codec:decode(Dict0, Opts, Pkt),
RC = result_code(incr_rc(recv, DPkt, Dict0)),
-
{SApps, IS, RCaps} = recv_CEA(DPkt, S),
#diameter_caps{origin_host = {OH, DH}}
@@ -1330,8 +1367,9 @@ dpr([], [Reason | _], S) ->
-record(opts, {cause, timeout}).
-send_dpr(Reason, Opts, #state{dictionary = Dict,
- service = #diameter_service{capabilities = Caps}}
+send_dpr(Reason, DprOpts, #state{dictionary = Dict,
+ service = #diameter_service{capabilities = Caps},
+ codec = Opts}
= S) ->
#opts{cause = Cause, timeout = Tmo}
= lists:foldl(fun opt/2,
@@ -1340,7 +1378,7 @@ send_dpr(Reason, Opts, #state{dictionary = Dict,
_ -> ?REBOOT
end,
timeout = dpa_timeout()},
- Opts),
+ DprOpts),
#diameter_caps{origin_host = {OH, _},
origin_realm = {OR, _}}
= Caps,
@@ -1348,6 +1386,7 @@ send_dpr(Reason, Opts, #state{dictionary = Dict,
Pkt = encode(['DPR', {'Origin-Host', OH},
{'Origin-Realm', OR},
{'Disconnect-Cause', Cause}],
+ Opts,
Dict),
send_dpr(false, Pkt, Tmo, S).