aboutsummaryrefslogtreecommitdiffstats
path: root/lib/diameter
diff options
context:
space:
mode:
Diffstat (limited to 'lib/diameter')
-rw-r--r--lib/diameter/src/base/diameter_capx.erl4
-rw-r--r--lib/diameter/src/base/diameter_peer_fsm.erl156
2 files changed, 110 insertions, 50 deletions
diff --git a/lib/diameter/src/base/diameter_capx.erl b/lib/diameter/src/base/diameter_capx.erl
index 138e76411e..293eb0c196 100644
--- a/lib/diameter/src/base/diameter_capx.erl
+++ b/lib/diameter/src/base/diameter_capx.erl
@@ -224,10 +224,6 @@ rCER(CER, #diameter_service{capabilities = LCaps} = Svc) ->
RCaps,
CEA#diameter_base_CEA{'Result-Code' = ?SUCCESS})}.
-%% TODO: 5.3 of RFC 3588 says we MUST return DIAMETER_NO_COMMON_APPLICATION
-%% in the CEA and SHOULD disconnect the transport. However, we have
-%% no way to guarantee the send before disconnecting.
-
build_CEA([], _, _, CEA) ->
CEA#diameter_base_CEA{'Result-Code' = ?NOAPP};
diff --git a/lib/diameter/src/base/diameter_peer_fsm.erl b/lib/diameter/src/base/diameter_peer_fsm.erl
index bc7625025a..3f1610b325 100644
--- a/lib/diameter/src/base/diameter_peer_fsm.erl
+++ b/lib/diameter/src/base/diameter_peer_fsm.erl
@@ -142,9 +142,12 @@ init(T) ->
proc_lib:init_ack({ok, self()}),
gen_server:enter_loop(?MODULE, [], i(T)).
-i({WPid, {M, _} = T, Opts, #diameter_service{capabilities = Caps} = Svc0}) ->
+i({WPid, T, Opts, #diameter_service{capabilities = Caps} = Svc0}) ->
putr(dwa, dwa(Caps)),
- {ok, TPid, Svc} = start_transport(T, Opts, Svc0),
+ {M, Ref} = T,
+ {[Ts], Rest} = proplists:split(Opts, [capabilities_cb]),
+ putr(capabilities_cb, {Ref, [F || {_,F} <- Ts]}),
+ {ok, TPid, Svc} = start_transport(T, Rest, Svc0),
erlang:monitor(process, TPid),
erlang:monitor(process, WPid),
#state{parent = WPid,
@@ -198,10 +201,13 @@ handle_info(T, #state{} = State) ->
?LOG(stop, T),
x(T, State)
catch
- throw: {?MODULE, Tag, Reason} ->
+ {?MODULE, Tag, Reason} ->
?LOG(Tag, {Reason, T}),
{stop, {shutdown, Reason}, State}
end.
+%% The form of the exception caught here is historical. It's
+%% significant that it's not a 2-tuple, as in ?FAILURE(Reason),
+%% since these are caught elsewhere.
x(Reason, #state{} = S) ->
close_wd(Reason, S),
@@ -226,6 +232,9 @@ putr(Key, Val) ->
getr(Key) ->
get({?MODULE, Key}).
+eraser(Key) ->
+ erase({?MODULE, Key}).
+
%% transition/2
%% Connection to peer.
@@ -467,13 +476,13 @@ send_answer(Type, ReqPkt, #state{transport = TPid} = S) ->
transport_data = TD}
= ReqPkt,
- {Answer, PostF} = build_answer(Type, V, ReqPkt, S),
+ {Msg, PostF} = build_answer(Type, V, ReqPkt, S),
Pkt = #diameter_packet{header = #diameter_header{version = V,
end_to_end_id = Eid,
hop_by_hop_id = Hid,
is_proxiable = P},
- msg = Answer,
+ msg = Msg,
transport_data = TD},
send(TPid, diameter_codec:encode(?BASE, Pkt)),
@@ -494,27 +503,31 @@ build_answer('CER',
= Pkt,
#state{service = Svc}
= S) ->
- #diameter_service{capabilities = #diameter_caps{origin_host = OH}}
+ {SupportedApps, RCaps, #diameter_base_CEA{'Result-Code' = RC,
+ 'Inband-Security-Id' = [IS]}
+ = CEA}
+ = recv_CER(CER, S),
+
+ #diameter_service{capabilities = LCaps}
= Svc,
- {SupportedApps,
- #diameter_caps{origin_host = DH} = RCaps,
- #diameter_base_CEA{'Result-Code' = RC}
- = CEA}
- = recv_CER(CER, S),
+ #diameter_caps{origin_host = {OH, DH}}
+ = Caps
+ = capz(LCaps, RCaps),
try
2001 == RC %% DIAMETER_SUCCESS
- orelse ?THROW({sent_CEA, RC}),
+ orelse ?THROW({result_code, RC}),
register_everywhere({?MODULE, connection, OH, DH})
- orelse ?THROW({election_lost, 4003}),
- #diameter_base_CEA{'Inband-Security-Id' = [IS]}
- = CEA,
- {CEA, [fun open/5, Pkt, SupportedApps, RCaps, {accept, IS}]}
+ orelse ?THROW({result_code, 4003}), %% DIAMETER_ELECTION_LOST
+ caps_cb(Caps)
+ of
+ ok -> {CEA, [fun open/5, Pkt, SupportedApps, Caps, {accept, IS}]}
catch
- ?FAILURE({Reason, RC}) ->
- {answer('CER', S) ++ [{'Result-Code', RC}],
- [fun close/2, {'CER', Reason, DH}]}
+ ?FAILURE(discard = T) ->
+ close({'CER', T, DH}, S);
+ ?FAILURE({result_code, N}) ->
+ {answer_message(cea(S), N), [fun close/2, {'CER', N, DH}]}
end;
%% The error checks below are similar to those in diameter_service for
@@ -522,18 +535,40 @@ build_answer('CER',
build_answer(Type, V, #diameter_packet{header = H, errors = Es} = Pkt, S) ->
FailedAvp = failed_avp([A || {_,A} <- Es]),
- Ans = answer(answer(Type, S), V, H, Es),
- {set(Ans, FailedAvp), if 'CER' == Type ->
+ Msg = answer_message(answer(Type, S), rc(V, H, Es)),
+ {set(Msg, FailedAvp), if 'CER' == Type ->
[fun close/2, {Type, V, Pkt}];
true ->
ok
end}.
+cea(S) ->
+ answer('CER', S).
+
+%% answer_message/2
+
+answer_message([_ | Avps], RC)
+ when 3000 =< RC, RC < 4000 ->
+ ['answer-message', {'Result-Code', RC}
+ | lists:filter(fun is_origin/1, Avps)];
+
+answer_message(Msg, RC) ->
+ Msg ++ [{'Result-Code', RC}].
+
+is_origin({N, _}) ->
+ N == 'Origin-Host'
+ orelse N == 'Origin-Realm'
+ orelse N == 'Origin-State-Id'.
+
+%% failed_avp/1
+
failed_avp([] = No) ->
No;
failed_avp(Avps) ->
[{'Failed-AVP', [[{'AVP', Avps}]]}].
+%% set/2
+
set(Ans, []) ->
Ans;
set(['answer-message' | _] = Ans, FailedAvp) ->
@@ -541,18 +576,22 @@ set(['answer-message' | _] = Ans, FailedAvp) ->
set([_|_] = Ans, FailedAvp) ->
Ans ++ FailedAvp.
-answer([_, OH, OR | _], _, #diameter_header{is_error = true}, _) ->
- ['answer-message', OH, OR, {'Result-Code', 3008}];
+%% rc/3
-answer([_, OH, OR | _], _, _, [Bs|_])
+rc(_, #diameter_header{is_error = true}, _) ->
+ 3008; %% DIAMETER_INVALID_HDR_BITS
+
+rc(_, _, [Bs|_])
when is_bitstring(Bs) ->
- ['answer-message', OH, OR, {'Result-Code', 3009}];
+ 3009; %% DIAMETER_INVALID_HDR_BITS
+
+rc(?DIAMETER_VERSION, _, Es) ->
+ rc(Es);
-answer(Ans, ?DIAMETER_VERSION, _, Es) ->
- Ans ++ [{'Result-Code', rc(Es)}];
+rc(_, _, _) ->
+ 5011. %% DIAMETER_UNSUPPORTED_VERSION
-answer(Ans, _, _, _) ->
- Ans ++ [{'Result-Code', 5011}]. %% DIAMETER_UNSUPPORTED_VERSION
+%% rc/1
rc([]) ->
2001; %% DIAMETER_SUCCESS
@@ -595,12 +634,14 @@ a('CER', #diameter_caps{vendor_id = Vid,
origin_host = Host,
origin_realm = Realm,
host_ip_address = Addrs,
- product_name = Name}) ->
+ product_name = Name,
+ origin_state_id = OSI}) ->
['CEA', {'Origin-Host', Host},
{'Origin-Realm', Realm},
{'Host-IP-Address', Addrs},
{'Vendor-Id', Vid},
- {'Product-Name', Name}];
+ {'Product-Name', Name},
+ {'Origin-State-Id', OSI}];
a('DPR', #diameter_caps{origin_host = Host,
origin_realm = Realm}) ->
@@ -631,11 +672,11 @@ handle_CEA(#diameter_packet{header = #diameter_header{version = V},
[] == Errors orelse close({errors, Errors}, S),
- {SApps, [IS], #diameter_caps{origin_host = DH} = RCaps}
- = recv_CEA(CEA, S),
+ {SApps, [IS], RCaps} = recv_CEA(CEA, S),
- #diameter_caps{origin_host = OH}
- = LCaps,
+ #diameter_caps{origin_host = {OH, DH}}
+ = Caps
+ = capz(LCaps, RCaps),
%% Ensure that we don't already have a connection to the peer in
%% question. This isn't the peer election of 3588 except in the
@@ -646,7 +687,11 @@ handle_CEA(#diameter_packet{header = #diameter_header{version = V},
register_everywhere({?MODULE, connection, OH, DH})
orelse close({'CEA', DH}, S),
- open(DPkt, SApps, RCaps, {connect, IS}, S).
+ try caps_cb(Caps) of
+ ok -> open(DPkt, SApps, Caps, {connect, IS}, S)
+ catch
+ ?FAILURE(Reason) -> close(Reason, S)
+ end.
%% recv_CEA/2
@@ -664,20 +709,39 @@ recv_CEA(CEA, #state{service = Svc} = S) ->
close({'CEA', Reason}, S)
end.
+%% caps_cb/1
+
+caps_cb(Caps) ->
+ {Ref, Ts} = eraser(capabilities_cb),
+ ccb(Ts, [Ref, Caps]).
+
+ccb([], _) ->
+ ok;
+ccb([F | Rest], T) ->
+ case diameter_lib:eval([F|T]) of
+ ok ->
+ ccb(Rest, T);
+ Res ->
+ ?THROW({{capabilities_cb, F}, rejected(Res)})
+ end.
+
+rejected({result_code, N} = T)
+ when 1000 =< N, N < 6000 ->
+ T;
+rejected(discard = T) ->
+ T;
+rejected(unknown) ->
+ {result_code, 3010}. %% DIAMETER_UNKNOWN_PEER
+
%% open/5
-open(Pkt, SupportedApps, RCaps, {Type, IS}, #state{parent = Pid,
- service = Svc}
- = S) ->
- #diameter_service{capabilities = #diameter_caps{origin_host = OH,
- inband_security_id = LS}
- = LCaps}
- = Svc,
- #diameter_caps{origin_host = DH}
- = RCaps,
+open(Pkt, SupportedApps, Caps, {Type, IS}, #state{parent = Pid} = S) ->
+ #diameter_caps{origin_host = {_,_} = H,
+ inband_security_id = {LS,_}}
+ = Caps,
tls_ack(lists:member(?TLS, LS), Type, IS, S),
- Pid ! {open, self(), {OH,DH}, {capz(LCaps, RCaps), SupportedApps, Pkt}},
+ Pid ! {open, self(), H, {Caps, SupportedApps, Pkt}},
S#state{state = 'Open'}.