diff options
Diffstat (limited to 'lib')
60 files changed, 4094 insertions, 3709 deletions
diff --git a/lib/compiler/src/sys_core_fold.erl b/lib/compiler/src/sys_core_fold.erl index cbf6e256f7..7acf08129a 100644 --- a/lib/compiler/src/sys_core_fold.erl +++ b/lib/compiler/src/sys_core_fold.erl @@ -1458,8 +1458,19 @@ sub_add_scope(Vs, #sub{s=Scope0}=Sub) -> Sub#sub{s=Scope}. sub_subst_scope(#sub{v=S0,s=Scope}=Sub) -> - S = [{-1,#c_var{name=Sv}} || Sv <- cerl_sets:to_list(Scope)]++S0, - Sub#sub{v=S}. + Initial = case S0 of + [{NegInt,_}|_] when is_integer(NegInt), NegInt < 0 -> + NegInt - 1; + _ -> + -1 + end, + S = sub_subst_scope_1(cerl_sets:to_list(Scope), Initial, S0), + Sub#sub{v=orddict:from_list(S)}. + +%% The keys in an orddict must be unique. Make them so! +sub_subst_scope_1([H|T], Key, Acc) -> + sub_subst_scope_1(T, Key-1, [{Key,#c_var{name=H}}|Acc]); +sub_subst_scope_1([], _, Acc) -> Acc. sub_is_val(#c_var{name=V}, #sub{v=S,s=Scope}) -> %% When the bottleneck in sub_del_var/2 was eliminated, this diff --git a/lib/compiler/src/v3_kernel.erl b/lib/compiler/src/v3_kernel.erl index 004c609311..1fc05109c5 100644 --- a/lib/compiler/src/v3_kernel.erl +++ b/lib/compiler/src/v3_kernel.erl @@ -1313,23 +1313,26 @@ get_vsub(V, Vsub) -> set_vsub(V, S, Vsub) -> orddict:store(V, S, Vsub). -subst_vsub(Key, New, [{K,Key}|Dict]) -> +subst_vsub(Key, New, Vsub) -> + orddict:from_list(subst_vsub_1(Key, New, Vsub)). + +subst_vsub_1(Key, New, [{K,Key}|Dict]) -> %% Fold chained substitution. - [{K,New}|subst_vsub(Key, New, Dict)]; -subst_vsub(Key, New, [{K,_}|_]=Dict) when Key < K -> + [{K,New}|subst_vsub_1(Key, New, Dict)]; +subst_vsub_1(Key, New, [{K,_}|_]=Dict) when Key < K -> %% Insert the new substitution here, and continue %% look for chained substitutions. - [{Key,New}|subst_vsub_1(Key, New, Dict)]; -subst_vsub(Key, New, [{K,_}=E|Dict]) when Key > K -> - [E|subst_vsub(Key, New, Dict)]; -subst_vsub(Key, New, []) -> [{Key,New}]. + [{Key,New}|subst_vsub_2(Key, New, Dict)]; +subst_vsub_1(Key, New, [{K,_}=E|Dict]) when Key > K -> + [E|subst_vsub_1(Key, New, Dict)]; +subst_vsub_1(Key, New, []) -> [{Key,New}]. -subst_vsub_1(V, S, [{K,V}|Dict]) -> +subst_vsub_2(V, S, [{K,V}|Dict]) -> %% Fold chained substitution. - [{K,S}|subst_vsub_1(V, S, Dict)]; -subst_vsub_1(V, S, [E|Dict]) -> - [E|subst_vsub_1(V, S, Dict)]; -subst_vsub_1(_, _, []) -> []. + [{K,S}|subst_vsub_2(V, S, Dict)]; +subst_vsub_2(V, S, [E|Dict]) -> + [E|subst_vsub_2(V, S, Dict)]; +subst_vsub_2(_, _, []) -> []. get_fsub(F, A, Fsub) -> case orddict:find({F,A}, Fsub) of diff --git a/lib/compiler/test/beam_type_SUITE.erl b/lib/compiler/test/beam_type_SUITE.erl index 07dad85c57..86146c614f 100644 --- a/lib/compiler/test/beam_type_SUITE.erl +++ b/lib/compiler/test/beam_type_SUITE.erl @@ -22,7 +22,7 @@ -export([all/0,suite/0,groups/0,init_per_suite/1,end_per_suite/1, init_per_group/2,end_per_group/2, integers/1,coverage/1,booleans/1,setelement/1,cons/1, - tuple/1,record_float/1,binary_float/1]). + tuple/1,record_float/1,binary_float/1,float_compare/1]). suite() -> [{ct_hooks,[ts_install_cth]}]. @@ -39,7 +39,8 @@ groups() -> cons, tuple, record_float, - binary_float + binary_float, + float_compare ]}]. init_per_suite(Config) -> @@ -151,5 +152,25 @@ binary_float(_Config) -> binary_negate_float(<<Float/float>>) -> <<-Float/float>>. +float_compare(_Config) -> + false = do_float_compare(-42.0), + false = do_float_compare(-42), + false = do_float_compare(0), + false = do_float_compare(0.0), + true = do_float_compare(42), + true = do_float_compare(42.0), + ok. + +do_float_compare(X) -> + %% ERL-433: Used to fail before OTP 20. Was accidentally fixed + %% in OTP 20. Add a test case to ensure it stays fixed. + + Y = X + 1.0, + case X > 0 of + T when (T =:= nil) or (T =:= false) -> T; + _T -> Y > 0 + end. + + id(I) -> I. diff --git a/lib/dialyzer/src/dialyzer_analysis_callgraph.erl b/lib/dialyzer/src/dialyzer_analysis_callgraph.erl index 2a2dcd55f0..a4b42c9367 100644 --- a/lib/dialyzer/src/dialyzer_analysis_callgraph.erl +++ b/lib/dialyzer/src/dialyzer_analysis_callgraph.erl @@ -92,11 +92,12 @@ loop(#server_state{parent = Parent} = State, send_warnings(Parent, Warnings), loop(State, Analysis, ExtCalls); {AnalPid, cserver, CServer, Plt} -> + skip_ets_transfer(AnalPid), send_codeserver_plt(Parent, CServer, Plt), loop(State, Analysis, ExtCalls); - {AnalPid, done, MiniPlt, DocPlt} -> + {AnalPid, done, Plt, DocPlt} -> send_ext_calls(Parent, ExtCalls), - send_analysis_done(Parent, MiniPlt, DocPlt); + send_analysis_done(Parent, Plt, DocPlt); {AnalPid, ext_calls, NewExtCalls} -> loop(State, Analysis, NewExtCalls); {AnalPid, ext_types, ExtTypes} -> @@ -133,26 +134,8 @@ analysis_start(Parent, Analysis, LegalWarnings) -> Files = ordsets:from_list(Analysis#analysis.files), {Callgraph, TmpCServer0} = compile_and_store(Files, State), %% Remote type postprocessing - NewCServer = - try - TmpCServer1 = dialyzer_utils:merge_types(TmpCServer0, Plt), - NewExpTypes = dialyzer_codeserver:get_temp_exported_types(TmpCServer0), - OldExpTypes0 = dialyzer_plt:get_exported_types(Plt), - RemMods = - [case Analysis#analysis.start_from of - byte_code -> list_to_atom(filename:basename(F, ".beam")); - src_code -> list_to_atom(filename:basename(F, ".erl")) - end || F <- Files], - OldExpTypes1 = dialyzer_utils:sets_filter(RemMods, OldExpTypes0), - MergedExpTypes = sets:union(NewExpTypes, OldExpTypes1), - TmpCServer2 = - dialyzer_codeserver:finalize_exported_types(MergedExpTypes, TmpCServer1), - erlang:garbage_collect(), % reduce heap size - ?timing(State#analysis_state.timing_server, "remote", - contracts_and_records(TmpCServer2, Parent)) - catch - throw:{error, _ErrorMsg} = Error -> exit(Error) - end, + Args = {Plt, Analysis, Parent}, + NewCServer = remote_type_postprocessing(TmpCServer0, Args), dump_callgraph(Callgraph, State, Analysis), %% Remove all old versions of the files being analyzed AllNodes = dialyzer_callgraph:all_nodes(Callgraph), @@ -168,46 +151,80 @@ analysis_start(Parent, Analysis, LegalWarnings) -> false -> Callgraph end, State2 = analyze_callgraph(NewCallgraph, State1), - #analysis_state{plt = MiniPlt2, + #analysis_state{plt = Plt2, doc_plt = DocPlt, codeserver = Codeserver0} = State2, - {Codeserver, MiniPlt3} = move_data(Codeserver0, MiniPlt2), + {Codeserver, Plt3} = move_data(Codeserver0, Plt2), dialyzer_callgraph:dispose_race_server(NewCallgraph), %% Since the PLT is never used, a dummy is sent: DummyPlt = dialyzer_plt:new(), send_codeserver_plt(Parent, Codeserver, DummyPlt), - MiniPlt4 = dialyzer_plt:delete_list(MiniPlt3, NonExportsList), - send_analysis_done(Parent, MiniPlt4, DocPlt). - -contracts_and_records(CodeServer, Parent) -> - Fun = contrs_and_recs(CodeServer, Parent), + dialyzer_plt:delete(DummyPlt), + Plt4 = dialyzer_plt:delete_list(Plt3, NonExportsList), + send_analysis_done(Parent, Plt4, DocPlt). + +remote_type_postprocessing(TmpCServer, Args) -> + Fun = fun() -> + exit(remote_type_postproc(TmpCServer, Args)) + end, {Pid, Ref} = erlang:spawn_monitor(Fun), - dialyzer_codeserver:give_away(CodeServer, Pid), + dialyzer_codeserver:give_away(TmpCServer, Pid), Pid ! {self(), go}, receive {'DOWN', Ref, process, Pid, Return} -> - Return + skip_ets_transfer(Pid), + case Return of + {error, _ErrorMsg} = Error -> exit(Error); + _ -> Return + end end. --spec contrs_and_recs(dialyzer_codeserver:codeserver(), pid()) -> - fun(() -> no_return()). - -contrs_and_recs(TmpCServer2, Parent) -> +remote_type_postproc(TmpCServer0, Args) -> + {Plt, Analysis, Parent} = Args, fun() -> Caller = receive {Pid, go} -> Pid end, - TmpCServer3 = dialyzer_utils:process_record_remote_types(TmpCServer2), + TmpCServer1 = dialyzer_utils:merge_types(TmpCServer0, Plt), + NewExpTypes = dialyzer_codeserver:get_temp_exported_types(TmpCServer0), + OldExpTypes0 = dialyzer_plt:get_exported_types(Plt), + #analysis{start_from = StartFrom, + timing_server = TimingServer} = Analysis, + Files = ordsets:from_list(Analysis#analysis.files), + RemMods = + [case StartFrom of + byte_code -> list_to_atom(filename:basename(F, ".beam")); + src_code -> list_to_atom(filename:basename(F, ".erl")) + end || F <- Files], + OldExpTypes1 = dialyzer_utils:sets_filter(RemMods, OldExpTypes0), + MergedExpTypes = sets:union(NewExpTypes, OldExpTypes1), + TmpCServer2 = + dialyzer_codeserver:finalize_exported_types(MergedExpTypes, + TmpCServer1), TmpServer4 = - dialyzer_contracts:process_contract_remote_types(TmpCServer3), - dialyzer_codeserver:give_away(TmpServer4, Caller), + ?timing + (TimingServer, "remote", + begin + TmpCServer3 = + dialyzer_utils:process_record_remote_types(TmpCServer2), + dialyzer_contracts:process_contract_remote_types(TmpCServer3) + end), rcv_and_send_ext_types(Caller, Parent), - exit(TmpServer4) + dialyzer_codeserver:give_away(TmpServer4, Caller), + TmpServer4 + end(). + +skip_ets_transfer(Pid) -> + receive + {'ETS-TRANSFER', _Tid, Pid, _HeriData} -> + skip_ets_transfer(Pid) + after 0 -> + ok end. -move_data(CServer, MiniPlt) -> +move_data(CServer, Plt) -> {CServer1, Records} = dialyzer_codeserver:extract_records(CServer), - MiniPlt1 = dialyzer_plt:insert_types(MiniPlt, Records), + Plt1 = dialyzer_plt:insert_types(Plt, Records), {NewCServer, ExpTypes} = dialyzer_codeserver:extract_exported_types(CServer1), - NewMiniPlt = dialyzer_plt:insert_exported_types(MiniPlt1, ExpTypes), - {NewCServer, NewMiniPlt}. + NewPlt = dialyzer_plt:insert_exported_types(Plt1, ExpTypes), + {NewCServer, NewPlt}. analyze_callgraph(Callgraph, #analysis_state{codeserver = Codeserver, doc_plt = DocPlt, @@ -217,19 +234,19 @@ analyze_callgraph(Callgraph, #analysis_state{codeserver = Codeserver, solvers = Solvers} = State) -> case State#analysis_state.analysis_type of plt_build -> - NewMiniPlt = + NewPlt = dialyzer_succ_typings:analyze_callgraph(Callgraph, Plt, Codeserver, TimingServer, Solvers, Parent), dialyzer_callgraph:delete(Callgraph), - State#analysis_state{plt = NewMiniPlt, doc_plt = DocPlt}; + State#analysis_state{plt = NewPlt, doc_plt = DocPlt}; succ_typings -> - {Warnings, NewMiniPlt, NewDocPlt} = + {Warnings, NewPlt, NewDocPlt} = dialyzer_succ_typings:get_warnings(Callgraph, Plt, DocPlt, Codeserver, TimingServer, Solvers, Parent), dialyzer_callgraph:delete(Callgraph), Warnings1 = filter_warnings(Warnings, Codeserver), send_warnings(State#analysis_state.parent, Warnings1), - State#analysis_state{plt = NewMiniPlt, doc_plt = NewDocPlt} + State#analysis_state{plt = NewPlt, doc_plt = NewDocPlt} end. %%-------------------------------------------------------------------- @@ -565,9 +582,8 @@ is_ok_fun({_Filename, _Line, {_M, _F, _A} = MFA}, Codeserver) -> is_ok_tag(Tag, {_F, _L, MorMFA}, Codeserver) -> not dialyzer_utils:is_suppressed_tag(MorMFA, Tag, Codeserver). -send_analysis_done(Parent, MiniPlt, DocPlt) -> - ok = dialyzer_plt:give_away(MiniPlt, Parent), - Parent ! {self(), done, MiniPlt, DocPlt}, +send_analysis_done(Parent, Plt, DocPlt) -> + Parent ! {self(), done, Plt, DocPlt}, ok. send_ext_calls(_Parent, none) -> diff --git a/lib/dialyzer/src/dialyzer_callgraph.erl b/lib/dialyzer/src/dialyzer_callgraph.erl index a83a0bda59..7411b1d28b 100644 --- a/lib/dialyzer/src/dialyzer_callgraph.erl +++ b/lib/dialyzer/src/dialyzer_callgraph.erl @@ -778,7 +778,6 @@ to_ps(#callgraph{} = CG, File, Args) -> ok. condensation(G) -> - erlang:garbage_collect(), % reduce heap size {Pid, Ref} = erlang:spawn_monitor(do_condensation(G, self())), receive {'DOWN', Ref, process, Pid, Result} -> {SCCInts, OutETS, InETS, MapsETS} = Result, diff --git a/lib/dialyzer/src/dialyzer_cl.erl b/lib/dialyzer/src/dialyzer_cl.erl index d72ae1dc86..0617be6435 100644 --- a/lib/dialyzer/src/dialyzer_cl.erl +++ b/lib/dialyzer/src/dialyzer_cl.erl @@ -637,8 +637,8 @@ cl_loop(State, LogCache) -> {BackendPid, cserver, CodeServer, _Plt} -> % Plt is ignored NewState = State#cl_state{code_server = CodeServer}, cl_loop(NewState, LogCache); - {BackendPid, done, NewMiniPlt, _NewDocPlt} -> - return_value(State, NewMiniPlt); + {BackendPid, done, NewPlt, _NewDocPlt} -> + return_value(State, NewPlt); {BackendPid, ext_calls, ExtCalls} -> cl_loop(State#cl_state{external_calls = ExtCalls}, LogCache); {BackendPid, ext_types, ExtTypes} -> @@ -700,7 +700,7 @@ return_value(State = #cl_state{code_server = CodeServer, output_plt = OutputPlt, plt_info = PltInfo, stored_warnings = StoredWarnings}, - MiniPlt) -> + Plt) -> %% Just for now: case CodeServer =:= none of true -> @@ -710,18 +710,9 @@ return_value(State = #cl_state{code_server = CodeServer, end, case OutputPlt =:= none of true -> - dialyzer_plt:delete(MiniPlt); + dialyzer_plt:delete(Plt); false -> - Fun = to_file_fun(OutputPlt, MiniPlt, ModDeps, PltInfo), - {Pid, Ref} = erlang:spawn_monitor(Fun), - dialyzer_plt:give_away(MiniPlt, Pid), - Pid ! go, - receive {'DOWN', Ref, process, Pid, Result} -> - case Result of - ok -> ok; - Thrown -> throw(Thrown) - end - end + dialyzer_plt:to_file(OutputPlt, Plt, ModDeps, PltInfo) end, UnknownWarnings = unknown_warnings(State), RetValue = @@ -742,16 +733,6 @@ return_value(State = #cl_state{code_server = CodeServer, {RetValue, set_warning_id(AllWarnings)} end. --spec to_file_fun(_, _, _, _) -> fun(() -> no_return()). - -to_file_fun(Filename, MiniPlt, ModDeps, PltInfo) -> - fun() -> - receive go -> ok end, - Plt = dialyzer_plt:restore_full_plt(MiniPlt), - dialyzer_plt:to_file(Filename, Plt, ModDeps, PltInfo), - exit(ok) - end. - unknown_warnings(State = #cl_state{legal_warnings = LegalWarnings}) -> Unknown = case ordsets:is_element(?WARN_UNKNOWN, LegalWarnings) of true -> diff --git a/lib/dialyzer/src/dialyzer_codeserver.erl b/lib/dialyzer/src/dialyzer_codeserver.erl index a1a7370eff..5587cf2bdf 100644 --- a/lib/dialyzer/src/dialyzer_codeserver.erl +++ b/lib/dialyzer/src/dialyzer_codeserver.erl @@ -304,9 +304,29 @@ lookup_temp_mod_records(Mod, #codeserver{temp_records = TempRecDict}) -> finalize_records(#codeserver{temp_records = TmpRecords, records = Records} = CS) -> - true = ets:delete(Records), - ets:rename(TmpRecords, dialyzer_codeserver_records), - CS#codeserver{temp_records = clean, records = TmpRecords}. + %% The annotations of the abstract code are reset as they are no + %% longer needed, which makes the ETS table compression better. + A0 = erl_anno:new(0), + AFun = fun(_) -> A0 end, + FFun = fun({F, Abs, Type}) -> + NewAbs = erl_parse:map_anno(AFun, Abs), + {F, NewAbs, Type} + end, + ArFun = fun({Arity, Fields}) -> {Arity, lists:map(FFun, Fields)} end, + List = dialyzer_utils:ets_tab2list(TmpRecords), + true = ets:delete(TmpRecords), + Fun = fun({Mod, Map}) -> + MFun = + fun({record, _}, {FileLine, ArityFields}) -> + {FileLine, lists:map(ArFun, ArityFields)}; + (_, {{M, FileLine, Abs, Args}, Type}) -> + {{M, FileLine, erl_parse:map_anno(AFun, Abs), Args}, Type} + end, + {Mod, maps:map(MFun, Map)} + end, + NewList = lists:map(Fun, List), + true = ets:insert(Records, NewList), + CS#codeserver{temp_records = clean}. -spec lookup_mod_contracts(atom(), codeserver()) -> contracts(). @@ -355,7 +375,7 @@ store_temp_contracts(Mod, SpecMap, CallbackMap, #codeserver{temp_contracts = Cn, temp_callbacks = Cb} = CS) when is_atom(Mod) -> - %% Make sure Mod is stored even if there are not callbacks or + %% Make sure Mod is stored even if there are no callbacks or %% contracts. CS1 = CS#codeserver{temp_contracts = ets_map_store(Mod, SpecMap, Cn)}, CS1#codeserver{temp_callbacks = ets_map_store(Mod, CallbackMap, Cb)}. diff --git a/lib/dialyzer/src/dialyzer_contracts.erl b/lib/dialyzer/src/dialyzer_contracts.erl index 300af7956d..b554ebc2cc 100644 --- a/lib/dialyzer/src/dialyzer_contracts.erl +++ b/lib/dialyzer/src/dialyzer_contracts.erl @@ -18,7 +18,7 @@ check_contracts/4, contracts_without_fun/3, contract_to_string/1, - get_invalid_contract_warnings/4, + get_invalid_contract_warnings/3, get_contract_args/1, get_contract_return/1, get_contract_return/2, @@ -173,22 +173,20 @@ process_contract_remote_types(CodeServer) -> lists:foreach(ModuleFun, Mods), dialyzer_codeserver:finalize_contracts(CodeServer). --type opaques_fun() :: fun((module()) -> [erl_types:erl_type()]). +-type fun_types() :: orddict:orddict(label(), erl_types:type_table()). --type fun_types() :: dict:dict(label(), erl_types:type_table()). - --spec check_contracts(orddict:orddict(mfa(), file_contract()), +-spec check_contracts(orddict:orddict(mfa(), #contract{}), dialyzer_callgraph:callgraph(), fun_types(), - opaques_fun()) -> plt_contracts(). + erl_types:opaques()) -> plt_contracts(). -check_contracts(Contracts, Callgraph, FunTypes, FindOpaques) -> +check_contracts(Contracts, Callgraph, FunTypes, ModOpaques) -> FoldFun = - fun(Label, Type, NewContracts) -> + fun({Label, Type}, NewContracts) -> case dialyzer_callgraph:lookup_name(Label, Callgraph) of {ok, {M,F,A} = MFA} -> case orddict:find(MFA, Contracts) of - {ok, {_FileLine, Contract, _Xtra}} -> - Opaques = FindOpaques(M), + {ok, Contract} -> + {M, Opaques} = lists:keyfind(M, 1, ModOpaques), case check_contract(Contract, Type, Opaques) of ok -> case erl_bif_types:is_known(M, F, A) of @@ -206,7 +204,7 @@ check_contracts(Contracts, Callgraph, FunTypes, FindOpaques) -> error -> NewContracts end end, - orddict:from_list(dict:fold(FoldFun, [], FunTypes)). + orddict:from_list(lists:foldl(FoldFun, [], orddict:to_list(FunTypes))). %% Checks all components of a contract -spec check_contract(#contract{}, erl_types:erl_type()) -> 'ok' | {'error', term()}. @@ -214,6 +212,9 @@ check_contracts(Contracts, Callgraph, FunTypes, FindOpaques) -> check_contract(Contract, SuccType) -> check_contract(Contract, SuccType, 'universe'). +-spec check_contract(#contract{}, erl_types:erl_type(), erl_types:opaques()) -> + 'ok' | {'error', term()}. + check_contract(#contract{contracts = Contracts}, SuccType, Opaques) -> try Contracts1 = [{Contract, insert_constraints(Constraints)} @@ -662,32 +663,37 @@ general_domain([], AccSig) -> -spec get_invalid_contract_warnings([module()], dialyzer_codeserver:codeserver(), - dialyzer_plt:plt(), - opaques_fun()) -> [raw_warning()]. + dialyzer_plt:plt()) -> [raw_warning()]. -get_invalid_contract_warnings(Modules, CodeServer, Plt, FindOpaques) -> - get_invalid_contract_warnings_modules(Modules, CodeServer, Plt, FindOpaques, []). +get_invalid_contract_warnings(Modules, CodeServer, Plt) -> + get_invalid_contract_warnings_modules(Modules, CodeServer, Plt, []). -get_invalid_contract_warnings_modules([Mod|Mods], CodeServer, Plt, FindOpaques, Acc) -> +get_invalid_contract_warnings_modules([Mod|Mods], CodeServer, Plt, Acc) -> Contracts1 = dialyzer_codeserver:lookup_mod_contracts(Mod, CodeServer), - Contracts2 = maps:to_list(Contracts1), - Records = dialyzer_codeserver:lookup_mod_records(Mod, CodeServer), - NewAcc = get_invalid_contract_warnings_funs(Contracts2, Plt, Records, FindOpaques, Acc), - get_invalid_contract_warnings_modules(Mods, CodeServer, Plt, FindOpaques, NewAcc); -get_invalid_contract_warnings_modules([], _CodeServer, _Plt, _FindOpaques, Acc) -> + NewAcc = + case maps:size(Contracts1) =:= 0 of + true -> Acc; + false -> + Contracts2 = maps:to_list(Contracts1), + Records = dialyzer_codeserver:lookup_mod_records(Mod, CodeServer), + Opaques = erl_types:t_opaque_from_records(Records), + get_invalid_contract_warnings_funs(Contracts2, Plt, Records, + Opaques, Acc) + end, + get_invalid_contract_warnings_modules(Mods, CodeServer, Plt, NewAcc); +get_invalid_contract_warnings_modules([], _CodeServer, _Plt, Acc) -> Acc. get_invalid_contract_warnings_funs([{MFA, {FileLine, Contract, _Xtra}}|Left], - Plt, RecDict, FindOpaques, Acc) -> + Plt, RecDict, Opaques, Acc) -> case dialyzer_plt:lookup(Plt, MFA) of none -> %% This must be a contract for a non-available function. Just accept it. - get_invalid_contract_warnings_funs(Left, Plt, RecDict, FindOpaques, Acc); + get_invalid_contract_warnings_funs(Left, Plt, RecDict, Opaques, Acc); {value, {Ret, Args}} -> Sig = erl_types:t_fun(Args, Ret), {M, _F, _A} = MFA, %% io:format("MFA ~tp~n", [MFA]), - Opaques = FindOpaques(M), {File, Line} = FileLine, WarningInfo = {File, Line, MFA}, NewAcc = @@ -741,9 +747,9 @@ get_invalid_contract_warnings_funs([{MFA, {FileLine, Contract, _Xtra}}|Left], RecDict, Acc) end end, - get_invalid_contract_warnings_funs(Left, Plt, RecDict, FindOpaques, NewAcc) + get_invalid_contract_warnings_funs(Left, Plt, RecDict, Opaques, NewAcc) end; -get_invalid_contract_warnings_funs([], _Plt, _RecDict, _FindOpaques, Acc) -> +get_invalid_contract_warnings_funs([], _Plt, _RecDict, _Opaques, Acc) -> Acc. invalid_contract_warning({M, F, A}, WarningInfo, SuccType, RecDict) -> diff --git a/lib/dialyzer/src/dialyzer_dataflow.erl b/lib/dialyzer/src/dialyzer_dataflow.erl index 46a8f01360..8367432ac5 100644 --- a/lib/dialyzer/src/dialyzer_dataflow.erl +++ b/lib/dialyzer/src/dialyzer_dataflow.erl @@ -138,7 +138,7 @@ %%-------------------------------------------------------------------- --type fun_types() :: dict:dict(label(), type()). +-type fun_types() :: orddict:orddict(label(), type()). -spec get_warnings(cerl:c_module(), dialyzer_plt:plt(), dialyzer_callgraph:callgraph(), @@ -3317,7 +3317,9 @@ state__clean_not_called(#state{fun_tab = FunTab} = State) -> state__all_fun_types(State) -> #state{fun_tab = FunTab} = state__clean_not_called(State), Tab1 = dict:erase(top, FunTab), - dict:map(fun(_Fun, {Args, Ret}) -> t_fun(Args, Ret)end, Tab1). + List = [{Fun, t_fun(Args, Ret)} || + {Fun, {Args, Ret}} <- dict:to_list(Tab1)], + orddict:from_list(List). state__fun_type(Fun, #state{fun_tab = FunTab}) -> Label = diff --git a/lib/dialyzer/src/dialyzer_gui_wx.erl b/lib/dialyzer/src/dialyzer_gui_wx.erl index bcaeca4cdc..538327d4d1 100644 --- a/lib/dialyzer/src/dialyzer_gui_wx.erl +++ b/lib/dialyzer/src/dialyzer_gui_wx.erl @@ -498,9 +498,9 @@ gui_loop(#gui_state{backend_pid = BackendPid, doc_plt = DocPlt, end, ExplanationPid = spawn_link(Fun), gui_loop(State#gui_state{expl_pid = ExplanationPid}); - {BackendPid, done, NewMiniPlt, NewDocPlt} -> + {BackendPid, done, NewPlt, NewDocPlt} -> message(State, "Analysis done"), - dialyzer_plt:delete(NewMiniPlt), + dialyzer_plt:delete(NewPlt), config_gui_stop(State), gui_loop(State#gui_state{doc_plt = NewDocPlt}); {'EXIT', BackendPid, {error, Reason}} -> diff --git a/lib/dialyzer/src/dialyzer_plt.erl b/lib/dialyzer/src/dialyzer_plt.erl index f36a008739..47994fc35b 100644 --- a/lib/dialyzer/src/dialyzer_plt.erl +++ b/lib/dialyzer/src/dialyzer_plt.erl @@ -39,6 +39,7 @@ insert_types/2, insert_exported_types/2, lookup/2, + is_contract/2, lookup_contract/2, lookup_callbacks/2, lookup_module/2, @@ -49,10 +50,7 @@ get_specs/1, get_specs/4, to_file/4, - get_mini_plt/1, - restore_full_plt/1, - delete/1, - give_away/2 + delete/1 ]). %% Debug utilities @@ -60,6 +58,8 @@ -export_type([plt/0, plt_info/0]). +-include_lib("stdlib/include/ms_transform.hrl"). + %%---------------------------------------------------------------------- -type mod_deps() :: dialyzer_callgraph:mod_deps(). @@ -75,20 +75,16 @@ %%---------------------------------------------------------------------- --record(plt, {info = table_new() :: dict:dict(), - types = table_new() :: erl_types:mod_records(), - contracts = table_new() :: dict:dict(), - callbacks = table_new() :: dict:dict(), - exported_types = sets:new() :: sets:set()}). - --record(mini_plt, {info :: ets:tid(), - types :: ets:tid(), - contracts :: ets:tid(), - callbacks :: ets:tid(), - exported_types :: ets:tid() - }). +-record(plt, {info :: ets:tid(), %% {mfa() | integer(), ret_args_types()} + types :: ets:tid(), %% {module(), erl_types:type_table()} + contracts :: ets:tid(), %% {mfa(), #contract{}} + callbacks :: ets:tid(), %% {module(), + %% [{mfa(), + %% dialyzer_contracts:file_contract()}] + exported_types :: ets:tid() %% {module(), sets:set()} + }). --opaque plt() :: #plt{} | #mini_plt{}. +-opaque plt() :: #plt{}. -include("dialyzer.hrl"). @@ -110,7 +106,17 @@ -spec new() -> plt(). new() -> - #plt{}. + [ETSInfo, ETSContracts] = + [ets:new(Name, [public]) || + Name <- [plt_info, plt_contracts]], + [ETSTypes, ETSCallbacks, ETSExpTypes] = + [ets:new(Name, [compressed, public]) || + Name <- [plt_types, plt_callbacks, plt_exported_types]], + #plt{info = ETSInfo, + types = ETSTypes, + contracts = ETSContracts, + callbacks = ETSCallbacks, + exported_types = ETSExpTypes}. -spec delete_module(plt(), atom()) -> plt(). @@ -121,58 +127,55 @@ delete_module(#plt{info = Info, types = Types, #plt{info = table_delete_module(Info, Mod), types = table_delete_module2(Types, Mod), contracts = table_delete_module(Contracts, Mod), - callbacks = table_delete_module(Callbacks, Mod), + callbacks = table_delete_module2(Callbacks, Mod), exported_types = table_delete_module1(ExpTypes, Mod)}. -spec delete_list(plt(), [mfa() | integer()]) -> plt(). -delete_list(#mini_plt{info = Info, - contracts = Contracts}=Plt, List) -> - Plt#mini_plt{info = ets_table_delete_list(Info, List), - contracts = ets_table_delete_list(Contracts, List)}; -delete_list(#plt{info = Info, types = Types, - contracts = Contracts, - callbacks = Callbacks, - exported_types = ExpTypes}, List) -> - #plt{info = table_delete_list(Info, List), - types = Types, - contracts = table_delete_list(Contracts, List), - callbacks = Callbacks, - exported_types = ExpTypes}. +delete_list(#plt{info = Info, + contracts = Contracts}=Plt, List) -> + Plt#plt{info = ets_table_delete_list(Info, List), + contracts = ets_table_delete_list(Contracts, List)}. -spec insert_contract_list(plt(), dialyzer_contracts:plt_contracts()) -> plt(). insert_contract_list(#plt{contracts = Contracts} = PLT, List) -> - NewContracts = dict:merge(fun(_MFA, _Old, New) -> New end, - Contracts, dict:from_list(List)), - PLT#plt{contracts = NewContracts}; -insert_contract_list(#mini_plt{contracts = Contracts} = PLT, List) -> true = ets:insert(Contracts, List), PLT. -spec insert_callbacks(plt(), dialyzer_codeserver:codeserver()) -> plt(). insert_callbacks(#plt{callbacks = Callbacks} = Plt, Codeserver) -> - List = dialyzer_codeserver:get_callbacks(Codeserver), - Plt#plt{callbacks = table_insert_list(Callbacks, List)}. + CallbacksList = dialyzer_codeserver:get_callbacks(Codeserver), + CallbacksByModule = + [{M, [Cb || {{M1,_,_},_} = Cb <- CallbacksList, M1 =:= M]} || + M <- lists:usort([M || {{M,_,_},_} <- CallbacksList])], + true = ets:insert(Callbacks, CallbacksByModule), + Plt. + +-spec is_contract(plt(), mfa()) -> boolean(). + +is_contract(#plt{contracts = ETSContracts}, + {M, F, _} = MFA) when is_atom(M), is_atom(F) -> + ets:member(ETSContracts, MFA). -spec lookup_contract(plt(), mfa_patt()) -> 'none' | {'value', #contract{}}. -lookup_contract(#mini_plt{contracts = ETSContracts}, +lookup_contract(#plt{contracts = ETSContracts}, {M, F, _} = MFA) when is_atom(M), is_atom(F) -> ets_table_lookup(ETSContracts, MFA). -spec lookup_callbacks(plt(), module()) -> 'none' | {'value', [{mfa(), dialyzer_contracts:file_contract()}]}. -lookup_callbacks(#mini_plt{callbacks = ETSCallbacks}, Mod) when is_atom(Mod) -> +lookup_callbacks(#plt{callbacks = ETSCallbacks}, Mod) when is_atom(Mod) -> ets_table_lookup(ETSCallbacks, Mod). -type ret_args_types() :: {erl_types:erl_type(), [erl_types:erl_type()]}. -spec insert_list(plt(), [{mfa() | integer(), ret_args_types()}]) -> plt(). -insert_list(#mini_plt{info = Info} = PLT, List) -> +insert_list(#plt{info = Info} = PLT, List) -> true = ets:insert(Info, List), PLT. @@ -184,31 +187,31 @@ lookup(Plt, {M, F, _} = MFA) when is_atom(M), is_atom(F) -> lookup(Plt, Label) when is_integer(Label) -> lookup_1(Plt, Label). -lookup_1(#mini_plt{info = Info}, MFAorLabel) -> +lookup_1(#plt{info = Info}, MFAorLabel) -> ets_table_lookup(Info, MFAorLabel). -spec insert_types(plt(), ets:tid()) -> plt(). -insert_types(MiniPLT, Records) -> - ets:rename(Records, plt_types), - MiniPLT#mini_plt{types = Records}. +insert_types(PLT, Records) -> + ok = dialyzer_utils:ets_move(Records, PLT#plt.types), + PLT. -spec insert_exported_types(plt(), ets:tid()) -> plt(). -insert_exported_types(MiniPLT, ExpTypes) -> - ets:rename(ExpTypes, plt_exported_types), - MiniPLT#mini_plt{exported_types = ExpTypes}. +insert_exported_types(PLT, ExpTypes) -> + ok = dialyzer_utils:ets_move(ExpTypes, PLT#plt.exported_types), + PLT. -spec get_module_types(plt(), atom()) -> 'none' | {'value', erl_types:type_table()}. get_module_types(#plt{types = Types}, M) when is_atom(M) -> - table_lookup(Types, M). + ets_table_lookup(Types, M). -spec get_exported_types(plt()) -> sets:set(). -get_exported_types(#plt{exported_types = ExpTypes}) -> - ExpTypes. +get_exported_types(#plt{exported_types = ETSExpTypes}) -> + sets:from_list([E || {E} <- table_to_list(ETSExpTypes)]). -type mfa_types() :: {mfa(), erl_types:erl_type(), [erl_types:erl_type()]}. @@ -225,8 +228,7 @@ all_modules(#plt{info = Info, contracts = Cs}) -> -spec contains_mfa(plt(), mfa()) -> boolean(). contains_mfa(#plt{info = Info, contracts = Contracts}, MFA) -> - (table_lookup(Info, MFA) =/= none) - orelse (table_lookup(Contracts, MFA) =/= none). + ets:member(Info, MFA) orelse ets:member(Contracts, MFA). -spec get_default_plt() -> file:filename(). @@ -249,32 +251,60 @@ from_file(FileName) -> from_file(FileName, false). from_file(FileName, ReturnInfo) -> + Plt = new(), + Fun = fun() -> from_file1(Plt, FileName, ReturnInfo) end, + case subproc(Fun) of + {ok, Return} -> + Return; + {error, Msg} -> + delete(Plt), + plt_error(Msg) + end. + +from_file1(Plt, FileName, ReturnInfo) -> case get_record_from_file(FileName) of {ok, Rec} -> case check_version(Rec) of error -> Msg = io_lib:format("Old PLT file ~ts\n", [FileName]), - plt_error(Msg); + {error, Msg}; ok -> + #file_plt{info = FileInfo, + contracts = FileContracts, + callbacks = FileCallbacks, + types = FileTypes, + exported_types = FileExpTypes} = Rec, Types = [{Mod, maps:from_list(dict:to_list(Types))} || - {Mod, Types} <- dict:to_list(Rec#file_plt.types)], - Plt = #plt{info = Rec#file_plt.info, - types = dict:from_list(Types), - contracts = Rec#file_plt.contracts, - callbacks = Rec#file_plt.callbacks, - exported_types = Rec#file_plt.exported_types}, + {Mod, Types} <- dict:to_list(FileTypes)], + CallbacksList = dict:to_list(FileCallbacks), + CallbacksByModule = + [{M, [Cb || {{M1,_,_},_} = Cb <- CallbacksList, M1 =:= M]} || + M <- lists:usort([M || {{M,_,_},_} <- CallbacksList])], + #plt{info = ETSInfo, + types = ETSTypes, + contracts = ETSContracts, + callbacks = ETSCallbacks, + exported_types = ETSExpTypes} = Plt, + [true, true, true] = + [ets:insert(ETS, Data) || + {ETS, Data} <- [{ETSInfo, dict:to_list(FileInfo)}, + {ETSTypes, Types}, + {ETSContracts, dict:to_list(FileContracts)}]], + true = ets:insert(ETSCallbacks, CallbacksByModule), + true = ets:insert(ETSExpTypes, [{ET} || + ET <- sets:to_list(FileExpTypes)]), case ReturnInfo of - false -> Plt; + false -> {ok, Plt}; true -> PltInfo = {Rec#file_plt.file_md5_list, Rec#file_plt.mod_deps}, - {Plt, PltInfo} + {ok, {Plt, PltInfo}} end end; {error, Reason} -> Msg = io_lib:format("Could not read PLT file ~ts: ~p\n", [FileName, Reason]), - plt_error(Msg) + {error, Msg} end. -type err_rsn() :: 'not_valid' | 'no_such_file' | 'read_error'. @@ -283,6 +313,10 @@ from_file(FileName, ReturnInfo) -> | {'error', err_rsn()}. included_files(FileName) -> + Fun = fun() -> included_files1(FileName) end, + subproc(Fun). + +included_files1(FileName) -> case get_record_from_file(FileName) of {ok, #file_plt{file_md5_list = Md5}} -> {ok, [File || {File, _} <- Md5]}; @@ -315,6 +349,9 @@ get_record_from_file(FileName) -> -spec merge_plts([plt()]) -> plt(). +%% One of the PLTs of the list is augmented with the contents of the +%% other PLTs, and returned. The other PLTs are deleted. + merge_plts(List) -> {InfoList, TypesList, ExpTypesList, ContractsList, CallbacksList} = group_fields(List), @@ -327,6 +364,12 @@ merge_plts(List) -> -spec merge_disj_plts([plt()]) -> plt(). +%% One of the PLTs of the list is augmented with the contents of the +%% other PLTs, and returned. The other PLTs are deleted. +%% +%% The keys are compared when checking for disjointness. Sometimes the +%% key is a module(), sometimes an mfa(). It boils down to checking if +%% any module occurs more than once. merge_disj_plts(List) -> {InfoList, TypesList, ExpTypesList, ContractsList, CallbacksList} = group_fields(List), @@ -367,17 +410,36 @@ find_duplicates(List) -> -spec to_file(file:filename(), plt(), mod_deps(), {[file_md5()], mod_deps()}) -> 'ok'. -to_file(FileName, - #plt{info = Info, types = Types, contracts = Contracts, - callbacks = Callbacks, exported_types = ExpTypes}, +%% Write the PLT to file, and deletes the PLT. +to_file(FileName, Plt, ModDeps, MD5_OldModDeps) -> + Fun = fun() -> to_file1(FileName, Plt, ModDeps, MD5_OldModDeps) end, + Return = subproc(Fun), + delete(Plt), + case Return of + ok -> ok; + Thrown -> throw(Thrown) + end. + +to_file1(FileName, + #plt{info = ETSInfo, types = ETSTypes, contracts = ETSContracts, + callbacks = ETSCallbacks, exported_types = ETSExpTypes}, ModDeps, {MD5, OldModDeps}) -> NewModDeps = dict:merge(fun(_Key, OldVal, NewVal) -> ordsets:union(OldVal, NewVal) end, OldModDeps, ModDeps), ImplMd5 = compute_implementation_md5(), + CallbacksList = + [Cb || + {_M, Cbs} <- tab2list(ETSCallbacks), + Cb <- Cbs], + Callbacks = dict:from_list(CallbacksList), + Info = dict:from_list(tab2list(ETSInfo)), + Types = tab2list(ETSTypes), + Contracts = dict:from_list(tab2list(ETSContracts)), + ExpTypes = sets:from_list([E || {E} <- tab2list(ETSExpTypes)]), FileTypes = dict:from_list([{Mod, dict:from_list(maps:to_list(MTypes))} || - {Mod, MTypes} <- dict:to_list(Types)]), + {Mod, MTypes} <- Types]), Record = #file_plt{version = ?VSN, file_md5_list = MD5, info = Info, @@ -393,7 +455,7 @@ to_file(FileName, {error, Reason} -> Msg = io_lib:format("Could not write PLT file ~ts: ~w\n", [FileName, Reason]), - throw({dialyzer_error, Msg}) + {dialyzer_error, Msg} end. -type md5_diff() :: [{'differ', atom()} | {'removed', atom()}]. @@ -406,6 +468,10 @@ to_file(FileName, | {'old_version', [file_md5()]}. check_plt(FileName, RemoveFiles, AddFiles) -> + Fun = fun() -> check_plt1(FileName, RemoveFiles, AddFiles) end, + subproc(Fun). + +check_plt1(FileName, RemoveFiles, AddFiles) -> case get_record_from_file(FileName) of {ok, #file_plt{file_md5_list = Md5, mod_deps = ModDeps} = Rec} -> case check_version(Rec) of @@ -514,67 +580,13 @@ init_md5_list_1([], DiffList, Acc) -> init_md5_list_1(Md5List, [], Acc) -> {ok, lists:reverse(Acc, Md5List)}. --spec get_mini_plt(plt()) -> plt(). +-spec delete(plt()) -> 'ok'. -get_mini_plt(#plt{info = Info, - types = Types, - contracts = Contracts, - callbacks = Callbacks, - exported_types = ExpTypes}) -> - [ETSInfo, ETSContracts] = - [ets:new(Name, [public]) || - Name <- [plt_info, plt_contracts]], - [ETSTypes, ETSCallbacks, ETSExpTypes] = - [ets:new(Name, [compressed, public]) || - Name <- [plt_types, plt_callbacks, plt_exported_types]], - CallbackList = dict:to_list(Callbacks), - CallbacksByModule = - [{M, [Cb || {{M1,_,_},_} = Cb <- CallbackList, M1 =:= M]} || - M <- lists:usort([M || {{M,_,_},_} <- CallbackList])], - [true, true, true] = - [ets:insert(ETS, dict:to_list(Data)) || - {ETS, Data} <- [{ETSInfo, Info}, - {ETSTypes, Types}, - {ETSContracts, Contracts}]], - true = ets:insert(ETSCallbacks, CallbacksByModule), - true = ets:insert(ETSExpTypes, [{ET} || ET <- sets:to_list(ExpTypes)]), - #mini_plt{info = ETSInfo, +delete(#plt{info = ETSInfo, types = ETSTypes, contracts = ETSContracts, callbacks = ETSCallbacks, - exported_types = ETSExpTypes}; -get_mini_plt(undefined) -> - undefined. - --spec restore_full_plt(plt()) -> plt(). - -restore_full_plt(#mini_plt{info = ETSInfo, - types = ETSTypes, - contracts = ETSContracts, - callbacks = ETSCallbacks, - exported_types = ETSExpTypes} = MiniPlt) -> - Info = dict:from_list(tab2list(ETSInfo)), - Contracts = dict:from_list(tab2list(ETSContracts)), - Types = dict:from_list(tab2list(ETSTypes)), - Callbacks = - dict:from_list([Cb || {_M, Cbs} <- tab2list(ETSCallbacks), Cb <- Cbs]), - ExpTypes = sets:from_list([E || {E} <- tab2list(ETSExpTypes)]), - ok = delete(MiniPlt), - #plt{info = Info, - types = Types, - contracts = Contracts, - callbacks = Callbacks, - exported_types = ExpTypes}; -restore_full_plt(undefined) -> - undefined. - --spec delete(plt()) -> 'ok'. - -delete(#mini_plt{info = ETSInfo, - types = ETSTypes, - contracts = ETSContracts, - callbacks = ETSCallbacks, - exported_types = ETSExpTypes}) -> + exported_types = ETSExpTypes}) -> true = ets:delete(ETSContracts), true = ets:delete(ETSTypes), true = ets:delete(ETSInfo), @@ -582,35 +594,15 @@ delete(#mini_plt{info = ETSInfo, true = ets:delete(ETSExpTypes), ok. --spec give_away(plt(), pid()) -> 'ok'. - -give_away(#mini_plt{info = ETSInfo, - types = ETSTypes, - contracts = ETSContracts, - callbacks = ETSCallbacks, - exported_types = ETSExpTypes}, - Pid) -> - true = ets:give_away(ETSContracts, Pid, any), - true = ets:give_away(ETSTypes, Pid, any), - true = ets:give_away(ETSInfo, Pid, any), - true = ets:give_away(ETSCallbacks, Pid, any), - true = ets:give_away(ETSExpTypes, Pid, any), - ok. - -%% Somewhat slower than ets:tab2list(), but uses less memory. -tab2list(T) -> - tab2list(ets:first(T), T, []). +tab2list(Tab) -> + dialyzer_utils:ets_tab2list(Tab). -tab2list('$end_of_table', T, A) -> - case ets:first(T) of % no safe_fixtable()... - '$end_of_table' -> A; - Key -> tab2list(Key, T, A) - end; -tab2list(Key, T, A) -> - Vs = ets:lookup(T, Key), - Key1 = ets:next(T, Key), - ets:delete(T, Key), - tab2list(Key1, T, Vs ++ A). +subproc(Fun) -> + F = fun() -> exit(Fun()) end, + {Pid, Ref} = erlang:spawn_monitor(F), + receive {'DOWN', Ref, process, Pid, Return} -> + Return + end. %%--------------------------------------------------------------------------- %% Edoc @@ -619,7 +611,8 @@ tab2list(Key, T, A) -> get_specs(#plt{info = Info}) -> %% TODO: Should print contracts as well. - L = lists:sort([{MFA, Val} || {{_,_,_} = MFA, Val} <- table_to_list(Info)]), + L = lists:sort([{MFA, Val} || + {{_,_,_} = MFA, Val} <- table_to_list(Info)]), lists:flatten(create_specs(L, [])). beam_file_to_module(Filename) -> @@ -629,7 +622,7 @@ beam_file_to_module(Filename) -> get_specs(#plt{info = Info}, M, F, A) when is_atom(M), is_atom(F) -> MFA = {M, F, A}, - case table_lookup(Info, MFA) of + case ets_table_lookup(Info, MFA) of none -> none; {value, Val} -> lists:flatten(create_specs([{MFA, Val}], [])) end. @@ -666,22 +659,24 @@ plt_error(Msg) -> %%--------------------------------------------------------------------------- %% Ets table -table_new() -> - dict:new(). - table_to_list(Plt) -> - dict:to_list(Plt). + ets:tab2list(Plt). -table_delete_module(Plt, Mod) -> - dict:filter(fun({M, _F, _A}, _Val) -> M =/= Mod; - (_, _) -> true - end, Plt). +table_delete_module(Tab, Mod) -> + MS = ets:fun2ms(fun({{M, _F, _A}, _Val}) -> M =:= Mod; + ({_, _}) -> false + end), + _NumDeleted = ets:select_delete(Tab, MS), + Tab. -table_delete_module1(Plt, Mod) -> - sets:filter(fun({M, _F, _A}) -> M =/= Mod end, Plt). +table_delete_module1(Tab, Mod) -> + MS = ets:fun2ms(fun({{M, _F, _A}}) -> M =:= Mod end), + _NumDeleted = ets:select_delete(Tab, MS), + Tab. -table_delete_module2(Plt, Mod) -> - dict:filter(fun(M, _Val) -> M =/= Mod end, Plt). +table_delete_module2(Tab, Mod) -> + true = ets:delete(Tab, Mod), + Tab. ets_table_delete_list(Tab, [H|T]) -> ets:delete(Tab, H), @@ -689,25 +684,6 @@ ets_table_delete_list(Tab, [H|T]) -> ets_table_delete_list(Tab, []) -> Tab. -table_delete_list(Plt, [H|T]) -> - table_delete_list(dict:erase(H, Plt), T); -table_delete_list(Plt, []) -> - Plt. - -table_insert_list(Plt, [{Key, Val}|Left]) -> - table_insert_list(table_insert(Plt, Key, Val), Left); -table_insert_list(Plt, []) -> - Plt. - -table_insert(Plt, Key, {_File, #contract{}, _Xtra} = C) -> - dict:store(Key, C, Plt). - -table_lookup(Plt, Obj) -> - case dict:find(Obj, Plt) of - error -> none; - {ok, Val} -> {value, Val} - end. - ets_table_lookup(Plt, Obj) -> try ets:lookup_element(Plt, Obj, 2) of Val -> {value, Val} @@ -715,25 +691,28 @@ ets_table_lookup(Plt, Obj) -> _:_ -> none end. -table_lookup_module(Plt, Mod) -> - List = dict:fold(fun(Key, Val, Acc) -> - case Key of - {Mod, _F, _A} -> [{Key, element(1, Val), - element(2, Val)}|Acc]; - _ -> Acc - end - end, [], Plt), +table_lookup_module(Tab, Mod) -> + MS = ets:fun2ms(fun({{M, F, A}, V}) when M =:= Mod -> + {{M, F, A}, V} end), + List = [begin + {V1, V2} = V, + {MFA, V1, V2} + end || {MFA, V} <- ets:select(Tab, MS)], case List =:= [] of true -> none; false -> {value, List} end. -table_all_modules(Plt) -> - Fold = - fun({M, _F, _A}, _Val, Acc) -> sets:add_element(M, Acc); - (_, _, Acc) -> Acc - end, - dict:fold(Fold, sets:new(), Plt). +table_all_modules(Tab) -> + Ks = ets:match(Tab, {'$1', '_'}, 100), + all_mods(Ks, sets:new()). + +all_mods('$end_of_table', S) -> + S; +all_mods({ListsOfKeys, Cont}, S) -> + S1 = lists:foldl(fun([{M, _F, _A}], S0) -> sets:add_element(M, S0) + end, S, ListsOfKeys), + all_mods(ets:match(Cont), S1). table_merge([H|T]) -> table_merge(T, H). @@ -741,7 +720,7 @@ table_merge([H|T]) -> table_merge([], Acc) -> Acc; table_merge([Plt|Plts], Acc) -> - NewAcc = dict:merge(fun(_Key, Val, Val) -> Val end, Plt, Acc), + NewAcc = merge_tables(Plt, Acc), table_merge(Plts, NewAcc). table_disj_merge([H|T]) -> @@ -752,24 +731,18 @@ table_disj_merge([], Acc) -> table_disj_merge([Plt|Plts], Acc) -> case table_is_disjoint(Plt, Acc) of true -> - NewAcc = dict:merge(fun(_Key, _Val1, _Val2) -> gazonk end, - Plt, Acc), + NewAcc = merge_tables(Plt, Acc), table_disj_merge(Plts, NewAcc); false -> throw({dialyzer_error, not_disjoint_plts}) end. -table_is_disjoint(T1, T2) -> - K1 = dict:fetch_keys(T1), - K2 = dict:fetch_keys(T2), - lists:all(fun(E) -> not lists:member(E, K2) end, K1). - sets_merge([H|T]) -> sets_merge(T, H). sets_merge([], Acc) -> Acc; sets_merge([Plt|Plts], Acc) -> - NewAcc = sets:union(Plt, Acc), + NewAcc = merge_tables(Plt, Acc), sets_merge(Plts, NewAcc). sets_disj_merge([H|T]) -> @@ -778,13 +751,39 @@ sets_disj_merge([H|T]) -> sets_disj_merge([], Acc) -> Acc; sets_disj_merge([Plt|Plts], Acc) -> - case sets:is_disjoint(Plt, Acc) of + case table_is_disjoint(Plt, Acc) of true -> - NewAcc = sets:union(Plt, Acc), + NewAcc = merge_tables(Plt, Acc), sets_disj_merge(Plts, NewAcc); false -> throw({dialyzer_error, not_disjoint_plts}) end. +table_is_disjoint(T1, T2) -> + tab_is_disj(ets:first(T1), T1, T2). + +tab_is_disj('$end_of_table', _T1, _T2) -> + true; +tab_is_disj(K1, T1, T2) -> + case ets:member(T2, K1) of + false -> + tab_is_disj(ets:next(T1, K1), T1, T2); + true -> + false + end. + +merge_tables(T1, T2) -> + tab_merge(ets:first(T1), T1, T2). + +tab_merge('$end_of_table', T1, T2) -> + true = ets:delete(T1), + T2; +tab_merge(K1, T1, T2) -> + Vs = ets:lookup(T1, K1), + NextK1 = ets:next(T1, K1), + true = ets:delete(T1, K1), + true = ets:insert(T2, Vs), + tab_merge(NextK1, T1, T2). + %%--------------------------------------------------------------------------- %% Debug utilities. @@ -812,7 +811,8 @@ pp_non_returning() -> lists:foreach(fun({{M, F, _}, Type}) -> io:format("~w:~w~s.\n", [M, F, dialyzer_utils:format_sig(Type)]) - end, lists:sort(None)). + end, lists:sort(None)), + delete(Plt). -spec pp_mod(atom()) -> 'ok'. @@ -828,4 +828,5 @@ pp_mod(Mod) when is_atom(Mod) -> end, lists:sort(List)); none -> io:format("dialyzer: Found no module named '~s' in the PLT\n", [Mod]) - end. + end, + delete(Plt). diff --git a/lib/dialyzer/src/dialyzer_succ_typings.erl b/lib/dialyzer/src/dialyzer_succ_typings.erl index 66c6c5e9ed..48115bb683 100644 --- a/lib/dialyzer/src/dialyzer_succ_typings.erl +++ b/lib/dialyzer/src/dialyzer_succ_typings.erl @@ -97,14 +97,13 @@ init_state_and_get_success_typings(Callgraph, Plt, Codeserver, TimingServer, Solvers, Parent) -> {SCCs, Callgraph1} = ?timing(TimingServer, "order", dialyzer_callgraph:finalize(Callgraph)), - State = #st{callgraph = Callgraph1, plt = dialyzer_plt:get_mini_plt(Plt), + State = #st{callgraph = Callgraph1, plt = Plt, codeserver = Codeserver, parent = Parent, timing_server = TimingServer, solvers = Solvers}, get_refined_success_typings(SCCs, State). get_refined_success_typings(SCCs, #st{callgraph = Callgraph, timing_server = TimingServer} = State) -> - erlang:garbage_collect(), case find_succ_typings(SCCs, State) of {fixpoint, State1} -> State1; {not_fixpoint, NotFixpoint1, State1} -> @@ -139,18 +138,15 @@ get_warnings(Callgraph, Plt, DocPlt, Codeserver, init_state_and_get_success_typings(Callgraph, Plt, Codeserver, TimingServer, Solvers, Parent), Mods = dialyzer_callgraph:modules(InitState#st.callgraph), - MiniPlt = InitState#st.plt, - FindOpaques = lookup_and_find_opaques_fun(Codeserver), + Plt = InitState#st.plt, CWarns = - dialyzer_contracts:get_invalid_contract_warnings(Mods, Codeserver, - MiniPlt, FindOpaques), - MiniDocPlt = dialyzer_plt:get_mini_plt(DocPlt), + dialyzer_contracts:get_invalid_contract_warnings(Mods, Codeserver, Plt), ModWarns = ?timing(TimingServer, "warning", - get_warnings_from_modules(Mods, InitState, MiniDocPlt)), + get_warnings_from_modules(Mods, InitState, DocPlt)), {postprocess_warnings(CWarns ++ ModWarns, Codeserver), - MiniPlt, - dialyzer_plt:restore_full_plt(MiniDocPlt)}. + Plt, + DocPlt}. get_warnings_from_modules(Mods, State, DocPlt) -> #st{callgraph = Callgraph, codeserver = Codeserver, @@ -162,13 +158,13 @@ get_warnings_from_modules(Mods, State, DocPlt) -> collect_warnings(M, {Codeserver, Callgraph, Plt, DocPlt}) -> ModCode = dialyzer_codeserver:lookup_mod_code(M, Codeserver), - Records = dialyzer_codeserver:lookup_mod_records(M, Codeserver), Contracts = dialyzer_codeserver:lookup_mod_contracts(M, Codeserver), AllFuns = collect_fun_info([ModCode]), %% Check if there are contracts for functions that do not exist Warnings1 = dialyzer_contracts:contracts_without_fun(Contracts, AllFuns, Callgraph), Attrs = cerl:module_attrs(ModCode), + Records = dialyzer_codeserver:lookup_mod_records(M, Codeserver), {Warnings2, FunTypes} = dialyzer_dataflow:get_warnings(ModCode, Plt, Callgraph, Codeserver, Records), @@ -251,24 +247,22 @@ lookup_names(Labels, {_Codeserver, Callgraph, _Plt, _Solvers}) -> refine_one_module(M, {CodeServer, Callgraph, Plt, _Solvers}) -> ModCode = dialyzer_codeserver:lookup_mod_code(M, CodeServer), AllFuns = collect_fun_info([ModCode]), - Records = dialyzer_codeserver:lookup_mod_records(M, CodeServer), FunTypes = get_fun_types_from_plt(AllFuns, Callgraph, Plt), + Records = dialyzer_codeserver:lookup_mod_records(M, CodeServer), NewFunTypes = dialyzer_dataflow:get_fun_types(ModCode, Plt, Callgraph, CodeServer, Records), - Contracts1 = dialyzer_codeserver:lookup_mod_contracts(M, CodeServer), - Contracts = orddict:from_list(maps:to_list(Contracts1)), - FindOpaques = find_opaques_fun(Records), - DecoratedFunTypes = - decorate_succ_typings(Contracts, Callgraph, NewFunTypes, FindOpaques), - %% ?debug("NewFunTypes ~tp\n ~n", [dict:to_list(NewFunTypes)]), - %% ?debug("refine DecoratedFunTypes ~tp\n ~n", [dict:to_list(DecoratedFunTypes)]), + {FunMFAContracts, ModOpaques} = + prepare_decoration(NewFunTypes, Callgraph, CodeServer), + DecoratedFunTypes = decorate_succ_typings(FunMFAContracts, ModOpaques), + %% ?Debug("NewFunTypes ~tp\n ~n", [NewFunTypes]), + %% ?debug("refine DecoratedFunTypes ~tp\n ~n", [DecoratedFunTypes]), debug_pp_functions("Refine", NewFunTypes, DecoratedFunTypes, Callgraph), case reached_fixpoint(FunTypes, DecoratedFunTypes) of true -> []; {false, NotFixpoint} -> ?debug("Not fixpoint\n", []), - Plt = insert_into_plt(dict:from_list(NotFixpoint), Callgraph, Plt), + Plt = insert_into_plt(orddict:from_list(NotFixpoint), Callgraph, Plt), [FunLbl || {FunLbl,_Type} <- NotFixpoint] end. @@ -282,22 +276,20 @@ reached_fixpoint_strict(OldTypes, NewTypes) -> end. reached_fixpoint(OldTypes0, NewTypes0, Strict) -> - MapFun = fun(_Key, Type) -> + MapFun = fun({Key, Type}) -> case is_failed_or_not_called_fun(Type) of - true -> failed_fun; - false -> erl_types:t_limit(Type, ?TYPE_LIMIT) + true -> {Key, failed_fun}; + false -> {Key, erl_types:t_limit(Type, ?TYPE_LIMIT)} end end, - OldTypes = dict:map(MapFun, OldTypes0), - NewTypes = dict:map(MapFun, NewTypes0), + OldTypes = lists:map(MapFun, orddict:to_list(OldTypes0)), + NewTypes = lists:map(MapFun, orddict:to_list(NewTypes0)), compare_types(OldTypes, NewTypes, Strict). is_failed_or_not_called_fun(Type) -> erl_types:any_none([erl_types:t_fun_range(Type)|erl_types:t_fun_args(Type)]). -compare_types(Dict1, Dict2, Strict) -> - List1 = lists:keysort(1, dict:to_list(Dict1)), - List2 = lists:keysort(1, dict:to_list(Dict2)), +compare_types(List1, List2, Strict) -> compare_types_1(List1, List2, Strict, []). compare_types_1([{X, _Type1}|Left1], [{X, failed_fun}|Left2], @@ -344,10 +336,6 @@ find_succ_typings(SCCs, #st{codeserver = Codeserver, callgraph = Callgraph, find_succ_types_for_scc(SCC0, {Codeserver, Callgraph, Plt, Solvers}) -> SCC = [MFA || {_, _, _} = MFA <- SCC0], - Contracts1 = [{MFA, dialyzer_codeserver:lookup_mfa_contract(MFA, Codeserver)} - || MFA <- SCC], - Contracts2 = [{MFA, Contract} || {MFA, {ok, Contract}} <- Contracts1], - Contracts3 = orddict:from_list(Contracts2), Label = dialyzer_codeserver:get_next_core_label(Codeserver), AllFuns = lists:append( [begin @@ -355,7 +343,6 @@ find_succ_types_for_scc(SCC0, {Codeserver, Callgraph, Plt, Solvers}) -> dialyzer_codeserver:lookup_mfa_code(MFA, Codeserver), collect_fun_info([Fun]) end || MFA <- SCC]), - erlang:garbage_collect(), PropTypes = get_fun_types_from_plt(AllFuns, Callgraph, Plt), %% Assume that the PLT contains the current propagated types FunTypes = dialyzer_typesig:analyze_scc(SCC, Label, Callgraph, @@ -363,27 +350,28 @@ find_succ_types_for_scc(SCC0, {Codeserver, Callgraph, Plt, Solvers}) -> Solvers), AllFunSet = sets:from_list([X || {X, _} <- AllFuns]), FilteredFunTypes = - dict:filter(fun(X, _) -> sets:is_element(X, AllFunSet) end, FunTypes), - FindOpaques = lookup_and_find_opaques_fun(Codeserver), - DecoratedFunTypes = - decorate_succ_typings(Contracts3, Callgraph, FilteredFunTypes, FindOpaques), + orddict:filter(fun(F, _T) -> sets:is_element(F, AllFunSet) + end, FunTypes), + {FunMFAContracts, ModOpaques} = + prepare_decoration(FilteredFunTypes, Callgraph, Codeserver), + DecoratedFunTypes = decorate_succ_typings(FunMFAContracts, ModOpaques), %% Check contracts + Contracts = orddict:from_list([{MFA, Contract} || + {_, {MFA, Contract}} <- FunMFAContracts]), PltContracts = - dialyzer_contracts:check_contracts(Contracts3, Callgraph, - DecoratedFunTypes, FindOpaques), - %% ?debug("FilteredFunTypes ~tp\n ~n", [dict:to_list(FilteredFunTypes)]), - %% ?debug("SCC DecoratedFunTypes ~tp\n ~n", [dict:to_list(DecoratedFunTypes)]), + dialyzer_contracts:check_contracts(Contracts, Callgraph, + DecoratedFunTypes, + ModOpaques), + %% ?debug("FilteredFunTypes ~tp\n ~n", [FilteredFunTypes]), + %% ?debug("SCC DecoratedFunTypes ~tp\n ~n", [DecoratedFunTypes]), debug_pp_functions("SCC", FilteredFunTypes, DecoratedFunTypes, Callgraph), - ContractFixpoint = - lists:all(fun({MFA, _C}) -> - %% Check the non-deleted PLT - case dialyzer_plt:lookup_contract(Plt, MFA) of - none -> false; - {value, _} -> true - end - end, PltContracts), + NewPltContracts = [MC || + {MFA, _C}=MC <- PltContracts, + %% Check the non-deleted PLT + not dialyzer_plt:is_contract(Plt, MFA)], + ContractFixpoint = NewPltContracts =:= [], Plt = insert_into_plt(DecoratedFunTypes, Callgraph, Plt), - Plt = dialyzer_plt:insert_contract_list(Plt, PltContracts), + Plt = dialyzer_plt:insert_contract_list(Plt, NewPltContracts), case (ContractFixpoint andalso reached_fixpoint_strict(PropTypes, DecoratedFunTypes)) of true -> []; @@ -392,42 +380,49 @@ find_succ_types_for_scc(SCC0, {Codeserver, Callgraph, Plt, Solvers}) -> [Fun || {Fun, _Arity} <- AllFuns] end. -decorate_succ_typings(Contracts, Callgraph, FunTypes, FindOpaques) -> - F = fun(Label, Type) -> +prepare_decoration(FunTypes, Callgraph, Codeserver) -> + F = fun({Label, _Type}=LabelType, Acc) -> case dialyzer_callgraph:lookup_name(Label, Callgraph) of {ok, MFA} -> - case orddict:find(MFA, Contracts) of + case dialyzer_codeserver:lookup_mfa_contract(MFA, Codeserver) of {ok, {_FileLine, Contract, _Xtra}} -> - Args = dialyzer_contracts:get_contract_args(Contract), - Ret = dialyzer_contracts:get_contract_return(Contract), - C = erl_types:t_fun(Args, Ret), - {M, _, _} = MFA, - Opaques = FindOpaques(M), - erl_types:t_decorate_with_opaque(Type, C, Opaques); - error -> Type + [{LabelType, {MFA, Contract}}|Acc]; + error -> [{LabelType, no}|Acc] end; - error -> Type + error -> [{LabelType, no}|Acc] end end, - dict:map(F, FunTypes). - -lookup_and_find_opaques_fun(Codeserver) -> - fun(Module) -> - Records = dialyzer_codeserver:lookup_mod_records(Module, Codeserver), - (find_opaques_fun(Records))(Module) - end. + Contracts = lists:foldl(F, [], orddict:to_list(FunTypes)), + ModOpaques = + [{M, lookup_opaques(M, Codeserver)} || + M <- lists:usort([M || {_LabelType, {{M, _, _}, _Con}} <- Contracts])], + {Contracts, orddict:from_list(ModOpaques)}. + +decorate_succ_typings(FunTypesContracts, ModOpaques) -> + F = fun({{Label, Type}, {{M, _, _}, Contract}}) -> + Args = dialyzer_contracts:get_contract_args(Contract), + Ret = dialyzer_contracts:get_contract_return(Contract), + C = erl_types:t_fun(Args, Ret), + {M, Opaques} = lists:keyfind(M, 1, ModOpaques), + R = erl_types:t_decorate_with_opaque(Type, C, Opaques), + {Label, R}; + ({LabelType, no}) -> + LabelType + end, + orddict:from_list(lists:map(F, FunTypesContracts)). -find_opaques_fun(Records) -> - fun(_Module) -> erl_types:t_opaque_from_records(Records) end. +lookup_opaques(Module, Codeserver) -> + Records = dialyzer_codeserver:lookup_mod_records(Module, Codeserver), + erl_types:t_opaque_from_records(Records). get_fun_types_from_plt(FunList, Callgraph, Plt) -> - get_fun_types_from_plt(FunList, Callgraph, Plt, dict:new()). + get_fun_types_from_plt(FunList, Callgraph, Plt, []). get_fun_types_from_plt([{FunLabel, Arity}|Left], Callgraph, Plt, Map) -> Type = lookup_fun_type(FunLabel, Arity, Callgraph, Plt), - get_fun_types_from_plt(Left, Callgraph, Plt, dict:store(FunLabel, Type, Map)); + get_fun_types_from_plt(Left, Callgraph, Plt, [{FunLabel, Type}|Map]); get_fun_types_from_plt([], _Callgraph, _Plt, Map) -> - Map. + orddict:from_list(Map). collect_fun_info(Trees) -> collect_fun_info(Trees, []). @@ -463,7 +458,7 @@ insert_into_plt(SuccTypes0, Callgraph, Plt) -> dialyzer_plt:insert_list(Plt, SuccTypes). format_succ_types(SuccTypes, Callgraph) -> - format_succ_types(dict:to_list(SuccTypes), Callgraph, []). + format_succ_types(SuccTypes, Callgraph, []). format_succ_types([{Label, Type0}|Left], Callgraph, Acc) -> Type = erl_types:t_limit(Type0, ?TYPE_LIMIT+1), @@ -486,10 +481,8 @@ debug_pp_succ_typings(SuccTypes) -> ?debug("\n", []), ok. -debug_pp_functions(Header, FunTypes, DecoratedFunTypes, Callgraph) -> +debug_pp_functions(Header, FTypes, DTypes, Callgraph) -> ?debug("FunTypes (~s)\n", [Header]), - FTypes = lists:keysort(1, dict:to_list(FunTypes)), - DTypes = lists:keysort(1, dict:to_list(DecoratedFunTypes)), Fun = fun({{Label, Type},{Label, DecoratedType}}) -> Name = lookup_name(Label, Callgraph), ?debug("~tw (~w): ~ts\n", diff --git a/lib/dialyzer/src/dialyzer_typesig.erl b/lib/dialyzer/src/dialyzer_typesig.erl index a0a69cb2ea..c4d8f45447 100644 --- a/lib/dialyzer/src/dialyzer_typesig.erl +++ b/lib/dialyzer/src/dialyzer_typesig.erl @@ -96,10 +96,12 @@ -type typesig_funmap() :: #{type_var() => type_var()}. --type prop_types() :: dict:dict(label(), erl_types:erl_type()). +-type prop_types() :: orddict:orddict(label(), erl_types:erl_type()). +-type dict_prop_types() :: dict:dict(label(), erl_types:erl_type()). -record(state, {callgraph :: dialyzer_callgraph:callgraph() | 'undefined', + cserver :: dialyzer_codeserver:codeserver(), cs = [] :: [constr()], cmap = maps:new() :: #{type_var() => constr()}, fun_map = maps:new() :: typesig_funmap(), @@ -112,8 +114,8 @@ self_rec :: 'false' | erl_types:erl_type(), plt :: dialyzer_plt:plt() | 'undefined', - prop_types = dict:new() :: prop_types(), - records = maps:new() :: types(), + prop_types = dict:new() :: dict_prop_types(), + mod_records = [] :: [{module(), types()}], scc = [] :: ordsets:ordset(type_var()), mfas :: [mfa()], solvers = [] :: [solver()] @@ -138,9 +140,11 @@ -ifdef(DEBUG). -define(debug(__String, __Args), io:format(__String, __Args)). -define(mk_fun_var(Fun, Vars), mk_fun_var(?LINE, Fun, Vars)). +-define(pp_map(S, M), pp_map(S, M)). -else. -define(debug(__String, __Args), ok). -define(mk_fun_var(Fun, Vars), mk_fun_var(Fun, Vars)). +-define(pp_map(S, M), ok). -endif. %% ============================================================================ @@ -177,15 +181,13 @@ analyze_scc(SCC, NextLabel, CallGraph, CServer, Plt, PropTypes, Solvers0) -> State1 = new_state(SCC, NextLabel, CallGraph, CServer, Plt, PropTypes, Solvers), DefSet = add_def_list(maps:values(State1#state.name_map), sets:new()), - ModRecs = [{M, dialyzer_codeserver:lookup_mod_records(M, CServer)} || - M <- lists:usort([M || {M, _, _} <- SCC])], - State2 = traverse_scc(SCC, CServer, DefSet, ModRecs, State1), + State2 = traverse_scc(SCC, CServer, DefSet, State1), State3 = state__finalize(State2), Funs = state__scc(State3), pp_constrs_scc(Funs, State3), constraints_to_dot_scc(Funs, State3), T = solve(Funs, State3), - dict:from_list(maps:to_list(T)). + orddict:from_list(maps:to_list(T)). solvers([]) -> [v2]; solvers(Solvers) -> Solvers. @@ -196,15 +198,14 @@ solvers(Solvers) -> Solvers. %% %% ============================================================================ -traverse_scc([{M,_,_}=MFA|Left], Codeserver, DefSet, ModRecs, AccState) -> +traverse_scc([{M,_,_}=MFA|Left], Codeserver, DefSet, AccState) -> + TmpState1 = state__set_module(AccState, M), Def = dialyzer_codeserver:lookup_mfa_code(MFA, Codeserver), - {M, Rec} = lists:keyfind(M, 1, ModRecs), - TmpState1 = state__set_rec_dict(AccState, Rec), DummyLetrec = cerl:c_letrec([Def], cerl:c_atom(foo)), TmpState2 = state__new_constraint_context(TmpState1), {NewAccState, _} = traverse(DummyLetrec, DefSet, TmpState2), - traverse_scc(Left, Codeserver, DefSet, ModRecs, NewAccState); -traverse_scc([], _Codeserver, _DefSet, _ModRecs, AccState) -> + traverse_scc(Left, Codeserver, DefSet, NewAccState); +traverse_scc([], _Codeserver, _DefSet, AccState) -> AccState. traverse(Tree, DefinedVars, State) -> @@ -470,12 +471,11 @@ traverse(Tree, DefinedVars, State) -> true -> %% Check if a record is constructed. Arity = length(Fields), - Records = State2#state.records, - case lookup_record(Records, cerl:atom_val(Tag), Arity) of - error -> {State2, TupleType}; - {ok, RecType} -> - State3 = state__store_conj(TupleType, sub, RecType, State2), - {State3, TupleType} + case lookup_record(State2, cerl:atom_val(Tag), Arity) of + {error, State3} -> {State3, TupleType}; + {ok, RecType, State3} -> + State4 = state__store_conj(TupleType, sub, RecType, State3), + {State4, TupleType} end; false -> {State2, TupleType} end; @@ -1440,7 +1440,6 @@ get_bif_constr({erlang, is_record, 2}, Dst, [Var, Tag] = Args, _State) -> mk_constraint(Var, sub, ArgV)]); get_bif_constr({erlang, is_record, 3}, Dst, [Var, Tag, Arity] = Args, State) -> %% TODO: Revise this to make it precise for Tag and Arity. - Records = State#state.records, ArgFun = fun(Map) -> case t_is_any_atom(true, lookup_type(Dst, Map)) of @@ -1457,10 +1456,10 @@ get_bif_constr({erlang, is_record, 3}, Dst, [Var, Tag, Arity] = Args, State) -> GenRecord = t_tuple([TagType|AnyElems]), case t_atom_vals(TagType) of [TagVal] -> - case lookup_record(Records, TagVal, ArityVal - 1) of - {ok, Type} -> + case lookup_record(State, TagVal, ArityVal - 1) of + {ok, Type, _NewState} -> Type; - error -> GenRecord + {error, _NewState} -> GenRecord end; _ -> GenRecord end; @@ -1917,8 +1916,8 @@ check_solutions([{S1,Map1,_Time1}|Maps], Fun, S, Map) -> check_solutions(Maps, Fun, S1, Map1); false -> ?debug("Constraint solvers do not agree on ~w\n", [Fun]), - pp_map(atom_to_list(S), Map), - pp_map(atom_to_list(S1), Map1), + ?pp_map(atom_to_list(S), Map), + ?pp_map(atom_to_list(S1), Map1), io:format("A bug was found. Please report it, and use the option " "`--solver v1' until the bug has been fixed.\n"), throw(error) @@ -1967,7 +1966,7 @@ v2_solve(#constraint_ref{id = Id}, Map, V2State) -> v2_solve_reference(Id, Map, V2State0) -> ?debug("Checking ref to fun: ~tw\n", [debug_lookup_name(Id)]), - pp_map("Map", Map), + ?pp_map("Map", Map), pp_constr_data("solve_ref", V2State0), Map1 = restore_local_map(V2State0, Id, Map), State = V2State0#v2_state.state, @@ -2023,7 +2022,7 @@ v2_solve_self_recursive(Cs, Map, Id, RecType0, V2State0) -> Error end; {ok, NewMap, V2State, U} -> - pp_map("recursive finished", NewMap), + ?pp_map("recursive finished", NewMap), NewRecType = unsafe_lookup_type(Id, NewMap), case is_equal(NewRecType, RecType0) of true -> @@ -2041,7 +2040,7 @@ enter_var_type(Var, Type, Map0) -> v2_solve_disjunct(Disj, Map, V2State0) -> #constraint_list{type = disj, id = _Id, list = Cs, masks = Masks} = Disj, ?debug("disjunct Id=~w~n", [_Id]), - pp_map("Map", Map), + ?pp_map("Map", Map), pp_constr_data("disjunct", V2State0), case get_flags(V2State0, Disj) of {V2State1, failed_list} -> {error, V2State1}; % cannot happen @@ -2069,7 +2068,7 @@ v2_solve_disjunct(Disj, Map, V2State0) -> U1 = [V || V <- U0, var_occurs_everywhere(V, Masks, NotFailed)], NewMap = join_maps(U1, MapL, Map), - pp_map("NewMap", NewMap), + ?pp_map("NewMap", NewMap), U = updated_vars_only(U1, Map, NewMap), ?debug("disjunct finished _Id=~w\n", [_Id]), {ok, NewMap, V2State, U} @@ -2092,7 +2091,7 @@ v2_solve_disj([I|Is], [C|Cs], I, Map0, V2State0, UL, MapL, Eval, Uneval, {ok, Map, V2State1, U} -> ?debug("disj I=~w U=~w~n", [I, U]), V2State = save_local_map(V2State1, Id, U, Map), - pp_map("DMap", Map), + ?pp_map("DMap", Map), v2_solve_disj(Is, Cs, I+1, Map0, V2State, [U|UL], [Map|MapL], [I|Eval], Uneval, Failed0) end; @@ -2118,9 +2117,9 @@ save_local_map(#v2_state{constr_data = ConData}=V2State, Id, U, Map) -> end, ?debug("save local map Id=~w:\n", [Id]), Part = lists:ukeymerge(1, lists:keysort(1, Part0), Part1), - pp_map("New Part", maps:from_list(Part0)), - pp_map("Old Part", maps:from_list(Part1)), - pp_map(" => Part", maps:from_list(Part)), + ?pp_map("New Part", maps:from_list(Part0)), + ?pp_map("Old Part", maps:from_list(Part1)), + ?pp_map(" => Part", maps:from_list(Part)), V2State#v2_state{constr_data = maps:put(Id, {Part,[]}, ConData)}. restore_local_map(#v2_state{constr_data = ConData}, Id, Map0) -> @@ -2131,10 +2130,10 @@ restore_local_map(#v2_state{constr_data = ConData}, Id, Map0) -> {ok, {Part0,U}} -> Part = [KV || {K,_V} = KV <- Part0, not lists:member(K, U)], ?debug("restore local map Id=~w U=~w\n", [Id, U]), - pp_map("Part", maps:from_list(Part)), - pp_map("Map0", Map0), + ?pp_map("Part", maps:from_list(Part)), + ?pp_map("Map0", Map0), Map = lists:foldl(fun({K,V}, D) -> maps:put(K, V, D) end, Map0, Part), - pp_map("Map", Map), + ?pp_map("Map", Map), Map end. @@ -2290,7 +2289,7 @@ pp_constr_data(_Tag, #v2_state{constr_data = D}) -> case _PartU of {_Part, _U} -> io:format("Id: ~w Vars: ~w\n", [_Id, _U]), - [pp_map("Part", maps:from_list(_Part)) || _Part =/= []]; + [?pp_map("Part", maps:from_list(_Part)) || _Part =/= []]; failed -> io:format("Id: ~w failed list\n", [_Id]) end @@ -2390,7 +2389,7 @@ solve_self_recursive(Cs, Map, MapDict, Id, RecType0, State) -> ?debug("OldRecType ~ts\n", [format_type(RecType0)]), RecType = t_limit(RecType0, ?TYPE_LIMIT), Map1 = enter_type(RecVar, RecType, erase_type(t_var_name(Id), Map)), - pp_map("Map1", Map1), + ?pp_map("Map1", Map1), case solve_ref_or_list(Cs, Map1, MapDict, State) of {error, _} = Error -> case t_is_none(RecType0) of @@ -2403,7 +2402,7 @@ solve_self_recursive(Cs, Map, MapDict, Id, RecType0, State) -> Error end; {ok, NewMapDict, NewMap} -> - pp_map("NewMap", NewMap), + ?pp_map("NewMap", NewMap), NewRecType = unsafe_lookup_type(Id, NewMap), case is_equal(NewRecType, RecType0) of true -> @@ -2702,18 +2701,13 @@ is_same(Key, Map1, Map2) -> is_equal(Type1, Type2) -> t_is_equal(Type1, Type2). -pp_map(_S, _Map) -> - ?debug("\t~s: ~tp\n", - [_S, [{X, lists:flatten(format_type(Y))} || - {X, Y} <- lists:keysort(1, maps:to_list(_Map))]]). - %% ============================================================================ %% %% The State. %% %% ============================================================================ -new_state(MFAs, NextLabel, CallGraph, CServer, Plt, PropTypes, Solvers) -> +new_state(MFAs, NextLabel, CallGraph, CServer, Plt, PropTypes0, Solvers) -> List_SCC = [begin {Var, Label} = dialyzer_codeserver:lookup_mfa_var_label(MFA, CServer), @@ -2731,12 +2725,14 @@ new_state(MFAs, NextLabel, CallGraph, CServer, Plt, PropTypes, Solvers) -> end; _Many -> false end, + PropTypes = dict:from_list(PropTypes0), #state{callgraph = CallGraph, name_map = NameMap, next_label = NextLabel, prop_types = PropTypes, plt = Plt, scc = ordsets:from_list(SCC), - mfas = MFAs, self_rec = SelfRec, solvers = Solvers}. + mfas = MFAs, self_rec = SelfRec, solvers = Solvers, + cserver = CServer}. -state__set_rec_dict(State, RecDict) -> - State#state{records = RecDict}. +state__set_module(State, Module) -> + State#state{module = Module}. state__set_in_match(State, Bool) -> State#state{in_match = Bool}. @@ -2975,6 +2971,11 @@ mk_fun_var(Line, Fun, Types) -> Deps = [t_var_name(Var) || Var <- t_collect_vars(t_product(Types))], #fun_var{'fun' = Fun, deps = ordsets:from_list(Deps), origin = Line}. +pp_map(S, Map) -> + ?debug("\t~s: ~p\n", + [S, [{X, lists:flatten(format_type(Y))} || + {X, Y} <- lists:keysort(1, maps:to_list(Map))]]). + -else. -spec mk_fun_var(fun((_) -> erl_types:erl_type()), [erl_types:erl_type()]) -> #fun_var{}. @@ -3348,15 +3349,25 @@ fold_literal_maybe_match(Tree0, State) -> true -> dialyzer_utils:refold_pattern(Tree1) end. -lookup_record(Records, Tag, Arity) -> - case erl_types:lookup_record(Tag, Arity, Records) of +lookup_record(State, Tag, Arity) -> + #state{module = M, mod_records = ModRecs, cserver = CServer} = State, + {State1, Rec} = + case lists:keyfind(M, 1, ModRecs) of + {M, Rec0} -> + {State, Rec0}; + false -> + Rec0 = dialyzer_codeserver:lookup_mod_records(M, CServer), + NewModRecs = [{M, Rec0}|ModRecs], + {State#state{mod_records = NewModRecs}, Rec0} + end, + case erl_types:lookup_record(Tag, Arity, Rec) of {ok, Fields} -> RecType = t_tuple([t_from_term(Tag)| [FieldType || {_FieldName, _Abstr, FieldType} <- Fields]]), - {ok, RecType}; + {ok, RecType, State1}; error -> - error + {error, State1} end. is_literal_record(Tree) -> diff --git a/lib/dialyzer/src/dialyzer_utils.erl b/lib/dialyzer/src/dialyzer_utils.erl index e5941d0ab8..511a6d66bf 100644 --- a/lib/dialyzer/src/dialyzer_utils.erl +++ b/lib/dialyzer/src/dialyzer_utils.erl @@ -39,6 +39,8 @@ sets_filter/2, src_compiler_opts/0, refold_pattern/1, + ets_tab2list/1, + ets_move/2, parallelism/0, family/1 ]). @@ -340,7 +342,19 @@ process_record_remote_types(CServer) -> {FieldsList, C3} = lists:mapfoldl(FieldFun, C2, orddict:to_list(Fields)), {{Key, {FileLine, orddict:from_list(FieldsList)}}, C3}; - _Other -> {{Key, Value}, C2} + {type, Name, NArgs} -> + %% Make sure warnings about unknown types are output + %% also for types unused by specs. + Site = {type, {Module, Name, NArgs}}, + L = erl_anno:new(0), + Args = lists:duplicate(NArgs, {var, L, '_'}), + UserType = {user_type, L, Name, Args}, + {_NewType, C3} = + erl_types:t_from_form(UserType, ExpTypes, Site, + RecordTable, VarTable, C2), + {{Key, Value}, C3}; + {opaque, _Name, _NArgs} -> + {{Key, Value}, C2} end end, Cache = erl_types:cache__new(), @@ -378,7 +392,10 @@ process_opaque_types(AllModules, CServer, TempExpTypes) -> erl_types:t_from_form(Form, TempExpTypes, Site, RecordTable, VarTable, C2), {{Key, {F, Type}}, C3}; - _Other -> {{Key, Value}, C2} + {type, _Name, _NArgs} -> + {{Key, Value}, C2}; + {record, _RecName} -> + {{Key, Value}, C2} end end, C0 = erl_types:cache__new(), @@ -974,6 +991,35 @@ label(Tree) -> %%------------------------------------------------------------------------------ +-spec ets_tab2list(ets:tid()) -> list(). + +%% Deletes the contents of the table. Use: +%% ets_tab2list(T), ets:delete(T) +%% instead of: +%% ets:tab2list(T), ets:delete(T) +%% to save some memory at the expense of somewhat longer execution time. +ets_tab2list(T) -> + F = fun(Vs, A) -> Vs ++ A end, + ets_take(ets:first(T), T, F, []). + +-spec ets_move(From :: ets:tid(), To :: ets:tid()) -> 'ok'. + +ets_move(T1, T2) -> + F = fun(Es, A) -> true = ets:insert(T2, Es), A end, + [] = ets_take(ets:first(T1), T1, F, []), + ok. + +ets_take('$end_of_table', T, F, A) -> + case ets:first(T) of % no safe_fixtable()... + '$end_of_table' -> A; + Key -> ets_take(Key, T, F, A) + end; +ets_take(Key, T, F, A) -> + Vs = ets:lookup(T, Key), + Key1 = ets:next(T, Key), + true = ets:delete(T, Key), + ets_take(Key1, T, F, F(Vs, A)). + -spec parallelism() -> integer(). parallelism() -> diff --git a/lib/dialyzer/src/typer.erl b/lib/dialyzer/src/typer.erl index 43e03be740..bf5484e5f6 100644 --- a/lib/dialyzer/src/typer.erl +++ b/lib/dialyzer/src/typer.erl @@ -158,10 +158,9 @@ get_type_info(#analysis{callgraph = CallGraph, StrippedCallGraph = remove_external(CallGraph, TrustPLT), %% io:format("--- Analyzing callgraph... "), try - NewMiniPlt = dialyzer_succ_typings:analyze_callgraph(StrippedCallGraph, - TrustPLT, - CodeServer), - NewPlt = dialyzer_plt:restore_full_plt(NewMiniPlt), + NewPlt = dialyzer_succ_typings:analyze_callgraph(StrippedCallGraph, + TrustPLT, + CodeServer), Analysis#analysis{callgraph = StrippedCallGraph, trust_plt = NewPlt} catch error:What -> diff --git a/lib/dialyzer/test/options2_SUITE_data/results/unused_unknown_type b/lib/dialyzer/test/options2_SUITE_data/results/unused_unknown_type new file mode 100644 index 0000000000..110d896c76 --- /dev/null +++ b/lib/dialyzer/test/options2_SUITE_data/results/unused_unknown_type @@ -0,0 +1,2 @@ + +:0: Unknown type unknown:type1/0:0: Unknown type unknown:type2/0:0: Unknown type unknown:type3/0
\ No newline at end of file diff --git a/lib/dialyzer/test/options2_SUITE_data/src/unused_unknown_type.erl b/lib/dialyzer/test/options2_SUITE_data/src/unused_unknown_type.erl new file mode 100644 index 0000000000..90df7d528a --- /dev/null +++ b/lib/dialyzer/test/options2_SUITE_data/src/unused_unknown_type.erl @@ -0,0 +1,10 @@ +-module(unused_unknown_type). + +-export_type([unused/0]). + +-type unused() :: unknown:type1(). + +-record(unused_rec, {a :: unknown:type2()}). + +-record(rec, {a}). +-type unused_rec() :: #rec{a :: unknown:type3()}. diff --git a/lib/dialyzer/test/plt_SUITE.erl b/lib/dialyzer/test/plt_SUITE.erl index 92c63bdb0c..ebe79b2a6d 100644 --- a/lib/dialyzer/test/plt_SUITE.erl +++ b/lib/dialyzer/test/plt_SUITE.erl @@ -9,14 +9,14 @@ -export([suite/0, all/0, build_plt/1, beam_tests/1, update_plt/1, local_fun_same_as_callback/1, remove_plt/1, run_plt_check/1, run_succ_typings/1, - bad_dialyzer_attr/1]). + bad_dialyzer_attr/1, merge_plts/1]). suite() -> [{timetrap, ?plt_timeout}]. all() -> [build_plt, beam_tests, update_plt, run_plt_check, remove_plt, run_succ_typings, local_fun_same_as_callback, - bad_dialyzer_attr]. + bad_dialyzer_attr, merge_plts]. build_plt(Config) -> OutDir = ?config(priv_dir, Config), @@ -170,6 +170,7 @@ update_plt(Config) -> {init_plt, Plt}] ++ Opts), ok. + %%% If a behaviour module contains an non-exported function with the same name %%% as one of the behaviour's callbacks, the callback info was inadvertently %%% deleted from the PLT as the dialyzer_plt:delete_list/2 function was cleaning @@ -297,6 +298,87 @@ bad_dialyzer_attr(Config) -> ok. +merge_plts(Config) -> + %% A few checks of merging PLTs. + fun() -> + {Mod1, Mod2} = types(), + {BeamFiles, Plt1, Plt2} = create_plts(Mod1, Mod2, Config), + + {dialyzer_error, + "Could not merge PLTs since they are not disjoint"++_} = + (catch run_dialyzer(succ_typings, BeamFiles, + [{plts, [Plt1, Plt1]}])), + [{warn_contract_types,_,_}] = + run_dialyzer(succ_typings, BeamFiles, + [{warnings, [unknown]}, + {plts, [Plt1, Plt2]}]) + end(), + + fun() -> + {Mod1, Mod2} = callbacks(), + {BeamFiles, Plt1, Plt2} = create_plts(Mod1, Mod2, Config), + + {dialyzer_error, + "Could not merge PLTs since they are not disjoint"++_} = + (catch run_dialyzer(succ_typings, BeamFiles, + [{plts, [Plt1, Plt1]}])), + [] = + run_dialyzer(succ_typings, BeamFiles, + [{warnings, [unknown]}, + {plts, [Plt1, Plt2]}]) + end(), + + ok. + +types() -> + Mod1 = <<"-module(merge_plts_1). + -export([f/0]). + -export_type([t/0]). + -type t() :: merge_plts_2:t(). + -spec f() -> t(). + f() -> 1. % Not an atom(). + ">>, + Mod2 = <<"-module(merge_plts_2). + -export_type([t/0]). + -type t() :: atom(). + ">>, + {Mod1, Mod2}. + +callbacks() -> % A very shallow test. + Mod1 = <<"-module(merge_plts_1). + -callback t() -> merge_plts_2:t(). + ">>, + Mod2 = <<"-module(merge_plts_2). + -export_type([t/0]). + -type t() :: atom(). + ">>, + {Mod1, Mod2}. + +create_plts(Mod1, Mod2, Config) -> + PrivDir = ?config(priv_dir, Config), + Plt1 = filename:join(PrivDir, "merge_plts_1.plt"), + Plt2 = filename:join(PrivDir, "merge_plts_2.plt"), + ErlangBeam = erlang_beam(), + + {ok, BeamFile1} = compile(Config, Mod1, merge_plts_1, []), + [] = run_dialyzer(plt_build, [ErlangBeam,BeamFile1], [{output_plt,Plt1}]), + + {ok, BeamFile2} = compile(Config, Mod2, merge_plts_2, []), + [] = run_dialyzer(plt_build, [BeamFile2], [{output_plt, Plt2}]), + {[BeamFile1, BeamFile2], Plt1, Plt2}. + +%% End of merge_plts(). + +erlang_beam() -> + case code:where_is_file("erlang.beam") of + non_existing -> + filename:join([code:root_dir(), + "erts", "preloaded", "ebin", + "erlang.beam"]); + EBeam -> + EBeam + end. + compile(Config, Prog, Module, CompileOpts) -> Source = lists:concat([Module, ".erl"]), PrivDir = ?config(priv_dir,Config), diff --git a/lib/diameter/doc/src/diameter.xml b/lib/diameter/doc/src/diameter.xml index 72181a42b0..2cbe48ecce 100644 --- a/lib/diameter/doc/src/diameter.xml +++ b/lib/diameter/doc/src/diameter.xml @@ -21,7 +21,7 @@ <copyright> <year>2011</year> -<year>2016</year> +<year>2017</year> <holder>Ericsson AB. All Rights Reserved.</holder> </copyright> <legalnotice> @@ -300,6 +300,17 @@ corresponding list of filters. Defaults to <c>none</c>.</p> </item> +<tag><c>{peer, &app_peer_ref;}</c></tag> +<item> +<p> +Peer to which the request in question can be sent, preempting the +selection of peers having advertised support for the Diameter +application in question. +Multiple options can be specified, and their order is +respected in the candidate lists passed to a subsequent +&app_pick_peer; callback.</p> +</item> + <tag><c>{timeout, &dict_Unsigned32;}</c></tag> <item> <p> diff --git a/lib/diameter/doc/src/diameter_codec.xml b/lib/diameter/doc/src/diameter_codec.xml index 91e96058dd..0117c1c88a 100644 --- a/lib/diameter/doc/src/diameter_codec.xml +++ b/lib/diameter/doc/src/diameter_codec.xml @@ -13,7 +13,8 @@ <erlref> <header> <copyright> -<year>2012</year><year>2016</year> +<year>2012</year> +<year>2017</year> <holder>Ericsson AB. All Rights Reserved.</holder> </copyright> <legalnotice> @@ -53,17 +54,17 @@ communicated to &man_app; callbacks. Similarly, outgoing Diameter messages are encoded into binary() before being passed to the appropriate &man_transport; module for transmission. -The functions in this module implement this encode/decode.</p> +The functions documented here implement the default encode/decode.</p> -<note> +<warning> <p> -Calls to this module are made by diameter itself as a consequence of -configuration passed to &mod_start_service;. -The encode/decode functions may also be useful for other purposes (eg. -test) but the diameter user does not need to call them explicitly when +The diameter user does not need to call functions here explicitly when sending and receiving messages using &mod_call; and the callback -interface documented in &man_app;.</p> -</note> +interface documented in &man_app;: diameter itself provides encode/decode +as a consequence of configuration passed to &mod_start_service;, and +the results may differ from those returned by the functions documented +here, depending on configuration.</p> +</warning> <p> The &header; and &packet; records below diff --git a/lib/diameter/doc/src/diameter_dict.xml b/lib/diameter/doc/src/diameter_dict.xml index 9584d682c2..94016d9466 100644 --- a/lib/diameter/doc/src/diameter_dict.xml +++ b/lib/diameter/doc/src/diameter_dict.xml @@ -16,7 +16,8 @@ <header> <copyright> -<year>2011</year><year>2016</year> +<year>2011</year> +<year>2017</year> <holder>Ericsson AB. All Rights Reserved.</holder> </copyright> <legalnotice> @@ -307,11 +308,11 @@ The P flag has been deprecated by &the_rfc;.</p> <p> Specifies AVPs for which module Mod provides encode/decode functions. The section contents consists of AVP names. -For each such name, <c>Mod:Name(encode|decode, Type, Data)</c> is +For each such name, <c>Mod:Name(encode|decode, Type, Data, Opts)</c> is expected to provide encode/decode for values of the AVP, where Name is the name of the AVP, Type is it's type as declared in the -<c>@avp_types</c> section of the dictionary and Data is the value to -encode/decode.</p> +<c>@avp_types</c> section of the dictionary, Data is the value to +encode/decode, and Opts is a term that is passed through encode/decode.</p> <p> Example:</p> @@ -328,8 +329,8 @@ Framed-IP-Address <item> <p> Like <c>@custom_types</c> but requires the specified module to export -<c>Mod:Type(encode|decode, Name, Data)</c> rather than -<c>Mod:Name(encode|decode, Type, Data)</c>.</p> +<c>Mod:Type(encode|decode, Name, Data, Opts)</c> rather than +<c>Mod:Name(encode|decode, Type, Data, Opts)</c>.</p> <p> Example:</p> diff --git a/lib/diameter/include/diameter_gen.hrl b/lib/diameter/include/diameter_gen.hrl index 6531e9528c..fb6370fe54 100644 --- a/lib/diameter/include/diameter_gen.hrl +++ b/lib/diameter/include/diameter_gen.hrl @@ -20,716 +20,36 @@ %% %% This file contains code that's included by encode/decode modules -%% generated by diameter_codegen.erl. This code does most of the work, the -%% generated code being kept simple. +%% generated by diameter_codegen.erl. This code used to do most of the +%% work, but now passes it off to module diameter_gen. %% --define(THROW(T), throw({?MODULE, T})). +%% encode_avps/3 -%% Tag common to generated dictionaries. --define(TAG, diameter_gen). +encode_avps(Name, Vals, Opts) -> + diameter_gen:encode_avps(Name, Vals, Opts#{module => ?MODULE}). -%% Key to a value in the process dictionary that determines whether or -%% not an unrecognized AVP setting the M-bit should be regarded as an -%% error or not. See is_strict/0. This is only used to relax M-bit -%% interpretation inside Grouped AVPs not setting the M-bit. The -%% service_opt() strict_mbit can be used to disable the check -%% globally. --define(STRICT_KEY, strict). +%% decode_avps/2 -%% Key that says whether or not we should do a best-effort decode -%% within Failed-AVP. --define(FAILED_KEY, failed). +decode_avps(Name, Recs, Opts) -> + diameter_gen:decode_avps(Name, Recs, Opts#{module => ?MODULE}). --type parent_name() :: atom(). %% parent = Message or AVP --type parent_record() :: tuple(). %% --type avp_name() :: atom(). --type avp_record() :: tuple(). --type avp_values() :: [{avp_name(), term()}]. +%% avp/5 --type non_grouped_avp() :: #diameter_avp{}. --type grouped_avp() :: nonempty_improper_list(#diameter_avp{}, [avp()]). --type avp() :: non_grouped_avp() | grouped_avp(). +avp(T, Data, Name, Opts, Mod) -> + Mod:avp(T, Data, Name, Opts#{module := Mod}). -%% Use a (hopefully) unique key when manipulating the process -%% dictionary. +%% grouped_avp/4 -putr(K,V) -> - put({?TAG, K}, V). +grouped_avp(T, Name, Data, Opts) -> + diameter_gen:grouped_avp(T, Name, Data, Opts). -getr(K) -> - case get({?TAG, K}) of - undefined -> - V = erase({?MODULE, K}), %% written in old code - V == undefined orelse putr(K,V), - V; - V -> - V - end. +%% empty_group/2 -eraser(K) -> - erase({?TAG, K}). +empty_group(Name, Opts) -> + diameter_gen:empty_group(Name, Opts). -%% --------------------------------------------------------------------------- -%% # encode_avps/2 -%% --------------------------------------------------------------------------- +%% empty/2 --spec encode_avps(parent_name(), parent_record() | avp_values()) - -> binary() - | no_return(). - -encode_avps(Name, Vals) - when is_list(Vals) -> - encode_avps(Name, '#set-'(Vals, newrec(Name))); - -encode_avps(Name, Rec) -> - try - list_to_binary(encode(Name, Rec)) - catch - throw: {?MODULE, Reason} -> - diameter_lib:log({encode, error}, - ?MODULE, - ?LINE, - {Reason, Name, Rec}), - erlang:error(list_to_tuple(Reason ++ [Name])); - error: Reason -> - Stack = erlang:get_stacktrace(), - diameter_lib:log({encode, failure}, - ?MODULE, - ?LINE, - {Reason, Name, Rec, Stack}), - erlang:error({encode_failure, Reason, Name, Stack}) - end. - -%% encode/2 - -encode(Name, Rec) -> - lists:flatmap(fun(A) -> encode(Name, A, '#get-'(A, Rec)) end, - '#info-'(element(1, Rec), fields)). - -%% encode/3 - -encode(Name, AvpName, Values) -> - e(Name, AvpName, avp_arity(Name, AvpName), Values). - -%% e/4 - -e(_, AvpName, 1, undefined) -> - ?THROW([mandatory_avp_missing, AvpName]); - -e(Name, AvpName, 1, Value) -> - e(Name, AvpName, [Value]); - -e(_, _, {0,_}, []) -> - []; - -e(_, AvpName, _, T) - when not is_list(T) -> - ?THROW([repeated_avp_as_non_list, AvpName, T]); - -e(_, AvpName, {Min, _}, L) - when length(L) < Min -> - ?THROW([repeated_avp_insufficient_arity, AvpName, Min, L]); - -e(_, AvpName, {_, Max}, L) - when Max < length(L) -> - ?THROW([repeated_avp_excessive_arity, AvpName, Max, L]); - -e(Name, AvpName, _, Values) -> - e(Name, AvpName, Values). - -%% e/3 - -e(Name, 'AVP', Values) -> - [pack_AVP(Name, A) || A <- Values]; - -e(_, AvpName, Values) -> - e(AvpName, Values). - -%% e/2 - -e(AvpName, Values) -> - H = avp_header(AvpName), - [diameter_codec:pack_avp(H, avp(encode, V, AvpName)) || V <- Values]. - -%% pack_AVP/2 - -%% No value: assume AVP data is already encoded. The normal case will -%% be when this is passed back from #diameter_packet.errors as a -%% consequence of a failed decode. Any AVP can be encoded this way -%% however, which side-steps any arity checks for known AVP's and -%% could potentially encode something unfortunate. -pack_AVP(_, #diameter_avp{value = undefined} = A) -> - diameter_codec:pack_avp(A); - -%% Missing name for value encode. -pack_AVP(_, #diameter_avp{name = N, value = V}) - when N == undefined; - N == 'AVP' -> - ?THROW([value_with_nameless_avp, N, V]); - -%% Or not. Ensure that 'AVP' is the appropriate field. Note that if we -%% don't know this AVP at all then the encode will fail. -pack_AVP(Name, #diameter_avp{name = AvpName, - value = Data}) -> - 0 == avp_arity(Name, AvpName) - orelse ?THROW([known_avp_as_AVP, Name, AvpName, Data]), - e(AvpName, [Data]). - -%% --------------------------------------------------------------------------- -%% # decode_avps/2 -%% --------------------------------------------------------------------------- - --spec decode_avps(parent_name(), [#diameter_avp{}]) - -> {parent_record(), [avp()], Failed} - when Failed :: [{5000..5999, #diameter_avp{}}]. - -decode_avps(Name, Recs) -> - {Avps, {Rec, Failed}} - = lists:foldl(fun(T,A) -> decode(Name, T, A) end, - {[], {newrec(Name), []}}, - Recs), - {Rec, Avps, Failed ++ missing(Rec, Name, Failed)}. -%% Append 5005 errors so that errors are reported in the order -%% encountered. Failed-AVP should typically contain the first -%% encountered error accordg to the RFC. - -newrec(Name) -> - '#new-'(name2rec(Name)). - -%% 3588: -%% -%% DIAMETER_MISSING_AVP 5005 -%% The request did not contain an AVP that is required by the Command -%% Code definition. If this value is sent in the Result-Code AVP, a -%% Failed-AVP AVP SHOULD be included in the message. The Failed-AVP -%% AVP MUST contain an example of the missing AVP complete with the -%% Vendor-Id if applicable. The value field of the missing AVP -%% should be of correct minimum length and contain zeros. - -missing(Rec, Name, Failed) -> - Avps = lists:foldl(fun({_, #diameter_avp{code = C, vendor_id = V}}, A) -> - sets:add_element({C,V}, A) - end, - sets:new(), - Failed), - [{5005, A} || F <- '#info-'(element(1, Rec), fields), - not has_arity(avp_arity(Name, F), '#get-'(F, Rec)), - #diameter_avp{code = C, vendor_id = V} - = A <- [empty_avp(F)], - not sets:is_element({C,V}, Avps)]. - -%% Maximum arities have already been checked in building the record. - -has_arity({Min, _}, L) -> - has_prefix(Min, L); -has_arity(N, V) -> - N /= 1 orelse V /= undefined. - -%% Compare a non-negative integer and the length of a list without -%% computing the length. -has_prefix(0, _) -> - true; -has_prefix(_, []) -> - false; -has_prefix(N, L) -> - has_prefix(N-1, tl(L)). - -%% empty_avp/1 - -empty_avp(Name) -> - {Code, Flags, VId} = avp_header(Name), - {Name, Type} = avp_name(Code, VId), - #diameter_avp{name = Name, - code = Code, - vendor_id = VId, - is_mandatory = 0 /= (Flags band 2#01000000), - need_encryption = 0 /= (Flags band 2#00100000), - data = empty_value(Name), - type = Type}. - -%% 3588, ch 7: -%% -%% The Result-Code AVP describes the error that the Diameter node -%% encountered in its processing. In case there are multiple errors, -%% the Diameter node MUST report only the first error it encountered -%% (detected possibly in some implementation dependent order). The -%% specific errors that can be described by this AVP are described in -%% the following section. - -%% decode/3 - -decode(Name, #diameter_avp{code = Code, vendor_id = Vid} = Avp, Acc) -> - decode(Name, avp_name(Code, Vid), Avp, Acc). - -%% decode/4 - -%% AVP is defined in the dictionary ... -decode(Name, {AvpName, Type}, Avp, Acc) -> - d(Name, Avp#diameter_avp{name = AvpName, type = Type}, Acc); - -%% ... or not. -decode(Name, 'AVP', Avp, Acc) -> - decode_AVP(Name, Avp, Acc). - -%% 6733, 4.4: -%% -%% Receivers of a Grouped AVP that does not have the 'M' (mandatory) -%% bit set and one or more of the encapsulated AVPs within the group -%% has the 'M' (mandatory) bit set MAY simply be ignored if the -%% Grouped AVP itself is unrecognized. The rule applies even if the -%% encapsulated AVP with its 'M' (mandatory) bit set is further -%% encapsulated within other sub-groups, i.e., other Grouped AVPs -%% embedded within the Grouped AVP. -%% -%% The first sentence is slightly mangled, but take it to mean this: -%% -%% An unrecognized AVP of type Grouped that does not set the 'M' bit -%% MAY be ignored even if one of its encapsulated AVPs sets the 'M' -%% bit. -%% -%% The text above is a change from RFC 3588, which instead says this: -%% -%% Further, if any of the AVPs encapsulated within a Grouped AVP has -%% the 'M' (mandatory) bit set, the Grouped AVP itself MUST also -%% include the 'M' bit set. -%% -%% Both of these texts have problems. If the AVP is unknown then its -%% type is unknown since the type isn't sent over the wire, so the -%% 6733 text becomes a non-statement: don't know that the AVP not -%% setting the M-bit is of type Grouped, therefore can't know that its -%% data consists of encapsulated AVPs, therefore can't but ignore that -%% one of these might set the M-bit. It should be no worse if we know -%% the AVP to have type Grouped. -%% -%% Similarly, for the 3588 text: if we receive an AVP that doesn't set -%% the M-bit and don't know that the AVP has type Grouped then we -%% can't realize that its data contains an AVP that sets the M-bit, so -%% can't regard the AVP as erroneous on this account. Again, it should -%% be no worse if the type is known to be Grouped, but in this case -%% the RFC forces us to regard the AVP as erroneous. This is -%% inconsistent, and the 3588 text has never been enforced. -%% -%% So, if an AVP doesn't set the M-bit then we're free to ignore it, -%% regardless of the AVP's type. If we know the type to be Grouped -%% then we must ignore the M-bit on an encapsulated AVP. That means -%% packing such an encapsulated AVP into an 'AVP' field if need be, -%% not regarding the lack of a specific field as an error as is -%% otherwise the case. (The lack of an AVP-specific field being how we -%% defined the RFC's "unrecognized", which is slightly stronger than -%% "not defined".) - -%% d/3 - -d(Name, Avp, Acc) -> - #diameter_avp{name = AvpName, - data = Data, - type = Type, - is_mandatory = M} - = Avp, - - %% Use the process dictionary is to keep track of whether or not - %% to ignore an M-bit on an encapsulated AVP. Not ideal, but the - %% alternative requires widespread changes to be able to pass the - %% value around through the entire decode. The solution here is - %% simple in comparison, both to implement and to understand. - - Strict = relax(Type, M), - - %% Use the process dictionary again to keep track of whether we're - %% decoding within Failed-AVP and should ignore decode errors - %% altogether. - - Failed = relax(Name), %% Not AvpName or else a failed Failed-AVP - %% decode is packed into 'AVP'. - Mod = dict(Failed), %% Dictionary to decode in. - - %% On decode, a Grouped AVP is represented as a #diameter_avp{} - %% list with AVP as head and component AVPs as tail. On encode, - %% data can be a list of component AVPs. - - try Mod:avp(decode, Data, AvpName) of - V -> - {Avps, T} = Acc, - {H, A} = ungroup(V, Avp), - {[H | Avps], pack_avp(Name, A, T)} - catch - throw: {?TAG, {grouped, Error, ComponentAvps}} -> - g(is_failed(), Error, Name, trim(Avp), Acc, ComponentAvps); - error: Reason -> - d(is_failed(), Reason, Name, trim(Avp), Acc) - after - reset(?STRICT_KEY, Strict), - reset(?FAILED_KEY, Failed) - end. - -%% trim/1 -%% -%% Remove any extra bit that was added in diameter_codec to induce a -%% 5014 error. - -trim(#diameter_avp{data = <<0:1, Bin/binary>>} = Avp) -> - Avp#diameter_avp{data = Bin}; - -trim(Avps) - when is_list(Avps) -> - lists:map(fun trim/1, Avps); - -trim(Avp) -> - Avp. - -%% dict/1 -%% -%% Retrieve the dictionary for the best-effort decode of Failed-AVP, -%% as put by diameter_codec:decode/2. See that function for the -%% explanation. - -dict(true) -> - case get({diameter_codec, dictionary}) of - undefined -> - ?MODULE; - Mod -> - Mod - end; - -dict(_) -> - ?MODULE. - -%% g/5 - -%% Ignore decode errors within Failed-AVP (best-effort) ... -g(true, [_Error | Rec], Name, Avp, Acc, _ComponentAvps) -> - decode_AVP(Name, Avp#diameter_avp{value = Rec}, Acc); -g(true, _Error, Name, Avp, Acc, _ComponentAvps) -> - decode_AVP(Name, Avp, Acc); - -%% ... or not. -g(false, [Error | _Rec], _Name, Avp, Acc, ComponentAvps) -> - g(Error, Avp, Acc, ComponentAvps); -g(false, Error, _Name, Avp, Acc, ComponentAvps) -> - g(Error, Avp, Acc, ComponentAvps). - -%% g/4 - -g({RC, ErrorData}, Avp, Acc, ComponentAvps) -> - {Avps, {Rec, Errors}} = Acc, - E = Avp#diameter_avp{data = [ErrorData]}, - {[[Avp | trim(ComponentAvps)] | Avps], {Rec, [{RC, E} | Errors]}}. - -%% d/5 - -%% Ignore a decode error within Failed-AVP ... -d(true, _, Name, Avp, Acc) -> - decode_AVP(Name, Avp, Acc); - -%% ... or not. Failures here won't be visible since they're a "normal" -%% occurrence if the peer sends a faulty AVP that we need to respond -%% sensibly to. Log the occurrence for traceability, but the peer will -%% also receive info in the resulting answer message. -d(false, Reason, Name, Avp, {Avps, Acc}) -> - Stack = diameter_lib:get_stacktrace(), - diameter_lib:log(decode_error, - ?MODULE, - ?LINE, - {Name, Avp#diameter_avp.name, Stack}), - {Rec, Failed} = Acc, - {[Avp|Avps], {Rec, [rc(Reason, Avp) | Failed]}}. - -%% relax/2 - -%% Set false in the process dictionary as soon as we see a Grouped AVP -%% that doesn't set the M-bit, so that is_strict() can say whether or -%% not to ignore the M-bit on an encapsulated AVP. -relax('Grouped', M) -> - case getr(?STRICT_KEY) of - undefined when not M -> - putr(?STRICT_KEY, M); - _ -> - false - end; -relax(_, _) -> - false. - -is_strict() -> - diameter_codec:getopt(strict_mbit) - andalso false /= getr(?STRICT_KEY). - -%% relax/1 -%% -%% Set true in the process dictionary as soon as we see Failed-AVP. -%% Matching on 'Failed-AVP' assumes that this is the RFC AVP. -%% Strictly, this doesn't need to be the case. - -relax('Failed-AVP') -> - putr(?FAILED_KEY, true); - -relax(_) -> - is_failed(). - -%% is_failed/0 -%% -%% Is the AVP currently being decoded nested within Failed-AVP? Note -%% that this is only true when Failed-AVP is the parent. In -%% particular, it's not true when Failed-AVP itself is being decoded -%% (unless nested). - -is_failed() -> - true == getr(?FAILED_KEY). - -%% is_failed/1 - -is_failed(Name) -> - 'Failed-AVP' == Name orelse is_failed(). - -%% reset/2 - -reset(Key, undefined) -> - eraser(Key); -reset(_, _) -> - ok. - -%% decode_AVP/3 -%% -%% Don't know this AVP: see if it can be packed in an 'AVP' field -%% undecoded. Note that the type field is 'undefined' in this case. - -decode_AVP(Name, Avp, {Avps, Acc}) -> - {[trim(Avp) | Avps], pack_AVP(Name, Avp, Acc)}. - -%% rc/1 - -%% diameter_types will raise an error of this form to communicate -%% DIAMETER_INVALID_AVP_LENGTH (5014). A module specified to a -%% @custom_types tag in a dictionary file can also raise an error of -%% this form. -rc({'DIAMETER', 5014 = RC, _}, #diameter_avp{name = AvpName} = Avp) -> - {RC, Avp#diameter_avp{data = empty_value(AvpName)}}; - -%% 3588: -%% -%% DIAMETER_INVALID_AVP_VALUE 5004 -%% The request contained an AVP with an invalid value in its data -%% portion. A Diameter message indicating this error MUST include -%% the offending AVPs within a Failed-AVP AVP. -rc(_, Avp) -> - {5004, Avp}. - -%% ungroup/2 - --spec ungroup(term(), #diameter_avp{}) - -> {avp(), #diameter_avp{}}. - -%% The decoded value in the Grouped case is as returned by grouped_avp/3: -%% a record and a list of component AVP's. -ungroup(V, #diameter_avp{type = 'Grouped'} = Avp) -> - {Rec, As} = V, - A = Avp#diameter_avp{value = Rec}, - {[A|As], A}; - -%% Otherwise it's just a plain value. -ungroup(V, #diameter_avp{} = Avp) -> - A = Avp#diameter_avp{value = V}, - {A, A}. - -%% pack_avp/3 - -pack_avp(Name, #diameter_avp{name = AvpName} = Avp, Acc) -> - pack_avp(Name, avp_arity(Name, AvpName), Avp, Acc). - -%% pack_avp/4 - -pack_avp(Name, 0, Avp, Acc) -> - pack_AVP(Name, Avp, Acc); - -pack_avp(_, Arity, Avp, Acc) -> - pack(Arity, Avp#diameter_avp.name, Avp, Acc). - -%% pack_AVP/3 - -%% Length failure was induced because of a header/payload length -%% mismatch. The AVP Length is reset to match the received data if -%% this AVP is encoded in an answer message, since the length is -%% computed. -%% -%% Data is a truncated header if command_code = undefined, otherwise -%% payload bytes. The former is padded to the length of a header if -%% the AVP reaches an outgoing encode in diameter_codec. -%% -%% RFC 6733 says that an AVP returned with 5014 can contain a minimal -%% payload for the AVP's type, but in this case we don't know the -%% type. - -pack_AVP(_, #diameter_avp{data = <<0:1, Data/binary>>} = Avp, Acc) -> - {Rec, Failed} = Acc, - {Rec, [{5014, Avp#diameter_avp{data = Data}} | Failed]}; - -pack_AVP(Name, #diameter_avp{is_mandatory = M, name = AvpName} = Avp, Acc) -> - case pack_arity(Name, AvpName, M) of - 0 -> - {Rec, Failed} = Acc, - {Rec, [{if M -> 5001; true -> 5008 end, Avp} | Failed]}; - Arity -> - pack(Arity, 'AVP', Avp, Acc) - end. - -%% Give Failed-AVP special treatment since (1) it'll contain any -%% unrecognized mandatory AVP's and (2) the RFC 3588 grammar failed to -%% allow for Failed-AVP in an answer-message. - -pack_arity(Name, AvpName, M) -> - - %% Not testing just Name /= 'Failed-AVP' means we're changing the - %% packing of AVPs nested within Failed-AVP, but the point of - %% ignoring errors within Failed-AVP is to decode as much as - %% possible, and failing because a mandatory AVP couldn't be - %% packed into a dedicated field defeats that point. Note - %% is_failed/1 since is_failed/0 will return false when packing - %% 'AVP' within Failed-AVP. - - pack_arity(is_failed(Name) - orelse {Name, AvpName} == {'answer-message', 'Failed-AVP'} - orelse not M - orelse not is_strict(), - Name). - -pack_arity(true, Name) -> - avp_arity(Name, 'AVP'); - -pack_arity(false, _) -> - 0. - -%% 3588: -%% -%% DIAMETER_AVP_UNSUPPORTED 5001 -%% The peer received a message that contained an AVP that is not -%% recognized or supported and was marked with the Mandatory bit. A -%% Diameter message with this error MUST contain one or more Failed- -%% AVP AVP containing the AVPs that caused the failure. -%% -%% DIAMETER_AVP_NOT_ALLOWED 5008 -%% A message was received with an AVP that MUST NOT be present. The -%% Failed-AVP AVP MUST be included and contain a copy of the -%% offending AVP. - -%% pack/4 - -pack(Arity, FieldName, Avp, {Rec, _} = Acc) -> - pack('#get-'(FieldName, Rec), Arity, FieldName, Avp, Acc). - -%% pack/5 - -pack(undefined, 1, FieldName, Avp, Acc) -> - p(FieldName, fun(V) -> V end, Avp, Acc); - -%% 3588: -%% -%% DIAMETER_AVP_OCCURS_TOO_MANY_TIMES 5009 -%% A message was received that included an AVP that appeared more -%% often than permitted in the message definition. The Failed-AVP -%% AVP MUST be included and contain a copy of the first instance of -%% the offending AVP that exceeded the maximum number of occurrences -%% - -pack(_, 1, _, Avp, {Rec, Failed}) -> - {Rec, [{5009, Avp} | Failed]}; -pack(L, {_, Max}, FieldName, Avp, Acc) -> - case '*' /= Max andalso has_prefix(Max, L) of - true -> - {Rec, Failed} = Acc, - {Rec, [{5009, Avp} | Failed]}; - false -> - p(FieldName, fun(V) -> [V|L] end, Avp, Acc) - end. - -%% p/4 - -p(F, Fun, Avp, {Rec, Failed}) -> - {'#set-'({F, Fun(value(F, Avp))}, Rec), Failed}. - -value('AVP', Avp) -> - Avp; -value(_, Avp) -> - Avp#diameter_avp.value. - -%% --------------------------------------------------------------------------- -%% # grouped_avp/3 -%% --------------------------------------------------------------------------- - --spec grouped_avp(decode, avp_name(), bitstring()) - -> {avp_record(), [avp()]}; - (encode, avp_name(), avp_record() | avp_values()) - -> binary() - | no_return(). - -%% Length error induced by diameter_codec:collect_avps/1: the AVP -%% length in the header was too short (insufficient for the extracted -%% header) or too long (past the end of the message). An empty payload -%% is sufficient according to the RFC text for 5014. -grouped_avp(decode, _Name, <<0:1, _/binary>>) -> - throw({?TAG, {grouped, {5014, []}, []}}); - -grouped_avp(decode, Name, Data) -> - grouped_decode(Name, diameter_codec:collect_avps(Data)); - -grouped_avp(encode, Name, Data) -> - encode_avps(Name, Data). - -%% grouped_decode/2 -%% -%% Note that Grouped is the only AVP type that doesn't just return a -%% decoded value, also returning the list of component diameter_avp -%% records. - -%% Length error in trailing component AVP. -grouped_decode(_Name, {Error, Acc}) -> - {5014, Avp} = Error, - throw({?TAG, {grouped, Error, [Avp | Acc]}}); - -%% 7.5. Failed-AVP AVP - -%% In the case where the offending AVP is embedded within a Grouped AVP, -%% the Failed-AVP MAY contain the grouped AVP, which in turn contains -%% the single offending AVP. The same method MAY be employed if the -%% grouped AVP itself is embedded in yet another grouped AVP and so on. -%% In this case, the Failed-AVP MAY contain the grouped AVP hierarchy up -%% to the single offending AVP. This enables the recipient to detect -%% the location of the offending AVP when embedded in a group. - -%% An error in decoding a component AVP throws the first fauly -%% component, which the catch in d/3 wraps in the Grouped AVP in -%% question. A partially decoded record is only used when ignoring -%% errors in Failed-AVP. -grouped_decode(Name, ComponentAvps) -> - {Rec, Avps, Es} = decode_avps(Name, ComponentAvps), - [] == Es orelse throw({?TAG, {grouped, [{_,_} = hd(Es) | Rec], Avps}}), - {Rec, Avps}. - -%% --------------------------------------------------------------------------- -%% # empty_group/1 -%% --------------------------------------------------------------------------- - -empty_group(Name) -> - list_to_binary(empty_body(Name)). - -empty_body(Name) -> - [z(F, avp_arity(Name, F)) || F <- '#info-'(name2rec(Name), fields)]. - -z(Name, 1) -> - z(Name); -z(_, {0,_}) -> - []; -z(Name, {Min, _}) -> - lists:duplicate(Min, z(Name)). - -z('AVP') -> - <<0:64/integer>>; %% minimal header -z(Name) -> - Bin = diameter_codec:pack_avp(avp_header(Name), empty_value(Name)), - << <<0>> || <<_>> <= Bin >>. - -%% --------------------------------------------------------------------------- -%% # empty/1 -%% --------------------------------------------------------------------------- - -empty(AvpName) -> - avp(encode, zero, AvpName). +empty(Name, Opts) -> + diameter_gen:empty(Name, Opts). diff --git a/lib/diameter/src/base/diameter.erl b/lib/diameter/src/base/diameter.erl index 253f64133c..bd92e16fba 100644 --- a/lib/diameter/src/base/diameter.erl +++ b/lib/diameter/src/base/diameter.erl @@ -406,4 +406,5 @@ call(SvcName, App, Message) -> :: {extra, list()} | {filter, peer_filter()} | {timeout, 'Unsigned32'()} + | {peer, peer_ref()} | detach. diff --git a/lib/diameter/src/base/diameter_capx.erl b/lib/diameter/src/base/diameter_capx.erl index 07a678c617..62b05644b2 100644 --- a/lib/diameter/src/base/diameter_capx.erl +++ b/lib/diameter/src/base/diameter_capx.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2015. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -94,6 +94,9 @@ recv_CER(CER, Svc, Dict) -> recv_CEA(CEA, Svc, Dict) -> try_it([fun rCEA/3, CEA, Svc, Dict]). +-spec make_caps(#diameter_caps{}, [{atom(), term()}]) + -> tried(#diameter_caps{}). + make_caps(Caps, Opts) -> try_it([fun mk_caps/2, Caps, Opts]). @@ -110,31 +113,20 @@ try_it([Fun | Args]) -> %% mk_caps/2 mk_caps(Caps0, Opts) -> - {Caps, _} = lists:foldl(fun set_cap/2, - {Caps0, #diameter_caps{_ = false}}, - Opts), - Caps. - --define(SC(K,F), - set_cap({K, Val}, {Caps, #diameter_caps{F = false} = C}) -> - {Caps#diameter_caps{F = cap(K, copy(Val))}, - C#diameter_caps{F = true}}). - -?SC('Origin-Host', origin_host); -?SC('Origin-Realm', origin_realm); -?SC('Host-IP-Address', host_ip_address); -?SC('Vendor-Id', vendor_id); -?SC('Product-Name', product_name); -?SC('Origin-State-Id', origin_state_id); -?SC('Supported-Vendor-Id', supported_vendor_id); -?SC('Auth-Application-Id', auth_application_id); -?SC('Inband-Security-Id', inband_security_id); -?SC('Acct-Application-Id', acct_application_id); -?SC('Vendor-Specific-Application-Id', vendor_specific_application_id); -?SC('Firmware-Revision', firmware_revision); - -set_cap({Key, _}, _) -> - ?THROW({duplicate, Key}). + Fields = diameter_gen_base_rfc3588:'#info-'(diameter_base_CER, fields), + Defs = lists:zip(Fields, tl(tuple_to_list(Caps0))), + Unset = maps:from_list([{F, true} || F <- lists:droplast(Fields)]), %% no 'AVP' + {Caps, _} = lists:foldl(fun set_cap/2, {Defs, Unset}, Opts), + #diameter_caps{} = list_to_tuple([diameter_caps | [V || {_,V} <- Caps]]). + +set_cap({F,V}, {Caps, Unset}) -> + case Unset of + #{F := true} -> + {lists:keyreplace(F, 1, Caps, {F, cap(F, copy(V))}), + maps:remove(F, Unset)}; + _ -> + ?THROW({duplicate, F}) + end. cap(K, V) when K == 'Origin-Host'; @@ -349,7 +341,7 @@ cs(LS, RS) -> cea_from_cer(CER, Dict) -> RecName = Dict:msg2rec('CEA'), [_ | Values] = Dict:'#get-'(CER), - Dict:'#set-'(Values, Dict:'#new-'(RecName)). + Dict:'#new-'([RecName | Values]). %% rCEA/3 @@ -424,7 +416,48 @@ bcaps(N, Caps) -> %% common_applications/3 %% %% Identify the (local) applications to be supported on the connection -%% in question. +%% in question. The RFC says this: +%% +%% 2.4 Application Identifiers +%% +%% Relay and redirect agents MUST advertise the Relay Application ID, +%% while all other Diameter nodes MUST advertise locally supported +%% applications. +%% +%% Taken literally, every Diameter node should then advertise support +%% for the Diameter common messages application, with id 0, since no +%% node can perform capabilities exchange without it. Expecting this, +%% or regarding the support as implicit, renders the Result-Code 5010 +%% (DIAMETER_NO_COMMON_APPLICATION) meaningless however, since every +%% node would regard the common application as being in common with +%% the peer. In practice, nodes may or may not advertise support for +%% Diameter common messages. +%% +%% That only explicitly advertised applications should be considered +%% when computing the intersection with the peer is supported here: +%% +%% 5.3. Capabilities Exchange +%% +%% The receiver of the Capabilities-Exchange-Request (CER) MUST +%% determine common applications by computing the intersection of its +%% own set of supported Application Ids against all of the +%% Application-Id AVPs (Auth-Application-Id, Acct-Application-Id, and +%% Vendor-Specific-Application-Id) present in the CER. +%% +%% The same section also has the following about capabilities exchange +%% messages. +%% +%% The receiver only issues commands to its peers that have advertised +%% support for the Diameter application that defines the command. +%% +%% This statement is also difficult to interpret literally since it +%% would disallow D[WP]R and more when Diameter common messages isn't +%% advertised. In practice, diameter lets requests be sent as long as +%% there's a dictionary configured to support it, peer selection by +%% advertised application being possible to preempt by passing +%% candidate peers directly to diameter:call/4. The peer can always +%% answer 3001 (DIAMETER_COMMAND_UNSUPPORTED) or 3007 +%% (DIAMETER_APPLICATION_UNSUPPORTED) if this is objectionable. common_applications(LCaps, RCaps, #diameter_service{applications = Apps}) -> LA = app_union(LCaps), diff --git a/lib/diameter/src/base/diameter_codec.erl b/lib/diameter/src/base/diameter_codec.erl index 1ea5357924..82fa796e69 100644 --- a/lib/diameter/src/base/diameter_codec.erl +++ b/lib/diameter/src/base/diameter_codec.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2015. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -20,11 +20,8 @@ -module(diameter_codec). --export([encode/2, - decode/2, - decode/3, - setopts/1, - getopt/1, +-export([encode/2, encode/3, + decode/2, decode/3, decode/4, collect_avps/1, decode_header/1, sequence_numbers/1, @@ -33,13 +30,17 @@ msg_id/1]). %% Towards generated encoders (from diameter_gen.hrl). --export([pack_avp/1, +-export([pack_data/2, pack_avp/2]). -include_lib("diameter/include/diameter.hrl"). -include("diameter_internal.hrl"). --define(MASK(N,I), ((I) band (1 bsl (N)))). +-define(PAD(Len), ((4 - (Len rem 4)) rem 4)). +-define(BIT(B,I), (if B -> I; true -> 0 end)). +-define(BIT(B), ?BIT(B,1)). +-define(FLAGS(R,P,E,T), ?BIT(R):1, ?BIT(P):1, ?BIT(E):1, ?BIT(T):1, 0:4). +-define(FLAG(B,D), (if is_boolean(B) -> B; true -> 0 /= (D) end)). -type u32() :: 0..16#FFFFFFFF. -type u24() :: 0..16#FFFFFF. @@ -62,62 +63,29 @@ %% +-+-+-+-+-+-+-+-+-+-+-+-+- %%% --------------------------------------------------------------------------- -%%% # setopts/1 -%%% # getopt/1 +%%% # encode/2 %%% --------------------------------------------------------------------------- -%% These functions are a compromise in the same vein as the use of the -%% process dictionary in diameter_gen.hrl in generated codec modules. -%% Instead of rewriting the entire dictionary generation to pass -%% encode/decode options around, the calling process sets them by -%% calling setopts/1. At current, the only option is whether or not to -%% decode binaries as strings, which is used by diameter_types. - -setopts(Opts) - when is_list(Opts) -> - lists:foreach(fun setopt/1, Opts). - -%% The default string_decode true is for backwards compatibility. -setopt({K, false = B}) - when K == string_decode; - K == strict_mbit -> - setopt(K, B); - -%% Regard anything but the generated RFC 3588 dictionary as modern. -%% This affects the interpretation of defaults during the decode -%% of values of type DiameterURI, this having changed from RFC 3588. -%% (So much for backwards compatibility.) -setopt({common_dictionary, diameter_gen_base_rfc3588}) -> - setopt(rfc, 3588); - -setopt(_) -> - ok. - -setopt(Key, Value) -> - put({diameter, Key}, Value). - -getopt(Key) -> - case get({diameter, Key}) of - undefined when Key == string_decode; - Key == strict_mbit -> - true; - undefined when Key == rfc -> - 6733; - V -> - V - end. +%% The representative encode documented in diameter_codec(3). As of +%% the options that affect encode (eg. ordered_encode), it's no longer +%% *the* encode. + +encode(Mod, Msg) -> + encode(Mod, #{ordered_encode => true}, Msg). %%% --------------------------------------------------------------------------- -%%% # encode/2 +%%% # encode/3 %%% --------------------------------------------------------------------------- --spec encode(module(), Msg :: term()) +-spec encode(module(), + map(), + Msg :: term()) -> #diameter_packet{} | no_return(). -encode(Mod, #diameter_packet{} = Pkt) -> +encode(Mod, Opts, #diameter_packet{} = Pkt) -> try - e(Mod, Pkt) + enc(Mod, Opts, Pkt) catch exit: {Reason, Stack, #diameter_header{} = H} = T -> %% Exit with a header in the reason to let the caller @@ -130,91 +98,97 @@ encode(Mod, #diameter_packet{} = Pkt) -> exit({?MODULE, encode, T}) end; -encode(Mod, Msg) -> +encode(Mod, Opts, Msg) -> Seq = diameter_session:sequence(), Hdr = #diameter_header{version = ?DIAMETER_VERSION, end_to_end_id = Seq, hop_by_hop_id = Seq}, - encode(Mod, #diameter_packet{header = Hdr, - msg = Msg}). + encode(Mod, Opts, #diameter_packet{header = Hdr, + msg = Msg}). + +%% enc/3 -e(_, #diameter_packet{msg = [#diameter_header{} = Hdr | As]} = Pkt) -> - try encode_avps(reorder(As)) of +enc(_, Opts, #diameter_packet{msg = [#diameter_header{} = Hdr | As]} + = Pkt) -> + try encode_avps(reorder(As), Opts) of Avps -> - Length = size(Avps) + 20, + Bin = list_to_binary(Avps), + Len = 20 + size(Bin), #diameter_header{version = Vsn, + is_request = R, + is_proxiable = P, + is_error = E, + is_retransmitted = T, cmd_code = Code, application_id = Aid, hop_by_hop_id = Hid, end_to_end_id = Eid} = Hdr, - Flags = make_flags(0, Hdr), - Pkt#diameter_packet{header = Hdr, - bin = <<Vsn:8, Length:24, - Flags:8, Code:24, + bin = <<Vsn:8, Len:24, + ?FLAGS(R,P,E,T), Code:24, Aid:32, Hid:32, Eid:32, - Avps/binary>>} + Bin/binary>>} catch error: Reason -> exit({Reason, diameter_lib:get_stacktrace(), Hdr}) end; -e(Mod, #diameter_packet{header = Hdr0, msg = Msg} = Pkt) -> +enc(Mod, Opts, #diameter_packet{header = Hdr0, msg = Msg} = Pkt) -> + MsgName = rec2msg(Mod, Msg), + {Code, Flags, Aid} = msg_header(Mod, MsgName, Hdr0), + #diameter_header{version = Vsn, + is_request = R, + is_proxiable = P, + is_error = E, + is_retransmitted = T, hop_by_hop_id = Hid, end_to_end_id = Eid} = Hdr0, - MsgName = rec2msg(Mod, Msg), - {Code, Flags0, Aid} = msg_header(Mod, MsgName, Hdr0), - Flags = make_flags(Flags0, Hdr0), - Hdr = Hdr0#diameter_header{cmd_code = Code, - application_id = Aid, - is_request = 0 /= ?MASK(7, Flags), - is_proxiable = 0 /= ?MASK(6, Flags), - is_error = 0 /= ?MASK(5, Flags), - is_retransmitted = 0 /= ?MASK(4, Flags)}, + RB = ?FLAG(R, Flags band 2#10000000), + PB = ?FLAG(P, Flags band 2#01000000), + EB = ?FLAG(E, Flags band 2#00100000), + TB = ?FLAG(T, Flags band 2#00010000), + Values = values(Msg), - try encode_avps(Mod, MsgName, Values) of + try encode_avps(Mod, MsgName, Values, Opts) of Avps -> - Length = size(Avps) + 20, - Pkt#diameter_packet{header = Hdr#diameter_header{length = Length}, - bin = <<Vsn:8, Length:24, - Flags:8, Code:24, + Bin = list_to_binary(Avps), + Len = 20 + size(Bin), + + Hdr = Hdr0#diameter_header{length = Len, + cmd_code = Code, + application_id = Aid, + is_request = RB, + is_proxiable = PB, + is_error = EB, + is_retransmitted = TB}, + + Pkt#diameter_packet{header = Hdr, + bin = <<Vsn:8, Len:24, + ?FLAGS(RB, PB, EB, TB), Code:24, Aid:32, Hid:32, Eid:32, - Avps/binary>>} + Bin/binary>>} catch error: Reason -> + Hdr = Hdr0#diameter_header{cmd_code = Code, + application_id = Aid, + is_request = RB, + is_proxiable = PB, + is_error = EB, + is_retransmitted = TB}, exit({Reason, diameter_lib:get_stacktrace(), Hdr}) end. -%% make_flags/2 - -make_flags(Flags0, #diameter_header{is_request = R, - is_proxiable = P, - is_error = E, - is_retransmitted = T}) -> - {Flags, 3} = lists:foldl(fun(B,{F,N}) -> {mf(B,F,N), N-1} end, - {Flags0, 7}, - [R,P,E,T]), - Flags. - -mf(undefined, F, _) -> - F; -mf(B, F, N) -> %% reset the affected bit - (F bxor (F band (1 bsl N))) bor bit(B, N). - -bit(true, N) -> 1 bsl N; -bit(false, _) -> 0. - %% values/1 values([H|T]) @@ -223,7 +197,7 @@ values([H|T]) values(Avps) -> Avps. -%% encode_avps/3 +%% encode_avps/4 %% Specifying values as a #diameter_avp list bypasses arity and other %% checks: the values are expected to be already encoded and the AVP's @@ -231,12 +205,12 @@ values(Avps) -> %% these have to be able to resend whatever comes. %% Message as a list of #diameter_avp{} ... -encode_avps(_, _, [#diameter_avp{} | _] = Avps) -> - encode_avps(reorder(Avps)); +encode_avps(_, _, [#diameter_avp{} | _] = Avps, Opts) -> + encode_avps(reorder(Avps), Opts); %% ... or as a tuple list or record. -encode_avps(Mod, MsgName, Values) -> - Mod:encode_avps(MsgName, Values). +encode_avps(Mod, MsgName, Values, Opts) -> + Mod:encode_avps(MsgName, Values, Opts). %% reorder/1 %% @@ -277,10 +251,10 @@ reorder([H | T], Acc) -> reorder([], _) -> false. -%% encode_avps/1 +%% encode_avps/2 -encode_avps(Avps) -> - list_to_binary(lists:map(fun pack_avp/1, Avps)). +encode_avps(Avps, Opts) -> + [pack_avp(A, Opts) || A <- Avps]. %% msg_header/3 @@ -308,38 +282,50 @@ rec2msg(Mod, Rec) -> %%% # decode/2 %%% --------------------------------------------------------------------------- +%% The representative default decode documented in diameter_codec(3). +%% As of the options that affect decode (eg. string_decode), it's no +%% longer *the* decode. + +decode(Mod, Pkt) -> + Opts = #{string_decode => true, + strict_mbit => true, + rfc => 6733}, + decode(Mod, Opts, Pkt). + +%%% --------------------------------------------------------------------------- +%%% # decode/3 +%%% --------------------------------------------------------------------------- + %% Unsuccessfully decoded AVPs will be placed in #diameter_packet.errors. --spec decode(module() | {module(), module()}, #diameter_packet{} | binary()) +-spec decode(module() | {module(), module()}, + map(), + #diameter_packet{} | binary()) -> #diameter_packet{}. %% An Answer setting the E-bit. The application dictionary is needed -%% for the best-effort decode of Failed-AVP, and the best way to make -%% this available to the AVP decode in diameter_gen.hrl, without -%% having to rewrite the entire codec generation, is to place it in -%% the process dictionary. It's the code in diameter_gen.hrl (that's -%% included by every generated codec module) that looks for the entry. -%% Not ideal, but it solves the problem relatively simply. -decode({Mod, Mod}, Pkt) -> - decode(Mod, Pkt); -decode({Mod, AppMod}, Pkt) -> - Key = {?MODULE, dictionary}, - put(Key, AppMod), - try - decode(Mod, Pkt) - after - erase(Key) - end; +%% for the best-effort decode of Failed-AVP. +decode({Mod, AppMod}, Opts, Pkt) -> + decode(Mod, AppMod, Opts, Pkt); %% Or not: a request, or an answer not setting the E-bit. -decode(Mod, Pkt) -> - decode(Mod:id(), Mod, Pkt). +decode(Mod, Opts, Pkt) -> + decode(Mod, Mod, Opts, Pkt). + +%% decode/4 + +decode(Id, Mod, Opts, Pkt) + when is_integer(Id) -> + decode(Id, Mod, Mod, Opts, Pkt); -%% decode/3 +decode(Mod, AppMod, Opts, Pkt) -> + decode(Mod:id(), Mod, AppMod, Opts, Pkt). + +%% decode/5 %% Relay application: just extract the avp's without any decoding of %% their data since we don't know the application in question. -decode(?APP_ID_RELAY, _, #diameter_packet{} = Pkt) -> +decode(?APP_ID_RELAY, _, _, _, #diameter_packet{} = Pkt) -> case collect_avps(Pkt) of {E, As} -> Pkt#diameter_packet{avps = As, @@ -349,7 +335,7 @@ decode(?APP_ID_RELAY, _, #diameter_packet{} = Pkt) -> end; %% Otherwise decode using the dictionary. -decode(_, Mod, #diameter_packet{header = Hdr} = Pkt) -> +decode(_, Mod, AppMod, Opts, #diameter_packet{header = Hdr} = Pkt) -> #diameter_header{cmd_code = CmdCode, is_request = IsRequest, is_error = IsError} @@ -361,29 +347,33 @@ decode(_, Mod, #diameter_packet{header = Hdr} = Pkt) -> Mod:msg_name(CmdCode, IsRequest) end, - decode_avps(MsgName, Mod, Pkt, collect_avps(Pkt)); + decode_avps(MsgName, Mod, AppMod, Opts, Pkt, collect_avps(Pkt)); -decode(Id, Mod, Bin) +decode(Id, Mod, AppMod, Opts, Bin) when is_binary(Bin) -> - decode(Id, Mod, #diameter_packet{header = decode_header(Bin), bin = Bin}). + decode(Id, Mod, AppMod, Opts, #diameter_packet{header = decode_header(Bin), + bin = Bin}). -%% decode_avps/4 +%% decode_avps/6 -decode_avps(MsgName, Mod, Pkt, {E, Avps}) -> +decode_avps(MsgName, Mod, AppMod, Opts, Pkt, {E, Avps}) -> ?LOG(invalid_avp_length, Pkt#diameter_packet.header), #diameter_packet{errors = Failed} = P - = decode_avps(MsgName, Mod, Pkt, Avps), + = decode_avps(MsgName, Mod, AppMod, Opts, Pkt, Avps), P#diameter_packet{errors = [E | Failed]}; -decode_avps('', _, Pkt, Avps) -> %% unknown message ... +decode_avps('', _, _, _, Pkt, Avps) -> %% unknown message ... ?LOG(unknown_message, Pkt#diameter_packet.header), Pkt#diameter_packet{avps = lists:reverse(Avps), errors = [3001]}; %% DIAMETER_COMMAND_UNSUPPORTED %% msg = undefined identifies this case. -decode_avps(MsgName, Mod, Pkt, Avps) -> %% ... or not - {Rec, As, Errors} = Mod:decode_avps(MsgName, Avps), +decode_avps(MsgName, Mod, AppMod, Opts, Pkt, Avps) -> %% ... or not + {Rec, As, Errors} = Mod:decode_avps(MsgName, + Avps, + Opts#{dictionary => AppMod, + failed_avp => false}), ?LOGC([] /= Errors, decode_errors, Pkt#diameter_packet.header), Pkt#diameter_packet{msg = Rec, errors = Errors, @@ -399,14 +389,12 @@ decode_avps(MsgName, Mod, Pkt, Avps) -> %% ... or not decode_header(<<Version:8, MsgLength:24, - CmdFlags:1/binary, + R:1, P:1, E:1, T:1, _:4, CmdCode:24, ApplicationId:32, HopByHopId:32, EndToEndId:32, _/binary>>) -> - <<R:1, P:1, E:1, T:1, _:4>> - = CmdFlags, %% 3588 (ch 3) says that reserved bits MUST be set to 0 and ignored %% by the receiver. @@ -518,7 +506,7 @@ msg_id(#diameter_packet{header = #diameter_header{} = Hdr}) -> msg_id(#diameter_header{application_id = A, cmd_code = C, is_request = R}) -> - {A, C, if R -> 1; true -> 0 end}; + {A, C, ?BIT(R)}; msg_id(<<_:32, Rbit:1, _:7, CmdCode:24, ApplId:32, _/binary>>) -> {ApplId, CmdCode, Rbit}. @@ -537,24 +525,14 @@ msg_id(<<_:32, Rbit:1, _:7, CmdCode:24, ApplId:32, _/binary>>) -> when Avp :: #diameter_avp{}, Error :: {5014, #diameter_avp{}}. -collect_avps(#diameter_packet{bin = Bin}) -> - <<_:20/binary, Avps/binary>> = Bin, - collect_avps(Avps); +collect_avps(#diameter_packet{bin = <<_:20/binary, Avps/binary>>}) -> + collect_avps(Avps, 0, []); collect_avps(Bin) when is_binary(Bin) -> collect_avps(Bin, 0, []). -collect_avps(<<>>, _, Acc) -> - Acc; -collect_avps(Bin, N, Acc) -> - try split_avp(Bin) of - {Rest, AVP} -> - collect_avps(Rest, N+1, [AVP#diameter_avp{index = N} | Acc]) - catch - ?FAILURE(Error) -> - {Error, Acc} - end. +%% collect_avps/3 %% 0 1 2 3 %% 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 @@ -568,32 +546,65 @@ collect_avps(Bin, N, Acc) -> %% | Data ... %% +-+-+-+-+-+-+-+-+ -%% split_avp/1 +collect_avps(<<Code:32, V:1, M:1, P:1, _:5, Len:24, I:V/unit:32, Rest/binary>>, + N, + Acc) -> + collect_avps(Code, + if 1 == V -> I; 0 == V -> undefined end, + 1 == M, + 1 == P, + Len - 8 - V*4, %% Might be negative, which ensures + ?PAD(Len), %% failure of the Data match below. + Rest, + N, + Acc); -split_avp(Bin) -> - {Code, V, M, P, Len, HdrLen} = split_head(Bin), - - <<_:HdrLen/binary, Rest/binary>> = Bin, - {Data, B} = split_data(Rest, Len - HdrLen), - - {B, #diameter_avp{code = Code, - vendor_id = V, - is_mandatory = 1 == M, - need_encryption = 1 == P, - data = Data}}. - -%% split_head/1 - -split_head(<<Code:32, 1:1, M:1, P:1, _:5, Len:24, V:32, _/binary>>) -> - {Code, V, M, P, Len, 12}; +collect_avps(<<>>, _, Acc) -> + Acc; -split_head(<<Code:32, 0:1, M:1, P:1, _:5, Len:24, _/binary>>) -> - {Code, undefined, M, P, Len, 8}; +%% Header is truncated. pack_avp/1 will pad this at encode if sent in +%% a Failed-AVP. +collect_avps(Bin, _, Acc) -> + {{5014, #diameter_avp{data = Bin}}, Acc}. + +%% collect_avps/9 + +%% Duplicate the diameter_avp creation in each branch below to avoid +%% modifying the record, which profiling has shown to be a relatively +%% costly part of building the list. + +collect_avps(Code, VendorId, M, P, Len, Pad, Rest, N, Acc) -> + case Rest of + <<Data:Len/binary, _:Pad/binary, T/binary>> -> + Avp = #diameter_avp{code = Code, + vendor_id = VendorId, + is_mandatory = M, + need_encryption = P, + data = Data, + index = N}, + collect_avps(T, N+1, [Avp | Acc]); + _ -> + %% Length in header points past the end of the message, or + %% doesn't span the header. As stated in the 6733 text + %% above, it's sufficient to return a zero-filled minimal + %% payload if this is a request. Do this (in cases that we + %% know the type) by inducing a decode failure and letting + %% the dictionary's decode (in diameter_gen) deal with it. + %% + %% Note that the extra bit can only occur in the trailing + %% AVP of a message or Grouped AVP, since a faulty AVP + %% Length is otherwise indistinguishable from a correct + %% one here, as we don't know the types of the AVPs being + %% extracted. + Avp = #diameter_avp{code = Code, + vendor_id = VendorId, + is_mandatory = M, + need_encryption = P, + data = {5014, Rest}, + index = N}, + [Avp | Acc] + end. -%% Header is truncated. -split_head(Bin) -> - ?THROW({5014, #diameter_avp{data = Bin}}). -%% Note that pack_avp/1 will pad this at encode if sent in a Failed-AVP. %% 3588: %% @@ -626,35 +637,8 @@ split_head(Bin) -> %% the minimum value mean we might not know the identity of the AVP and %% (2) the last sentence covers this case. -%% split_data/3 - -split_data(Bin, Len) -> - Pad = (4 - (Len rem 4)) rem 4, - - %% Len might be negative here, but that ensures the failure of the - %% binary match. - - case Bin of - <<Data:Len/binary, _:Pad/binary, Rest/binary>> -> - {Data, Rest}; - _ -> - %% Header length points past the end of the message, or - %% doesn't span the header. As stated in the 6733 text - %% above, it's sufficient to return a zero-filled minimal - %% payload if this is a request. Do this (in cases that we - %% know the type) by inducing a decode failure and letting - %% the dictionary's decode (in diameter_gen) deal with it. - %% - %% Note that the extra bit can only occur in the trailing - %% AVP of a message or Grouped AVP, since a faulty AVP - %% Length is otherwise indistinguishable from a correct - %% one here, since we don't know the types of the AVPs - %% being extracted. - {<<0:1, Bin/binary>>, <<>>} - end. - %%% --------------------------------------------------------------------------- -%%% # pack_avp/1 +%%% # pack_avp/2 %%% --------------------------------------------------------------------------- %% The normal case here is data as an #diameter_avp{} list or an @@ -664,104 +648,96 @@ split_data(Bin, Len) -> %% Decoded Grouped AVP with decoded components: ignore components %% since they're already encoded in the Grouped AVP. -pack_avp([#diameter_avp{} = Grouped | _Components]) -> - pack_avp(Grouped); +pack_avp([#diameter_avp{} = Grouped | _Components], Opts) -> + pack_avp(Grouped, Opts); %% Grouped AVP whose components need packing. It's intentional that %% this isn't equivalent to [Grouped | Components]: here the %% components need to be encoded before wrapping with the Grouped AVP, %% and the list is flat, nesting being accomplished in the data %% fields. -pack_avp(#diameter_avp{data = [#diameter_avp{} | _] = Components} = Grouped) -> - pack_avp(Grouped#diameter_avp{data = encode_avps(Components)}); +pack_avp(#diameter_avp{data = [#diameter_avp{} | _] = Components} + = Grouped, + Opts) -> + pack_data(Grouped, encode_avps(Components, Opts)); %% Data as a type/value tuple ... -pack_avp(#diameter_avp{data = {Type, Value}} = A) +pack_avp(#diameter_avp{data = {Type, Value}} = A, Opts) when is_atom(Type) -> - pack_avp(A#diameter_avp{data = diameter_types:Type(encode, Value)}); + pack_data(A, diameter_types:Type(encode, Value, Opts)); %% ... with a header in various forms ... -pack_avp(#diameter_avp{data = {{_,_,_} = T, {Type, Value}}}) -> - pack_avp(T, iolist_to_binary(diameter_types:Type(encode, Value))); +pack_avp(#diameter_avp{data = {T, {Type, Value}}}, Opts) -> + pack_data(T, diameter_types:Type(encode, Value, Opts)); -pack_avp(#diameter_avp{data = {{_,_,_} = T, Bin}}) - when is_binary(Bin) -> - pack_avp(T, Bin); +pack_avp(#diameter_avp{data = {T, Data}}, _) -> + pack_data(T, Data); -pack_avp(#diameter_avp{data = {Dict, Name, Value}} = A) -> - {Code, _Flags, Vid} = Hdr = Dict:avp_header(Name), - {Name, Type} = Dict:avp_name(Code, Vid), - pack_avp(A#diameter_avp{data = {Hdr, {Type, Value}}}); +pack_avp(#diameter_avp{data = {Dict, Name, Data}}, Opts) -> + pack_data(Dict:avp_header(Name), Dict:avp(encode, Data, Name, Opts)); %% ... with a truncated header ... -pack_avp(#diameter_avp{code = undefined, data = B}) +pack_avp(#diameter_avp{code = undefined, data = B}, _) when is_binary(B) -> %% Reset the AVP Length of an AVP Header resulting from a 5014 %% error. The RFC doesn't explicitly say to do this but the %% receiver can't correctly extract this and following AVP's %% without a correct length. On the downside, the header doesn't - %% reveal if the received header has been padded. - Pad = 8*header_length(B) - bit_size(B), - Len = size(<<H:5/binary, _:24, T/binary>> = <<B/binary, 0:Pad>>), - <<H/binary, Len:24, T/binary>>; - -%% ... when ignoring errors in Failed-AVP ... -%% ... during a relay encode ... -pack_avp(#diameter_avp{data = <<0:1, B/binary>>} = A) -> - pack_avp(A#diameter_avp{data = B}); - -%% ... or as an iolist. -pack_avp(#diameter_avp{code = Code, - vendor_id = V, - is_mandatory = M, - need_encryption = P, - data = Data}) -> - Flags = lists:foldl(fun flag_avp/2, 0, [{V /= undefined, 2#10000000}, - {M, 2#01000000}, - {P, 2#00100000}]), - pack_avp({Code, Flags, V}, iolist_to_binary(Data)). - -header_length(<<_:32, 1:1, _/bitstring>>) -> + %% reveal if the received header has been padded. Discard bytes + %% from the length header for this reason, to avoid creating a sub + %% binary for no useful reason. + Len = header_length(B), + Sz = min(5, size(B)), + <<B:Sz/binary, 0:(5-Sz)/unit:8, Len:24, 0:(Len-8)/unit:8>>; + +%% Ignoring errors in Failed-AVP or during a relay encode. +pack_avp(#diameter_avp{data = {5014, Data}} = A, _) -> + pack_data(A, Data); + +pack_avp(#diameter_avp{data = Data} = A, _) -> + pack_data(A, Data). + +header_length(<<_:32, 1:1, _/bits>>) -> 12; header_length(_) -> 8. -flag_avp({true, B}, F) -> - F bor B; -flag_avp({false, _}, F) -> - F. - %%% --------------------------------------------------------------------------- -%%% # pack_avp/2 +%%% # pack_data/2 %%% --------------------------------------------------------------------------- -pack_avp({Code, Flags, VendorId}, Bin) - when is_binary(Bin) -> - Sz = size(Bin), - pack_avp(Code, Flags, VendorId, Sz, pad(Sz rem 4, Bin)). - -pad(0, Bin) -> - Bin; -pad(N, Bin) -> - P = 8*(4-N), - <<Bin/binary, 0:P>>. -%% Note that padding is not included in the length field as mandated by -%% the RFC. - -%% pack_avp/5 +pack_data(#diameter_avp{code = Code, + vendor_id = V, + is_mandatory = M, + need_encryption = P}, + Data) -> + Flags = ?BIT(V /= undefined, 2#10000000) + bor ?BIT(M, 2#01000000) + bor ?BIT(P, 2#00100000), + pack(Code, Flags, V, Data); + +pack_data({Code, Flags, VendorId}, Data) -> + pack(Code, Flags, VendorId, Data). + +%% pack/4 + +pack(Code, Flags, VendorId, Data) -> + Sz = iolist_size(Data), + pack(Code, Flags, Sz, VendorId, Data, ?PAD(Sz)). +%% Padding is not included in the length field, as mandated by the RFC. + +%% pack/6 %% %% Prepend the vendor id as required. -pack_avp(Code, Flags, Vid, Sz, Bin) +pack(Code, Flags, Sz, _Vid, Data, Pad) when 0 == Flags band 2#10000000 -> - undefined = Vid, %% sanity check - pack_avp(Code, Flags, Sz, Bin); + pack(Code, Flags, Sz, 0, 0, Data, Pad); -pack_avp(Code, Flags, Vid, Sz, Bin) -> - pack_avp(Code, Flags, Sz+4, <<Vid:32, Bin/binary>>). +pack(Code, Flags, Sz, Vid, Data, Pad) -> + pack(Code, Flags, Sz+4, Vid, 1, Data, Pad). -%% pack_avp/4 +%% pack/7 -pack_avp(Code, Flags, Sz, Bin) -> - Length = Sz + 8, - <<Code:32, Flags:8, Length:24, Bin/binary>>. +pack(Code, Flags, Sz, VId, V, Data, Pad) -> + [<<Code:32, Flags:8, (8+Sz):24, VId:V/unit:32>>, Data, <<0:Pad/unit:8>>]. diff --git a/lib/diameter/src/base/diameter_config.erl b/lib/diameter/src/base/diameter_config.erl index e10804c931..34018ae6d3 100644 --- a/lib/diameter/src/base/diameter_config.erl +++ b/lib/diameter/src/base/diameter_config.erl @@ -277,7 +277,7 @@ start_link() -> start_link(T) -> proc_lib:start_link(?MODULE, init, [T], infinity, []). - + state() -> call(state). @@ -535,12 +535,12 @@ stop(SvcName) -> %% restrict applications so that that there's one while the service %% has many. -add(SvcName, Type, Opts) -> +add(SvcName, Type, Opts0) -> %% Ensure acceptable transport options. This won't catch all %% possible errors (faulty callbacks for example) but it catches %% many. diameter_service:merge_service/2 depends on usable %% capabilities for example. - ok = transport_opts(Opts), + Opts = transport_opts(Opts0), Ref = make_ref(), true = diameter_reg:add_new(?TRANSPORT_KEY(Ref)), @@ -560,7 +560,17 @@ add(SvcName, Type, Opts) -> end. transport_opts(Opts) -> - lists:foreach(fun(T) -> opt(T) orelse ?THROW({invalid, T}) end, Opts). + lists:map(fun topt/1, Opts). + +topt(T) -> + case opt(T) of + {value, X} -> + X; + true -> + T; + false -> + ?THROW({invalid, T}) + end. opt({transport_module, M}) -> is_atom(M); @@ -600,8 +610,15 @@ opt({watchdog_timer, Tmo}) -> opt({watchdog_config, L}) -> is_list(L) andalso lists:all(fun wdopt/1, L); -opt({spawn_opt, Opts}) -> - is_list(Opts); +opt({spawn_opt, {M,F,A}}) + when is_atom(M), is_atom(F), is_list(A) -> + true; +opt({spawn_opt = K, Opts}) -> + if is_list(Opts) -> + {value, {K, spawn_opts(Opts)}}; + true -> + false + end; opt({pool_size, N}) -> is_integer(N) andalso 0 < N; @@ -676,7 +693,7 @@ stop_transport(SvcName, Refs) -> make_config(SvcName, Opts) -> AppOpts = [T || {application, _} = T <- Opts], - Apps = init_apps(AppOpts), + Apps = [init_app(T) || T <- AppOpts], [] == Apps andalso ?THROW(no_apps), @@ -725,9 +742,13 @@ opt(incoming_maxlen, N) when 0 =< N, N < 1 bsl 24 -> N; +opt(spawn_opt, {M,F,A} = T) + when is_atom(M), is_atom(F), is_list(A) -> + T; + opt(spawn_opt, L) when is_list(L) -> - L; + spawn_opts(L); opt(K, false = B) when K == share_peers; @@ -789,6 +810,9 @@ opt(sequence = K, F) -> opt(K, _) -> ?THROW({value, K}). +spawn_opts(L) -> + [T || T <- L, T /= link, T /= monitor]. + sequence({H,N} = T) when 0 =< N, N =< 32, 0 =< H, 0 == H bsr (32-N) -> T; @@ -822,10 +846,7 @@ encode_CER(Opts) -> ?THROW(Reason) end. -init_apps(Opts) -> - lists:foldl(fun app_acc/2, [], lists:reverse(Opts)). - -app_acc({application, Opts} = T, Acc) -> +init_app({application, Opts} = T) -> is_list(Opts) orelse ?THROW(T), [Dict, Mod] = get_opt([dictionary, module], Opts), @@ -834,15 +855,14 @@ app_acc({application, Opts} = T, Acc) -> M = get_opt(call_mutates_state, Opts, false, [true]), A = get_opt(answer_errors, Opts, discard, [callback, report]), P = get_opt(request_errors, Opts, answer_3xxx, [answer, callback]), - [#diameter_app{alias = Alias, - dictionary = Dict, - id = cb(Dict, id), - module = init_mod(Mod), - init_state = ModS, - mutable = M, - options = [{answer_errors, A}, - {request_errors, P}]} - | Acc]. + #diameter_app{alias = Alias, + dictionary = Dict, + id = cb(Dict, id), + module = init_mod(Mod), + init_state = ModS, + mutable = M, + options = [{answer_errors, A}, + {request_errors, P}]}. init_mod(#diameter_callback{} = R) -> init_mod([diameter_callback, R]); diff --git a/lib/diameter/src/base/diameter_dict.erl b/lib/diameter/src/base/diameter_dict.erl deleted file mode 100644 index 7db294a1b1..0000000000 --- a/lib/diameter/src/base/diameter_dict.erl +++ /dev/null @@ -1,154 +0,0 @@ -%% -%% %CopyrightBegin% -%% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%% -%% %CopyrightEnd% -%% - -%% -%% This module provide OTP's dict interface built on top of ets. -%% -%% Note that while the interface is the same as dict the semantics -%% aren't quite. A Dict here is just a table identifier (although -%% this fact can't be used if you want dict/ets-based implementations -%% to be interchangeable) so changes made to the Dict modify the -%% underlying table. For merge/3, the first argument table is modified. -%% -%% The underlying ets table implementing a dict is deleted when the -%% process from which new() was invoked exits and the dict is only -%% writable from this process. -%% -%% The reason for this is to be able to swap dict/ets-based -%% implementations: the former is easier to debug, the latter is -%% faster for larger tables. It's also just a nice interface even -%% when there's no need for swapability. -%% - --module(diameter_dict). - --export([append/3, - append_list/3, - erase/2, - fetch/2, - fetch_keys/1, - filter/2, - find/2, - fold/3, - from_list/1, - is_key/2, - map/2, - merge/3, - new/0, - store/3, - to_list/1, - update/3, - update/4, - update_counter/3]). - -%%% ---------------------------------------------------------- -%%% EXPORTED INTERNAL FUNCTIONS -%%% ---------------------------------------------------------- - -append(Key, Value, Dict) -> - append_list(Key, [Value], Dict). - -append_list(Key, ValueList, Dict) - when is_list(ValueList) -> - update(Key, fun(V) -> V ++ ValueList end, ValueList, Dict). - -erase(Key, Dict) -> - ets:delete(Dict, Key), - Dict. - -fetch(Key, Dict) -> - {ok, V} = find(Key, Dict), - V. - -fetch_keys(Dict) -> - ets:foldl(fun({K,_}, Acc) -> [K | Acc] end, [], Dict). - -filter(Pred, Dict) -> - lists:foreach(fun({K,V}) -> filter(Pred(K,V), K, Dict) end, to_list(Dict)), - Dict. - -find(Key, Dict) -> - case ets:lookup(Dict, Key) of - [{Key, V}] -> - {ok, V}; - [] -> - error - end. - -fold(Fun, Acc0, Dict) -> - ets:foldl(fun({K,V}, Acc) -> Fun(K, V, Acc) end, Acc0, Dict). - -from_list(List) -> - lists:foldl(fun store/2, new(), List). - -is_key(Key, Dict) -> - ets:member(Dict, Key). - -map(Fun, Dict) -> - lists:foreach(fun({K,V}) -> store(K, Fun(K,V), Dict) end, to_list(Dict)), - Dict. - -merge(Fun, Dict1, Dict2) -> - fold(fun(K2,V2,_) -> - update(K2, fun(V1) -> Fun(K2, V1, V2) end, V2, Dict1) - end, - Dict1, - Dict2). - -new() -> - ets:new(?MODULE, [set]). - -store(Key, Value, Dict) -> - store({Key, Value}, Dict). - -to_list(Dict) -> - ets:tab2list(Dict). - -update(Key, Fun, Dict) -> - store(Key, Fun(fetch(Key, Dict)), Dict). - -update(Key, Fun, Initial, Dict) -> - store(Key, map(Key, Fun, Dict, Initial), Dict). - -update_counter(Key, Increment, Dict) - when is_integer(Increment) -> - update(Key, fun(V) -> V + Increment end, Increment, Dict). - -%%% --------------------------------------------------------- -%%% INTERNAL FUNCTIONS -%%% --------------------------------------------------------- - -store({_,_} = T, Dict) -> - ets:insert(Dict, T), - Dict. - -filter(true, _, _) -> - ok; -filter(false, K, Dict) -> - erase(K, Dict). - -map(Key, Fun, Dict, Error) -> - case find(Key, Dict) of - {ok, V} -> - Fun(V); - error -> - Error - end. - diff --git a/lib/diameter/src/base/diameter_gen.erl b/lib/diameter/src/base/diameter_gen.erl new file mode 100644 index 0000000000..e832832876 --- /dev/null +++ b/lib/diameter/src/base/diameter_gen.erl @@ -0,0 +1,709 @@ +%% +%% %CopyrightBegin% +%% +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%% +%% %CopyrightEnd% +%% + +%% +%% This file contains code that encode/decode modules generated by +%% diameter_codegen.erl calls to implement the functionality. This +%% code does most of the work, the generated code being kept simple. +%% + +-module(diameter_gen). + +-export([encode_avps/3, + decode_avps/3, + grouped_avp/4, + empty_group/2, + empty/2]). + +-include_lib("diameter/include/diameter.hrl"). + +-define(THROW(T), throw({?MODULE, T})). + +-type parent_name() :: atom(). %% parent = Message or AVP +-type parent_record() :: tuple(). %% +-type avp_name() :: atom(). +-type avp_record() :: tuple(). +-type avp_values() :: [{avp_name(), term()}]. + +-type non_grouped_avp() :: #diameter_avp{}. +-type grouped_avp() :: nonempty_improper_list(#diameter_avp{}, [avp()]). +-type avp() :: non_grouped_avp() | grouped_avp(). + +%% --------------------------------------------------------------------------- +%% # encode_avps/3 +%% --------------------------------------------------------------------------- + +-spec encode_avps(parent_name(), parent_record() | avp_values(), map()) + -> iolist() + | no_return(). + +encode_avps(Name, Vals, #{module := Mod} = Opts) -> + try + encode(Name, Vals, Opts, Mod) + catch + throw: {?MODULE, Reason} -> + diameter_lib:log({encode, error}, + ?MODULE, + ?LINE, + {Reason, Name, Vals, Mod}), + erlang:error(list_to_tuple(Reason ++ [Name])); + error: Reason -> + Stack = erlang:get_stacktrace(), + diameter_lib:log({encode, failure}, + ?MODULE, + ?LINE, + {Reason, Name, Vals, Mod, Stack}), + erlang:error({encode_failure, Reason, Name, Stack}) + end. + +%% encode/4 + +encode(Name, Vals, #{ordered_encode := false} = Opts, Mod) + when is_list(Vals) -> + lists:map(fun({F,V}) -> encode(Name, F, V, Opts, Mod) end, Vals); + +encode(Name, Vals, Opts, Mod) + when is_list(Vals) -> + encode(Name, Mod:'#set-'(Vals, newrec(Mod, Name)), Opts, Mod); + +encode(Name, Rec, Opts, Mod) -> + [encode(Name, F, V, Opts, Mod) || {F,V} <- Mod:'#get-'(Rec)]. + +%% encode/5 + +encode(Name, AvpName, Values, Opts, Mod) -> + enc(Name, AvpName, Mod:avp_arity(Name, AvpName), Values, Opts, Mod). + +%% enc/6 + +enc(_, AvpName, 1, undefined, _, _) -> + ?THROW([mandatory_avp_missing, AvpName]); + +enc(Name, AvpName, 1, Value, Opts, Mod) -> + enc(Name, AvpName, [Value], Opts, Mod); + +enc(_, _, {0,_}, [], _, _) -> + []; + +enc(_, AvpName, _, T, _, _) + when not is_list(T) -> + ?THROW([repeated_avp_as_non_list, AvpName, T]); + +enc(_, AvpName, {Min, _}, L, _, _) + when length(L) < Min -> + ?THROW([repeated_avp_insufficient_arity, AvpName, Min, L]); + +enc(_, AvpName, {_, Max}, L, _, _) + when Max < length(L) -> + ?THROW([repeated_avp_excessive_arity, AvpName, Max, L]); + +enc(Name, AvpName, _, Values, Opts, Mod) -> + enc(Name, AvpName, Values, Opts, Mod). + +%% enc/5 + +enc(Name, 'AVP', Values, Opts, Mod) -> + [enc_AVP(Name, A, Opts, Mod) || A <- Values]; + +enc(_, AvpName, Values, Opts, Mod) -> + enc(AvpName, Values, Opts, Mod). + +%% enc/4 + +enc(AvpName, Values, Opts, Mod) -> + H = Mod:avp_header(AvpName), + [diameter_codec:pack_data(H, Mod:avp(encode, V, AvpName, Opts)) + || V <- Values]. + +%% enc_AVP/4 + +%% No value: assume AVP data is already encoded. The normal case will +%% be when this is passed back from #diameter_packet.errors as a +%% consequence of a failed decode. Any AVP can be encoded this way +%% however, which side-steps any arity checks for known AVP's and +%% could potentially encode something unfortunate. +enc_AVP(_, #diameter_avp{value = undefined} = A, Opts, _) -> + diameter_codec:pack_avp(A, Opts); + +%% Missing name for value encode. +enc_AVP(_, #diameter_avp{name = N, value = V}, _, _) + when N == undefined; + N == 'AVP' -> + ?THROW([value_with_nameless_avp, N, V]); + +%% Or not. Ensure that 'AVP' is the appropriate field. Note that if we +%% don't know this AVP at all then the encode will fail. +enc_AVP(Name, #diameter_avp{name = AvpName, value = Data}, Opts, Mod) -> + 0 == Mod:avp_arity(Name, AvpName) + orelse ?THROW([known_avp_as_AVP, Name, AvpName, Data]), + enc(AvpName, [Data], Opts, Mod); + +%% The backdoor ... +enc_AVP(_, {AvpName, Value}, Opts, Mod) -> + enc(AvpName, [Value], Opts, Mod); + +%% ... and the side door. +enc_AVP(_Name, {_Dict, _AvpName, _Data} = T, Opts, _) -> + diameter_codec:pack_avp(#diameter_avp{data = T}, Opts). + +%% --------------------------------------------------------------------------- +%% # decode_avps/3 +%% --------------------------------------------------------------------------- + +-spec decode_avps(parent_name(), [#diameter_avp{}], map()) + -> {parent_record(), [avp()], Failed} + when Failed :: [{5000..5999, #diameter_avp{}}]. + +decode_avps(Name, Recs, #{module := Mod} = Opts) -> + {Avps, {Rec, Failed}} + = mapfoldl(fun(T,A) -> decode(Name, Opts, Mod, T, A) end, + {newrec(Mod, Name), []}, + Recs), + {Rec, Avps, Failed ++ missing(Rec, Name, Failed, Opts, Mod)}. +%% Append 5005 errors so that errors are reported in the order +%% encountered. Failed-AVP should typically contain the first +%% encountered error accordg to the RFC. + +%% mapfoldl/3 +%% +%% Like lists:mapfoldl/3, but don't reverse the list. + +mapfoldl(F, Acc, List) -> + mapfoldl(F, Acc, List, []). + +mapfoldl(F, Acc0, [T|Rest], List) -> + {B, Acc} = F(T, Acc0), + mapfoldl(F, Acc, Rest, [B|List]); +mapfoldl(_, Acc, [], List) -> + {List, Acc}. + +%% 3588: +%% +%% DIAMETER_MISSING_AVP 5005 +%% The request did not contain an AVP that is required by the Command +%% Code definition. If this value is sent in the Result-Code AVP, a +%% Failed-AVP AVP SHOULD be included in the message. The Failed-AVP +%% AVP MUST contain an example of the missing AVP complete with the +%% Vendor-Id if applicable. The value field of the missing AVP +%% should be of correct minimum length and contain zeros. + +missing(Rec, Name, Failed, Opts, Mod) -> + Avps = lists:foldl(fun({_, #diameter_avp{code = C, vendor_id = V}}, A) -> + maps:put({C,V}, true, A) + end, + maps:new(), + Failed), + missing(Mod:avp_arity(Name), tl(tuple_to_list(Rec)), Avps, Opts, Mod, []). + +missing([{Name, Arity} | As], [Value | Vs], Avps, Opts, Mod, Acc) -> + missing(As, + Vs, + Avps, + Opts, + Mod, + case + [H || missing_arity(Arity, Value), + {C,_,V} = H <- [Mod:avp_header(Name)], + not maps:is_key({C,V}, Avps)] + of + [H] -> + [{5005, empty_avp(Name, H, Opts, Mod)} | Acc]; + [] -> + Acc + end); + +missing([], [], _, _, _, Acc) -> + Acc. + +%% Maximum arities have already been checked in building the record. + +missing_arity(1, V) -> + V == undefined; +missing_arity({0, _}, _) -> + false; +missing_arity({1, _}, L) -> + [] == L; +missing_arity({Min, _}, L) -> + not has_prefix(Min, L). + +%% Compare a non-negative integer and the length of a list without +%% computing the length. +has_prefix(0, _) -> + true; +has_prefix(_, []) -> + false; +has_prefix(N, [_|L]) -> + has_prefix(N-1, L). + +%% empty_avp/4 + +empty_avp(Name, {Code, Flags, VId}, Opts, Mod) -> + {Name, Type} = Mod:avp_name(Code, VId), + #diameter_avp{name = Name, + code = Code, + vendor_id = VId, + is_mandatory = 0 /= (Flags band 2#01000000), + need_encryption = 0 /= (Flags band 2#00100000), + data = Mod:empty_value(Name, Opts), + type = Type}. + +%% 3588, ch 7: +%% +%% The Result-Code AVP describes the error that the Diameter node +%% encountered in its processing. In case there are multiple errors, +%% the Diameter node MUST report only the first error it encountered +%% (detected possibly in some implementation dependent order). The +%% specific errors that can be described by this AVP are described in +%% the following section. + +%% decode/5 + +decode(Name, + Opts, + Mod, + #diameter_avp{code = Code, vendor_id = Vid} + = Avp, + Acc) -> + decode(Name, Opts, Mod, Mod:avp_name(Code, Vid), Avp, Acc). + +%% decode/6 + +%% AVP not in dictionary. +decode(Name, Opts, Mod, 'AVP', Avp, Acc) -> + decode_AVP(Name, Avp, Opts, Mod, Acc); + +%% 6733, 4.4: +%% +%% Receivers of a Grouped AVP that does not have the 'M' (mandatory) +%% bit set and one or more of the encapsulated AVPs within the group +%% has the 'M' (mandatory) bit set MAY simply be ignored if the +%% Grouped AVP itself is unrecognized. The rule applies even if the +%% encapsulated AVP with its 'M' (mandatory) bit set is further +%% encapsulated within other sub-groups, i.e., other Grouped AVPs +%% embedded within the Grouped AVP. +%% +%% The first sentence is slightly mangled, but take it to mean this: +%% +%% An unrecognized AVP of type Grouped that does not set the 'M' bit +%% MAY be ignored even if one of its encapsulated AVPs sets the 'M' +%% bit. +%% +%% The text above is a change from RFC 3588, which instead says this: +%% +%% Further, if any of the AVPs encapsulated within a Grouped AVP has +%% the 'M' (mandatory) bit set, the Grouped AVP itself MUST also +%% include the 'M' bit set. +%% +%% Both of these texts have problems. If the AVP is unknown then its +%% type is unknown since the type isn't sent over the wire, so the +%% 6733 text becomes a non-statement: don't know that the AVP not +%% setting the M-bit is of type Grouped, therefore can't know that its +%% data consists of encapsulated AVPs, therefore can't but ignore that +%% one of these might set the M-bit. It should be no worse if we know +%% the AVP to have type Grouped. +%% +%% Similarly, for the 3588 text: if we receive an AVP that doesn't set +%% the M-bit and don't know that the AVP has type Grouped then we +%% can't realize that its data contains an AVP that sets the M-bit, so +%% can't regard the AVP as erroneous on this account. Again, it should +%% be no worse if the type is known to be Grouped, but in this case +%% the RFC forces us to regard the AVP as erroneous. This is +%% inconsistent, and the 3588 text has never been enforced. +%% +%% So, if an AVP doesn't set the M-bit then we're free to ignore it, +%% regardless of the AVP's type. If we know the type to be Grouped +%% then we must ignore the M-bit on an encapsulated AVP. That means +%% packing such an encapsulated AVP into an 'AVP' field if need be, +%% not regarding the lack of a specific field as an error as is +%% otherwise the case. (The lack of an AVP-specific field being how we +%% defined the RFC's "unrecognized", which is slightly stronger than +%% "not defined".) + +decode(Name, Opts0, Mod, {AvpName, Type}, Avp, Acc) -> + #diameter_avp{data = Data, is_mandatory = M} + = Avp, + + %% Whether or not to ignore an M-bit on an encapsulated AVP, or on + %% all AVPs with the service_opt() strict_mbit. + Opts1 = set_strict(Type, M, Opts0), + + %% Whether or not we're decoding within Failed-AVP and should + %% ignore decode errors. + #{dictionary := AppMod, failed_avp := Failed} + = Opts + = set_failed(Name, Opts1), %% Not AvpName or else a failed Failed-AVP + %% decode is packed into 'AVP'. + + %% Reset the dictionary for best-effort decode of Failed-AVP. + DecMod = if Failed -> + AppMod; + true -> + Mod + end, + + %% On decode, a Grouped AVP is represented as a #diameter_avp{} + %% list with AVP as head and component AVPs as tail. On encode, + %% data can be a list of component AVPs. + + try avp_decode(Data, AvpName, Opts, DecMod, Mod) of + {Rec, As} when Type == 'Grouped' -> + A = Avp#diameter_avp{name = AvpName, + value = Rec, + type = Type}, + {[A|As], pack_avp(Name, A, Opts, Mod, Acc)}; + + V when Type /= 'Grouped' -> + A = Avp#diameter_avp{name = AvpName, + value = V, + type = Type}, + {A, pack_avp(Name, A, Opts, Mod, Acc)} + catch + throw: {?MODULE, {grouped, Error, ComponentAvps}} -> + decode_error(Name, + Error, + ComponentAvps, + Opts, + Mod, + Avp#diameter_avp{name = AvpName, + data = trim(Avp#diameter_avp.data), + type = Type}, + Acc); + + error: Reason -> + decode_error(Name, + Reason, + Opts, + Mod, + Avp#diameter_avp{name = AvpName, + data = trim(Avp#diameter_avp.data), + type = Type}, + Acc) + end. + +%% avp_decode/5 + +avp_decode(Data, AvpName, Opts, Mod, Mod) -> + Mod:avp(decode, Data, AvpName, Opts); + +avp_decode(Data, AvpName, Opts, Mod, _) -> + Mod:avp(decode, Data, AvpName, Opts, Mod). + +%% trim/1 +%% +%% Remove any extra bit that was added in diameter_codec to induce a +%% 5014 error. + +trim(#diameter_avp{data = Data} = Avp) -> + Avp#diameter_avp{data = trim(Data)}; + +trim({5014, Bin}) -> + Bin; + +trim(Avps) + when is_list(Avps) -> + lists:map(fun trim/1, Avps); + +trim(Avp) -> + Avp. + +%% decode_error/7 + +decode_error(Name, [_ | Rec], _, #{failed_avp := true} = Opts, Mod, Avp, Acc) -> + decode_AVP(Name, Avp#diameter_avp{value = Rec}, Opts, Mod, Acc); + +decode_error(Name, _, _, #{failed_avp := true} = Opts, Mod, Avp, Acc) -> + decode_AVP(Name, Avp, Opts, Mod, Acc); + +decode_error(_, [Error | _], ComponentAvps, _, _, Avp, Acc) -> + decode_error(Error, Avp, Acc, ComponentAvps); + +decode_error(_, Error, ComponentAvps, _, _, Avp, Acc) -> + decode_error(Error, Avp, Acc, ComponentAvps). + +%% decode_error/5 + +decode_error(Name, _Reason, #{failed_avp := true} = Opts, Mod, Avp, Acc) -> + decode_AVP(Name, Avp, Opts, Mod, Acc); + +decode_error(Name, Reason, Opts, Mod, Avp, {Rec, Failed}) -> + Stack = diameter_lib:get_stacktrace(), + diameter_lib:log(decode_error, + ?MODULE, + ?LINE, + {Reason, Name, Avp#diameter_avp.name, Mod, Stack}), + {Avp, {Rec, [rc(Reason, Avp, Opts, Mod) | Failed]}}. + +%% decode_error/4 + +decode_error({RC, ErrorData}, Avp, {Rec, Failed}, ComponentAvps) -> + E = Avp#diameter_avp{data = [ErrorData]}, + {[Avp | trim(ComponentAvps)], {Rec, [{RC, E} | Failed]}}. + +%% set_strict/3 + +%% Set false as soon as we see a Grouped AVP that doesn't set the +%% M-bit, to ignore the M-bit on an encapsulated AVP. +set_strict('Grouped', false = M, #{strict_mbit := true} = Opts) -> + Opts#{strict_mbit := M}; +set_strict(_, _, Opts) -> + Opts. + +%% set_failed/2 +%% +%% Set true as soon as we see Failed-AVP. Matching on 'Failed-AVP' +%% assumes that this is the RFC AVP. Strictly, this doesn't need to be +%% the case. + +set_failed('Failed-AVP', #{failed_avp := false} = Opts) -> + Opts#{failed_avp := true}; +set_failed(_, Opts) -> + Opts. + +%% decode_AVP/5 +%% +%% Don't know this AVP: see if it can be packed in an 'AVP' field +%% undecoded. Note that the type field is 'undefined' in this case. + +decode_AVP(Name, Avp, Opts, Mod, Acc) -> + {trim(Avp), pack_AVP(Name, Avp, Opts, Mod, Acc)}. + +%% rc/2 + +%% diameter_types will raise an error of this form to communicate +%% DIAMETER_INVALID_AVP_LENGTH (5014). A module specified to a +%% @custom_types tag in a dictionary file can also raise an error of +%% this form. +rc({'DIAMETER', 5014 = RC, _}, #diameter_avp{name = AvpName} = Avp, Opts, Mod) -> + {RC, Avp#diameter_avp{data = Mod:empty_value(AvpName, Opts)}}; + +%% 3588: +%% +%% DIAMETER_INVALID_AVP_VALUE 5004 +%% The request contained an AVP with an invalid value in its data +%% portion. A Diameter message indicating this error MUST include +%% the offending AVPs within a Failed-AVP AVP. +rc(_, Avp, _, _) -> + {5004, Avp}. + +%% pack_avp/5 + +pack_avp(Name, #diameter_avp{name = AvpName} = Avp, Opts, Mod, Acc) -> + pack_avp(Name, Mod:avp_arity(Name, AvpName), Avp, Opts, Mod, Acc). + +%% pack_avp/6 + +pack_avp(Name, 0, Avp, Opts, Mod, Acc) -> + pack_AVP(Name, Avp, Opts, Mod, Acc); + +pack_avp(_, Arity, #diameter_avp{name = AvpName} = Avp, _Opts, Mod, Acc) -> + pack(Arity, AvpName, Avp, Mod, Acc). + +%% pack_AVP/5 + +%% Length failure was induced because of a header/payload length +%% mismatch. The AVP Length is reset to match the received data if +%% this AVP is encoded in an answer message, since the length is +%% computed. +%% +%% Data is a truncated header if command_code = undefined, otherwise +%% payload bytes. The former is padded to the length of a header if +%% the AVP reaches an outgoing encode in diameter_codec. +%% +%% RFC 6733 says that an AVP returned with 5014 can contain a minimal +%% payload for the AVP's type, but in this case we don't know the +%% type. + +pack_AVP(_, #diameter_avp{data = {5014 = RC, Data}} = Avp, _, _, Acc) -> + {Rec, Failed} = Acc, + {Rec, [{RC, Avp#diameter_avp{data = Data}} | Failed]}; + +pack_AVP(Name, Avp, Opts, Mod, Acc) -> + pack_arity(Name, pack_arity(Name, Opts, Mod, Avp), Avp, Mod, Acc). + +%% pack_arity/5 + +pack_arity(_, 0, #diameter_avp{is_mandatory = M} = Avp, _, Acc) -> + {Rec, Failed} = Acc, + {Rec, [{if M -> 5001; true -> 5008 end, Avp} | Failed]}; + +pack_arity(_, Arity, Avp, Mod, Acc) -> + pack(Arity, 'AVP', Avp, Mod, Acc). + +%% Give Failed-AVP special treatment since (1) it'll contain any +%% unrecognized mandatory AVP's and (2) the RFC 3588 grammar failed to +%% allow for Failed-AVP in an answer-message. + +pack_arity(Name, + #{strict_mbit := Strict, + failed_avp := Failed}, + Mod, + #diameter_avp{is_mandatory = M, + name = AvpName}) -> + + %% Not testing just Name /= 'Failed-AVP' means we're changing the + %% packing of AVPs nested within Failed-AVP, but the point of + %% ignoring errors within Failed-AVP is to decode as much as + %% possible, and failing because a mandatory AVP couldn't be + %% packed into a dedicated field defeats that point. + + if Failed == true; + Name == 'Failed-AVP'; + Name == 'answer-message', AvpName == 'Failed-AVP'; + not M; + not Strict -> + Mod:avp_arity(Name, 'AVP'); + true -> + 0 + end. + +%% 3588: +%% +%% DIAMETER_AVP_UNSUPPORTED 5001 +%% The peer received a message that contained an AVP that is not +%% recognized or supported and was marked with the Mandatory bit. A +%% Diameter message with this error MUST contain one or more Failed- +%% AVP AVP containing the AVPs that caused the failure. +%% +%% DIAMETER_AVP_NOT_ALLOWED 5008 +%% A message was received with an AVP that MUST NOT be present. The +%% Failed-AVP AVP MUST be included and contain a copy of the +%% offending AVP. + +%% pack/5 + +pack(Arity, FieldName, Avp, Mod, {Rec, _} = Acc) -> + pack(Mod:'#get-'(FieldName, Rec), Arity, FieldName, Avp, Mod, Acc). + +%% pack/6 + +pack(undefined, 1, 'AVP' = F, Avp, Mod, {Rec, Failed}) -> %% unlikely + {Mod:'#set-'({F, Avp}, Rec), Failed}; + +pack(undefined, 1, F, #diameter_avp{value = V}, Mod, {Rec, Failed}) -> + {Mod:'#set-'({F, V}, Rec), Failed}; + +%% 3588: +%% +%% DIAMETER_AVP_OCCURS_TOO_MANY_TIMES 5009 +%% A message was received that included an AVP that appeared more +%% often than permitted in the message definition. The Failed-AVP +%% AVP MUST be included and contain a copy of the first instance of +%% the offending AVP that exceeded the maximum number of occurrences +%% + +pack(_, 1, _, Avp, _, {Rec, Failed}) -> + {Rec, [{5009, Avp} | Failed]}; + +pack(L, {_, Max}, F, Avp, Mod, {Rec, Failed}) -> + case '*' /= Max andalso has_prefix(Max+1, L) of + true -> + {Rec, [{5009, Avp} | Failed]}; + false when F == 'AVP' -> + {Mod:'#set-'({F, [Avp | L]}, Rec), Failed}; + false -> + {Mod:'#set-'({F, [Avp#diameter_avp.value | L]}, Rec), Failed} + end. + +%% --------------------------------------------------------------------------- +%% # grouped_avp/3 +%% --------------------------------------------------------------------------- + +-spec grouped_avp(decode, avp_name(), binary() | {5014, binary()}, term()) + -> {avp_record(), [avp()]}; + (encode, avp_name(), avp_record() | avp_values(), term()) + -> iolist() + | no_return(). + +%% Length error induced by diameter_codec:collect_avps/1: the AVP +%% length in the header was too short (insufficient for the extracted +%% header) or too long (past the end of the message). An empty payload +%% is sufficient according to the RFC text for 5014. +grouped_avp(decode, _Name, {5014 = RC, _Bin}, _) -> + ?THROW({grouped, {RC, []}, []}); + +grouped_avp(decode, Name, Data, Opts) -> + grouped_decode(Name, diameter_codec:collect_avps(Data), Opts); + +grouped_avp(encode, Name, Data, Opts) -> + encode_avps(Name, Data, Opts). + +%% grouped_decode/2 +%% +%% Note that Grouped is the only AVP type that doesn't just return a +%% decoded value, also returning the list of component diameter_avp +%% records. + +%% Length error in trailing component AVP. +grouped_decode(_Name, {Error, Acc}, _) -> + {5014, Avp} = Error, + ?THROW({grouped, Error, [Avp | Acc]}); + +%% 7.5. Failed-AVP AVP + +%% In the case where the offending AVP is embedded within a Grouped AVP, +%% the Failed-AVP MAY contain the grouped AVP, which in turn contains +%% the single offending AVP. The same method MAY be employed if the +%% grouped AVP itself is embedded in yet another grouped AVP and so on. +%% In this case, the Failed-AVP MAY contain the grouped AVP hierarchy up +%% to the single offending AVP. This enables the recipient to detect +%% the location of the offending AVP when embedded in a group. + +%% An error in decoding a component AVP throws the first faulty +%% component, which the catch in d/3 wraps in the Grouped AVP in +%% question. A partially decoded record is only used when ignoring +%% errors in Failed-AVP. +grouped_decode(Name, ComponentAvps, Opts) -> + {Rec, Avps, Es} = decode_avps(Name, ComponentAvps, Opts), + [] == Es orelse ?THROW({grouped, [{_,_} = hd(Es) | Rec], Avps}), + {Rec, Avps}. + +%% --------------------------------------------------------------------------- +%% # empty_group/2 +%% --------------------------------------------------------------------------- + +empty_group(Name, #{module := Mod} = Opts) -> + list_to_binary([z(F, A, Opts, Mod) || {F,A} <- Mod:avp_arity(Name)]). + +z(Name, 1, Opts, Mod) -> + z(Name, Opts, Mod); +z(_, {0,_}, _, _) -> + []; +z(Name, {Min, _}, Opts, Mod) -> + binary:copy(z(Name, Opts, Mod), Min). + +z('AVP', _, _) -> + <<0:64>>; %% minimal header +z(Name, Opts, Mod) -> + Bin = diameter_codec:pack_data(Mod:avp_header(Name), + Mod:empty_value(Name, Opts)), + Sz = iolist_size(Bin), + <<0:Sz/unit:8>>. + +%% --------------------------------------------------------------------------- +%% # empty/2 +%% --------------------------------------------------------------------------- + +empty(Name, #{module := Mod} = Opts) -> + Mod:avp(encode, zero, Name, Opts). + +%% ------------------------------------------------------------------------------ + +newrec(Mod, Name) -> + Mod:'#new-'(Mod:name2rec(Name)). diff --git a/lib/diameter/src/base/diameter_lib.erl b/lib/diameter/src/base/diameter_lib.erl index 3928769b5e..58b9a29812 100644 --- a/lib/diameter/src/base/diameter_lib.erl +++ b/lib/diameter/src/base/diameter_lib.erl @@ -37,7 +37,6 @@ ipaddr/1, spawn_opts/2, wait/1, - fold_tuple/3, fold_n/3, for_n/2, log/4]). @@ -341,36 +340,6 @@ down(MRef) receive {'DOWN', MRef, process, _, _} = T -> T end. %% --------------------------------------------------------------------------- -%% # fold_tuple/3 -%% --------------------------------------------------------------------------- - --spec fold_tuple(N, T0, T) - -> tuple() - when N :: pos_integer(), - T0 :: tuple(), - T :: tuple() - | undefined. - -%% Replace fields in T0 by those of T starting at index N, unless the -%% new value is 'undefined'. -%% -%% eg. fold_tuple(2, Hdr, #diameter_header{end_to_end_id = 42}) - -fold_tuple(_, T, undefined) -> - T; - -fold_tuple(N, T0, T1) -> - {_, T} = lists:foldl(fun(V, {I,_} = IT) -> {I+1, ft(V, IT)} end, - {N, T0}, - lists:nthtail(N-1, tuple_to_list(T1))), - T. - -ft(undefined, {_, T}) -> - T; -ft(Value, {Idx, T}) -> - setelement(Idx, T, Value). - -%% --------------------------------------------------------------------------- %% # fold_n/3 %% --------------------------------------------------------------------------- diff --git a/lib/diameter/src/base/diameter_peer_fsm.erl b/lib/diameter/src/base/diameter_peer_fsm.erl index 46d231da74..1b0dc417e5 100644 --- a/lib/diameter/src/base/diameter_peer_fsm.erl +++ b/lib/diameter/src/base/diameter_peer_fsm.erl @@ -128,7 +128,12 @@ %% outgoing DPR; boolean says whether or not %% the request was sent explicitly with %% diameter:call/4. + codec :: #{string_decode := boolean(), + strict_mbit := boolean(), + rfc := 3588 | 6733, + ordered_encode := false}, strict :: boolean(), + ack = false :: boolean(), length_errors :: exit | handle | discard, incoming_maxlen :: integer() | infinity}). @@ -159,10 +164,7 @@ %% # start/3 %% --------------------------------------------------------------------------- --spec start(T, [Opt], {[diameter:service_opt()], - [node()], - module(), - #diameter_service{}}) +-spec start(T, [Opt], {map(), [node()], module(), #diameter_service{}}) -> {reference(), pid()} when T :: {connect|accept, diameter:transport_ref()}, Opt :: diameter:transport_opt(). @@ -221,9 +223,10 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) -> erlang:monitor(process, WPid), wait(Ack, WPid), diameter_stats:reg(Ref), - diameter_codec:setopts([{common_dictionary, Dict0} | SvcOpts]), - {_,_} = Mask = proplists:get_value(sequence, SvcOpts), - Maxlen = proplists:get_value(incoming_maxlen, SvcOpts, 16#FFFFFF), + + #{sequence := Mask, incoming_maxlen := Maxlen} + = SvcOpts, + {[Cs,Ds], Rest} = proplists:split(Opts, [capabilities_cb, disconnect_cb]), putr(?CB_KEY, {Ref, [F || {_,F} <- Cs]}), putr(?DPR_KEY, [F || {_, F} <- Ds]), @@ -235,7 +238,7 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) -> Tmo = proplists:get_value(capx_timeout, Opts, ?CAPX_TIMEOUT), Strictness = proplists:get_value(capx_strictness, Opts, true), - OnLengthErr = proplists:get_value(length_errors, Opts, exit), + LengthErr = proplists:get_value(length_errors, Opts, exit), {TPid, Addrs} = start_transport(T, Rest, Svc), @@ -247,9 +250,14 @@ i({Ack, WPid, {M, Ref} = T, Opts, {SvcOpts, Nodes, Dict0, Svc}}) -> dictionary = Dict0, mode = M, service = svc(Svc, Addrs), - length_errors = OnLengthErr, + length_errors = LengthErr, strict = Strictness, - incoming_maxlen = Maxlen}. + incoming_maxlen = Maxlen, + codec = maps:with([string_decode, + strict_mbit, + rfc, + ordered_encode], + SvcOpts#{ordered_encode => false})}. %% The transport returns its local ip addresses so that different %% transports on the same service can use different local addresses. %% The local addresses are put into Host-IP-Address avps here when @@ -442,9 +450,18 @@ transition({connection_timeout = T, TPid}, transition({connection_timeout, _}, _) -> ok; +%% Requests for acknowledgements to the transport. +transition({diameter, ack}, S) -> + S#state{ack = true}; + %% Incoming message from the transport. -transition({diameter, {recv, MsgT}}, S) -> - incoming(MsgT, S); +transition({diameter, {recv, Msg}}, S) -> + incoming(recv(Msg, S), S); + +%% Handler of an incoming request is telling of its existence. +transition({handler, Pid}, _) -> + put_route(Pid), + ok; %% Timeout when still in the same state ... transition({timeout = T, PS}, #state{state = PS}) -> @@ -458,7 +475,7 @@ transition({timeout, _}, _) -> transition({send, Msg}, S) -> outgoing(Msg, S); transition({send, Msg, Route}, S) -> - put_route(Route), + route_outgoing(Route), outgoing(Msg, S); %% Request for graceful shutdown at remove_transport, stop_service of @@ -487,12 +504,13 @@ transition({'DOWN', _, process, WPid, _}, transition({'DOWN', _, process, TPid, _}, #state{transport = TPid} = S) -> - start_next(S); + start_next(S#state{ack = false}); %% Transport has died after connection timeout, or handler process has %% died. -transition({'DOWN', _, process, Pid, _}, _) -> - erase_route(Pid), +transition({'DOWN', _, process, Pid, _}, #state{transport = TPid}) -> + is_reference(erase_route(Pid)) + andalso send(TPid, false), %% answer not forthcoming ok; %% State query. @@ -502,37 +520,56 @@ transition({state, Pid}, #state{state = S, transport = TPid}) -> %% Crash on anything unexpected. -%% put_route/1 -%% +%% route_outgoing/1 + %% Map identifiers in an outgoing request to be able to lookup the %% handler process when the answer is received. - -put_route({Pid, Ref, Seqs}) -> +route_outgoing({Pid, Ref, Seqs}) -> %% request MRef = monitor(process, Pid), put(Pid, Seqs), - put(Seqs, {Pid, Ref, MRef}). + put(Seqs, {Pid, Ref, MRef}); -%% get_route/1 +%% Remove a mapping made for an incoming request. +route_outgoing(Pid) + when is_pid(Pid) -> %% answer + MRef = erase_route(Pid), + undefined == MRef orelse demonitor(MRef). -get_route(#diameter_packet{header = #diameter_header{is_request = false}} - = Pkt) -> +%% put_route/1 + +%% Monitor on a handler process for an incoming request. +put_route(Pid) -> + MRef = monitor(process, Pid), + put(Pid, MRef). + +%% get_route/2 + +%% incoming answer +get_route(_, #diameter_packet{header = #diameter_header{is_request = false}} + = Pkt) -> Seqs = diameter_codec:sequence_numbers(Pkt), case erase(Seqs) of {Pid, Ref, MRef} -> demonitor(MRef), erase(Pid), {Pid, Ref, self()}; - undefined -> + undefined -> %% request unknown false end; -get_route(_) -> - false. +%% incoming request +get_route(Ack, _) -> + Ack. %% erase_route/1 erase_route(Pid) -> - erase(erase(Pid)). + case erase(Pid) of + {_,_} = Seqs -> + erase(Seqs); + T -> + T + end. %% capx/1 @@ -560,7 +597,8 @@ send_CER(#state{state = {'Wait-Conn-Ack', Tmo}, mode = {connect, Remote}, service = #diameter_service{capabilities = LCaps}, transport = TPid, - dictionary = Dict} + dictionary = Dict, + codec = Opts} = S) -> OH = LCaps#diameter_caps.origin_host, req_send_CER(OH, Remote) @@ -570,7 +608,7 @@ send_CER(#state{state = {'Wait-Conn-Ack', Tmo}, #diameter_packet{header = #diameter_header{end_to_end_id = Eid, hop_by_hop_id = Hid}} = Pkt - = encode(CER, Dict), + = encode(CER, Opts, Dict), incr(send, Pkt, Dict), send(TPid, Pkt), ?LOG(send, 'CER'), @@ -599,41 +637,36 @@ build_CER(#state{service = #diameter_service{capabilities = LCaps}, {ok, CER} = diameter_capx:build_CER(LCaps, Dict), CER. -%% encode/2 +%% encode/3 -encode(Rec, Dict) -> +encode(Rec, Opts, Dict) -> Seq = diameter_session:sequence({_,_} = getr(?SEQUENCE_KEY)), Hdr = #diameter_header{version = ?DIAMETER_VERSION, end_to_end_id = Seq, hop_by_hop_id = Seq}, - diameter_codec:encode(Dict, #diameter_packet{header = Hdr, - msg = Rec}). + diameter_codec:encode(Dict, Opts, #diameter_packet{header = Hdr, + msg = Rec}). %% incoming/2 -incoming({Msg, NPid}, S) -> - try recv(Msg, S) of - T -> - NPid ! {diameter, discard}, - T - catch - {?MODULE, Name, Pkt} -> - incoming(Name, Pkt, NPid, S) - end; +incoming({recv = T, Name, Pkt}, #state{parent = Pid, ack = Ack} = S) -> + Pid ! {T, self(), get_route(Ack, Pkt), Name, Pkt}, + rcv(Name, Pkt, S); -incoming(Msg, S) -> - try - recv(Msg, S) - catch - {?MODULE, Name, Pkt} -> - incoming(Name, Pkt, false, S) - end. +incoming(#diameter_header{is_request = R}, #state{transport = TPid, + ack = Ack}) -> + R andalso Ack andalso send(TPid, false), + ok; -%% incoming/4 +incoming(<<_:32, 1:1, _/bits>>, #state{ack = true} = S) -> + send(S#state.transport, false), + ok; -incoming(Name, Pkt, NPid, #state{parent = Pid} = S) -> - Pid ! {recv, self(), get_route(Pkt), Name, Pkt, NPid}, - rcv(Name, Pkt, S). +incoming(<<_/bits>>, _) -> + ok; + +incoming(T, _) -> + T. %% recv/2 @@ -658,18 +691,19 @@ recv1(_, #diameter_packet{header = H, bin = Bin}, #state{incoming_maxlen = M}) when M < size(Bin) -> - invalid(false, incoming_maxlen_exceeded, {size(Bin), H}); + invalid(false, incoming_maxlen_exceeded, {size(Bin), H}), + H; %% Ignore anything but an expected CER/CEA if so configured. This is %% non-standard behaviour. -recv1(Name, _, #state{state = {'Wait-CEA', _, _}, - strict = false}) +recv1(Name, #diameter_packet{header = H}, #state{state = {'Wait-CEA', _, _}, + strict = false}) when Name /= 'CEA' -> - ok; -recv1(Name, _, #state{state = recv_CER, - strict = false}) + H; +recv1(Name, #diameter_packet{header = H}, #state{state = recv_CER, + strict = false}) when Name /= 'CER' -> - ok; + H; %% Incoming request after outgoing DPR: discard. Don't discard DPR, so %% both ends don't do so when sending simultaneously. @@ -677,13 +711,15 @@ recv1(Name, #diameter_packet{header = #diameter_header{is_request = true} = H}, #state{dpr = {_,_,_}}) when Name /= 'DPR' -> - invalid(false, recv_after_outgoing_dpr, H); + invalid(false, recv_after_outgoing_dpr, H), + H; %% Incoming request after incoming DPR: discard. recv1(_, #diameter_packet{header = #diameter_header{is_request = true} = H}, #state{dpr = true}) -> - invalid(false, recv_after_incoming_dpr, H); + invalid(false, recv_after_incoming_dpr, H), + H; %% DPA with identifier mismatch, or in response to a DPR initiated by %% the service. @@ -701,7 +737,7 @@ recv1('DPA' = N, %% Any other message with a header and no length errors: send to the %% parent. recv1(Name, Pkt, #state{}) -> - throw({?MODULE, Name, Pkt}). + {recv, Name, Pkt}. %% recv/3 @@ -720,10 +756,12 @@ recv(#diameter_header{} #diameter_packet{bin = Bin}, #state{length_errors = E}) -> T = {size(Bin), bit_size(Bin) rem 8, H}, - invalid(E, message_length_mismatch, T); + invalid(E, message_length_mismatch, T), + Bin; recv(false, #diameter_packet{bin = Bin}, #state{length_errors = E}) -> - invalid(E, truncated_header, Bin). + invalid(E, truncated_header, Bin), + Bin. %% Note that counters here only count discarded messages. invalid(E, Reason, T) -> @@ -767,26 +805,23 @@ rcv('DPA' = N, = Pkt, #state{dictionary = Dict0, transport = TPid, - dpr = {X, Hid, Eid}}) -> + dpr = {X, Hid, Eid}, + codec = Opts}) -> ?LOG(recv, N), X orelse begin %% Only count DPA in response to a DPR sent by the %% service: explicit DPR is counted in the same way %% as other explicitly sent requests. incr(recv, H, Dict0), - incr_rc(recv, diameter_codec:decode(Dict0, Pkt), Dict0) + incr_rc(recv, diameter_codec:decode(Dict0, Opts, Pkt), Dict0) end, diameter_peer:close(TPid), {stop, N}; -%% Ignore anything else, an unsolicited DPA in particular. Note that -%% dpa_timeout deals with the case in which the peer sends the wrong -%% identifiers in DPA. -rcv(N, #diameter_packet{header = H}, _) - when N == 'CER'; - N == 'CEA'; - N == 'DPR'; - N == 'DPA' -> +%% Ignore an unsolicited DPA in particular. Note that dpa_timeout +%% deals with the case in which the peer sends the wrong identifiers +%% in DPA. +rcv('DPA' = N, #diameter_packet{header = H}, _) -> ?LOG(ignored, N), %% Note that these aren't counted in the normal recv counter. diameter_stats:incr({diameter_codec:msg_id(H), recv, ignored}), @@ -839,7 +874,7 @@ outgoing(#diameter_packet{header = #diameter_header{application_id = 0, invalid(false, dpr_after_dpr, H) %% DPR sent: discard end; -%% Explict CER or DWR: discard. These are sent by us. +%% Explicit CER or DWR: discard. These are sent by us. outgoing(#diameter_packet{header = #diameter_header{application_id = 0, cmd_code = C, is_request = true} @@ -875,15 +910,21 @@ header(Bin) -> %% DWR %% Incoming CER or DPR. handle_request(Name, - #diameter_packet{header = H} = Pkt, - #state{dictionary = Dict0} = S) -> + #diameter_packet{header = H} + = Pkt, + #state{dictionary = Dict0, + codec = Opts} + = S) -> ?LOG(recv, Name), incr(recv, H, Dict0), - send_answer(Name, diameter_codec:decode(Dict0, Pkt), S). + send_answer(Name, diameter_codec:decode(Dict0, Opts, Pkt), S). %% send_answer/3 -send_answer(Type, ReqPkt, #state{transport = TPid, dictionary = Dict} = S) -> +send_answer(Type, ReqPkt, #state{transport = TPid, + dictionary = Dict, + codec = Opts} + = S) -> incr_error(recv, ReqPkt, Dict), #diameter_packet{header = H, @@ -902,7 +943,7 @@ send_answer(Type, ReqPkt, #state{transport = TPid, dictionary = Dict} = S) -> msg = Msg, transport_data = TD}, - AnsPkt = diameter_codec:encode(Dict, Pkt), + AnsPkt = diameter_codec:encode(Dict, Opts, Pkt), incr(send, AnsPkt, Dict), incr_rc(send, AnsPkt, Dict), @@ -929,8 +970,6 @@ build_answer('CER', = Pkt, #state{dictionary = Dict0} = S) -> - diameter_codec:setopts([{string_decode, false}]), - {SupportedApps, RCaps, CEA} = recv_CER(CER, S), [RC, IS] = Dict0:'#get-'(['Result-Code', 'Inband-Security-Id'], CEA), @@ -1131,18 +1170,16 @@ recv_CER(CER, #state{service = Svc, dictionary = Dict}) -> handle_CEA(#diameter_packet{header = H} = Pkt, #state{dictionary = Dict0, - service = #diameter_service{capabilities = LCaps}} + service = #diameter_service{capabilities = LCaps}, + codec = Opts} = S) -> incr(recv, H, Dict0), #diameter_packet{} = DPkt - = diameter_codec:decode(Dict0, Pkt), - - diameter_codec:setopts([{string_decode, false}]), + = diameter_codec:decode(Dict0, Opts, Pkt), RC = result_code(incr_rc(recv, DPkt, Dict0)), - {SApps, IS, RCaps} = recv_CEA(DPkt, S), #diameter_caps{origin_host = {OH, DH}} @@ -1330,8 +1367,9 @@ dpr([], [Reason | _], S) -> -record(opts, {cause, timeout}). -send_dpr(Reason, Opts, #state{dictionary = Dict, - service = #diameter_service{capabilities = Caps}} +send_dpr(Reason, DprOpts, #state{dictionary = Dict, + service = #diameter_service{capabilities = Caps}, + codec = Opts} = S) -> #opts{cause = Cause, timeout = Tmo} = lists:foldl(fun opt/2, @@ -1340,7 +1378,7 @@ send_dpr(Reason, Opts, #state{dictionary = Dict, _ -> ?REBOOT end, timeout = dpa_timeout()}, - Opts), + DprOpts), #diameter_caps{origin_host = {OH, _}, origin_realm = {OR, _}} = Caps, @@ -1348,6 +1386,7 @@ send_dpr(Reason, Opts, #state{dictionary = Dict, Pkt = encode(['DPR', {'Origin-Host', OH}, {'Origin-Realm', OR}, {'Disconnect-Cause', Cause}], + Opts, Dict), send_dpr(false, Pkt, Tmo, S). diff --git a/lib/diameter/src/base/diameter_reg.erl b/lib/diameter/src/base/diameter_reg.erl index 9027130063..4910979219 100644 --- a/lib/diameter/src/base/diameter_reg.erl +++ b/lib/diameter/src/base/diameter_reg.erl @@ -137,7 +137,7 @@ match(Pat) -> match(Pat, Pid) -> ets:match_object(?TABLE, {Pat, Pid}). - + %% =========================================================================== %% # wait(Pat) %% diff --git a/lib/diameter/src/base/diameter_service.erl b/lib/diameter/src/base/diameter_service.erl index e4f77e3a24..a976a8b998 100644 --- a/lib/diameter/src/base/diameter_service.erl +++ b/lib/diameter/src/base/diameter_service.erl @@ -88,12 +88,6 @@ %% outside of the service process. -define(STATE_TABLE, ?MODULE). -%% The default sequence mask. --define(NOMASK, {0,32}). - -%% The default restrict_connections. --define(RESTRICT, nodes). - %% Workaround for dialyzer's lack of understanding of match specs. -type match(T) :: T | '_' | '$1' | '$2'. @@ -110,21 +104,17 @@ service :: #diameter_service{}, watchdogT = ets_new(watchdogs) %% #watchdog{} at start :: ets:tid(), - peerT, %% undefined in new code, but remain for upgrade - shared_peers, %% reasons. Replaced by local/remote. - local_peers, %% local :: {ets:tid(), ets:tid(), ets:tid()}, remote :: {ets:tid(), ets:tid(), ets:tid()}, monitor = false :: false | pid(), %% process to die with - options - :: [{sequence, diameter:sequence()} %% sequence mask - | {share_peers, diameter:remotes()} %% broadcast to - | {use_shared_peers, diameter:remotes()} %% use from - | {restrict_connections, diameter:restriction()} - | {strict_mbit, boolean()} - | {string_decode, boolean()} - | {incoming_maxlen, diameter:message_length()}]}). -%% shared_peers reflects the peers broadcast from remote nodes. + options :: #{sequence := diameter:sequence(), %% sequence mask + share_peers := diameter:remotes(),%% broadcast to + use_shared_peers := diameter:remotes(),%% use from + restrict_connections := diameter:restriction(), + incoming_maxlen := diameter:message_length(), + strict_mbit := boolean(), + string_decode := boolean(), + spawn_opt := list() | {module(), atom(), list()}}}). %% Record representing an RFC 3539 watchdog process implemented by %% diameter_watchdog. @@ -284,7 +274,7 @@ whois(SvcName) -> %% --------------------------------------------------------------------------- -spec pick_peer(SvcName, AppOrAlias, Opts) - -> {{TPid, Caps, App}, Mask, SvcOpts} + -> {{{TPid, Caps}, App}, SvcOpts} | false %% no selection | {error, no_service} when SvcName :: diameter:service_name(), @@ -292,14 +282,12 @@ whois(SvcName) -> | {alias, diameter:app_alias()}, Opts :: {fun((Dict :: module()) -> [term()]), diameter:peer_filter(), - Xtra :: list()}, + Xtra :: list(), + [diameter:peer_ref()]}, TPid :: pid(), Caps :: #diameter_caps{}, App :: #diameter_app{}, - Mask :: diameter:sequence(), - SvcOpts :: [diameter:service_opt()]. -%% Extract Mask in the returned tuple so that diameter_traffic doesn't -%% need to know about the ordering of SvcOpts used here. + SvcOpts :: map(). pick_peer(SvcName, App, Opts) -> pick(lookup_state(SvcName), App, Opts). @@ -319,16 +307,16 @@ pick(#state{service = #diameter_service{applications = Apps}} pick(_, false = No, _) -> No; -pick(#state{options = [{_, Mask} | SvcOpts]} +pick(#state{options = SvcOpts} = S, #diameter_app{module = ModX, dictionary = Dict} = App0, - {DestF, Filter, Xtra}) -> + {DestF, Filter, Xtra, TPids}) -> App = App0#diameter_app{module = ModX ++ Xtra}, [_,_] = RealmAndHost = diameter_lib:eval([DestF, Dict]), - case pick_peer(App, RealmAndHost, Filter, S) of - {TPid, Caps} -> - {{TPid, Caps, App}, Mask, SvcOpts}; + case pick_peer(App, RealmAndHost, [Filter | TPids], S) of + {_TPid, _Caps} = TC -> + {{TC, App}, SvcOpts}; false = No -> No end. @@ -556,81 +544,9 @@ terminate(Reason, #state{service_name = Name, local = {PeerT, _, _}} = S) -> %% # code_change/3 %% --------------------------------------------------------------------------- -code_change(_FromVsn, #state{} = S, _Extra) -> - {ok, S}; - -%% Don't support downgrade since we won't in appup. -code_change({down = T, _}, _, _Extra) -> - {error, T}; - -%% Upgrade local/shared peers dicts populated in old code. Don't -code_change(_FromVsn, S0, _Extra) -> - {state, Id, SvcName, Svc, WT, PeerT, SDict, LDict, Monitor, Opts} - = S0, - - init_peers(LT = setelement(1, {PT, _, _} = init_peers(), PeerT), - fun({_,A}) -> A end), - init_peers(init_peers(RT = init_peers(), SDict), - fun(A) -> A end), - - S = #state{id = Id, - service_name = SvcName, - service = Svc, - watchdogT = WT, - peerT = PT, %% empty - shared_peers = SDict, - local_peers = LDict, - local = LT, - remote = RT, - monitor = Monitor, - options = Opts}, - - %% Replacing the table entry and deleting the old shared tables - %% can make outgoing requests return {error, no_connection} until - %% everyone is running new code. Don't delete the tables to avoid - %% crashing request processes. - ets:delete_all_objects(SDict), - ets:delete_all_objects(LDict), - ets:insert(?STATE_TABLE, S), +code_change(_FromVsn, S, _Extra) -> {ok, S}. -%% init_peers/2 - -%% Populate app and identity bags from a new-style #peer{} sets. -init_peers({PeerT, _, _} = T, F) - when is_function(F) -> - ets:foldl(fun(#peer{pid = P, apps = As, caps = C}, N) -> - insert_peer(P, lists:map(F, As), C, T), - N+1 - end, - 0, - PeerT); - -%% Populate #peer{} table given a shared peers dict. -init_peers({PeerT, _, _}, SDict) -> - dict:fold(fun(P, As, N) -> - ets:update_element(PeerT, P, {#peer.apps, As}), - N+1 - end, - 0, - diameter_dict:fold(fun(A, Ps, D) -> - init_peers(A, Ps, PeerT, D) - end, - dict:new(), - SDict)). - -%% init_peers/4 - -init_peers(App, TCs, PeerT, Dict) -> - lists:foldl(fun({P,C}, D) -> - ets:insert(PeerT, #peer{pid = P, - apps = [], - caps = C}), - dict:append(P, App, D) - end, - Dict, - TCs). - %% =========================================================================== %% =========================================================================== @@ -768,7 +684,7 @@ cfg_acc({SvcName, #diameter_service{applications = Apps} = Rec, Opts}, local = init_peers(), remote = init_peers(), monitor = mref(get_value(monitor, Opts)), - options = service_options(Opts)}, + options = service_options(lists:keydelete(monitor, 1, Opts))}, {S, Acc}; cfg_acc({_Ref, Type, _Opts} = T, {S, Acc}) @@ -784,24 +700,14 @@ init_peers() -> %% TPid} service_options(Opts) -> - [{sequence, proplists:get_value(sequence, Opts, ?NOMASK)}, - {share_peers, get_value(share_peers, Opts)}, - {use_shared_peers, get_value(use_shared_peers, Opts)}, - {restrict_connections, proplists:get_value(restrict_connections, - Opts, - ?RESTRICT)}, - {spawn_opt, proplists:get_value(spawn_opt, Opts, [])}, - {string_decode, proplists:get_value(string_decode, Opts, true)}, - {incoming_maxlen, proplists:get_value(incoming_maxlen, Opts, 16#FFFFFF)}, - {strict_mbit, proplists:get_value(strict_mbit, Opts, true)}]. -%% The order of options is significant since we match against the list. + maps:from_list(Opts). mref(false = No) -> No; mref(P) -> monitor(process, P). -init_shared(#state{options = [_, _, {_,T} | _], +init_shared(#state{options = #{use_shared_peers := T}, service_name = Svc}) -> notify(T, Svc, {service, self()}). @@ -899,7 +805,8 @@ start(Ref, Type, Opts, State) -> start(Ref, Type, Opts, N, #state{watchdogT = WatchdogT, local = {PeerT, _, _}, - options = SvcOpts, + options = #{string_decode := SD} + = SvcOpts0, service_name = SvcName, service = Svc0}) when Type == connect; @@ -907,21 +814,25 @@ start(Ref, Type, Opts, N, #state{watchdogT = WatchdogT, #diameter_service{applications = Apps} = Svc1 = merge_service(Opts, Svc0), - Svc = binary_caps(Svc1, proplists:get_value(string_decode, SvcOpts, true)), - RecvData = diameter_traffic:make_recvdata([SvcName, - PeerT, - Apps, - SvcOpts]), - T = {{spawn_opts([Opts, SvcOpts]), RecvData}, Opts, SvcOpts, Svc}, + Svc = binary_caps(Svc1, SD), + SvcOpts = merge_options(Opts, SvcOpts0), + RecvData = diameter_traffic:make_recvdata([SvcName, PeerT, Apps, SvcOpts]), + T = {Opts, SvcOpts, RecvData, Svc}, Rec = #watchdog{type = Type, ref = Ref, options = Opts}, + diameter_lib:fold_n(fun(_,A) -> [wd(Type, Ref, T, WatchdogT, Rec) | A] end, [], N). +merge_options(Opts, SvcOpts) -> + Keys = maps:keys(SvcOpts), + Map = maps:from_list([KV || {K,_} = KV <- Opts, lists:member(K, Keys)]), + maps:merge(SvcOpts, Map). + binary_caps(Svc, true) -> Svc; binary_caps(#diameter_service{capabilities = Caps} = Svc, false) -> @@ -936,12 +847,6 @@ wd(Type, Ref, T, WatchdogT, Rec) -> %% record so that each watchdog may get a different record. This %% record is what is passed back into application callbacks. -spawn_opts(Optss) -> - SpawnOpts = get_value(spawn_opt, Optss, []), - [T || T <- SpawnOpts, - T /= link, - T /= monitor]. - start_watchdog(Type, Ref, T) -> {_MRef, Pid} = diameter_watchdog:start({Type, Ref}, T), Pid. @@ -1154,18 +1059,6 @@ keyfind([Key | Rest], Pos, L) -> T end. -%% get_value/3 - -get_value(_, [], Def) -> - Def; -get_value(Key, [L | Rest], Def) -> - case lists:keyfind(Key, 1, L) of - {_,V} -> - V; - _ -> - get_value(Key, Rest, Def) - end. - %% find_outgoing_app/2 find_outgoing_app(Alias, Apps) -> @@ -1463,19 +1356,19 @@ send_event(#diameter_event{service = SvcName} = E) -> %% # share_peer/5 %% --------------------------------------------------------------------------- -share_peer(up, Caps, Apps, TPid, #state{options = [_, {_,T} | _], +share_peer(up, Caps, Apps, TPid, #state{options = #{share_peers := SP}, service_name = Svc}) -> - notify(T, Svc, {peer, TPid, [A || {_,A} <- Apps], Caps}); + notify(SP, Svc, {peer, TPid, [A || {_,A} <- Apps], Caps}); -share_peer(down, _Caps, _Apps, TPid, #state{options = [_, {_,T} | _], +share_peer(down, _Caps, _Apps, TPid, #state{options = #{share_peers := SP}, service_name = Svc}) -> - notify(T, Svc, {peer, TPid}). + notify(SP, Svc, {peer, TPid}). %% --------------------------------------------------------------------------- %% # share_peers/2 %% --------------------------------------------------------------------------- -share_peers(Pid, #state{options = [_, {_,SP} | _], +share_peers(Pid, #state{options = #{share_peers := SP}, local = {PeerT, AppT, _}}) -> is_remote(Pid, SP) andalso ets:foldl(fun(T, N) -> N + sp(Pid, AppT, T) end, @@ -1507,7 +1400,8 @@ is_remote(Pid, T) -> %% # remote_peer_up/4 %% --------------------------------------------------------------------------- -remote_peer_up(TPid, Aliases, Caps, #state{options = [_, _, {_,T} | _]} = S) -> +remote_peer_up(TPid, Aliases, Caps, #state{options = #{use_shared_peers := T}} + = S) -> is_remote(TPid, T) andalso rpu(TPid, Aliases, Caps, S). rpu(TPid, Aliases, Caps, #state{service = Svc, remote = RT}) -> @@ -1629,8 +1523,14 @@ pick_peer(Local, %% peers/4 -peers(Alias, RH, Filter, T) -> - filter(Alias, RH, Filter, T, true). +%% No peer options pointing at specific peers: search for them. +peers(Alias, RH, [Filter], T) -> + filter(Alias, RH, Filter, T, true); + +%% Or just lookup. +peers(_Alias, RH, [Filter | TPids], {PeerT, _AppT, _IdentT}) -> + {Ts, _} = filter(caps(PeerT, TPids), RH, Filter), + Ts. %% filter/5 %% diff --git a/lib/diameter/src/base/diameter_traffic.erl b/lib/diameter/src/base/diameter_traffic.erl index bc1ccf4feb..85378babea 100644 --- a/lib/diameter/src/base/diameter_traffic.erl +++ b/lib/diameter/src/base/diameter_traffic.erl @@ -30,7 +30,7 @@ -export([send_request/4]). %% towards diameter_watchdog --export([receive_message/6]). +-export([receive_message/5]). %% towards diameter_peer_fsm and diameter_watchdog -export([incr/4, @@ -54,8 +54,7 @@ -define(RELAY, ?DIAMETER_DICT_RELAY). -define(BASE, ?DIAMETER_DICT_COMMON). %% Note: the RFC 3588 dictionary --define(DEFAULT_TIMEOUT, 5000). %% for outgoing requests --define(DEFAULT_SPAWN_OPTS, []). +-define(DEFAULT(V, Def), if V == undefined -> Def; true -> V end). %% Table containing outgoing entries that live and die with %% peer_up/down. The name is historic, since the table used to contain @@ -65,9 +64,10 @@ %% Record diameter:call/4 options are parsed into. -record(options, - {filter = none :: diameter:peer_filter(), + {peers = [] :: [diameter:peer_ref()], + filter = none :: diameter:peer_filter(), extra = [] :: list(), - timeout = ?DEFAULT_TIMEOUT :: 0..16#FFFFFFFF, + timeout = 5000 :: 0..16#FFFFFFFF, %% for outgoing requests detach = false :: boolean()}). %% Term passed back to receive_message/6 with every incoming message. @@ -76,9 +76,9 @@ service_name :: diameter:service_name(), apps :: [#diameter_app{}], sequence :: diameter:sequence(), - codec :: [{string_decode, boolean()} - | {strict_mbit, boolean()} - | {incoming_maxlen, diameter:message_length()}]}). + codec :: #{string_decode := boolean(), + strict_mbit := boolean(), + incoming_maxlen := diameter:message_length()}}). %% Note that incoming_maxlen is currently handled in diameter_peer_fsm, %% so that any message exceeding the maximum is discarded. Retain the %% option in case we want to extend the values and semantics. @@ -88,24 +88,25 @@ {ref :: reference(), %% used to receive answer caller :: pid() | undefined, %% calling process handler :: pid(), %% request process - transport :: pid() | undefined, %% peer process - caps :: #diameter_caps{} | undefined, %% of connection + peer :: undefined | {pid(), #diameter_caps{}}, packet :: #diameter_packet{} | undefined}). %% of request %% --------------------------------------------------------------------------- -%% # make_recvdata/1 +%% make_recvdata/1 %% --------------------------------------------------------------------------- make_recvdata([SvcName, PeerT, Apps, SvcOpts | _]) -> - {_,_} = Mask = proplists:get_value(sequence, SvcOpts), - #recvdata{service_name = SvcName, - peerT = PeerT, - apps = Apps, - sequence = Mask, - codec = [T || {K,_} = T <- SvcOpts, - lists:member(K, [string_decode, - incoming_maxlen, - strict_mbit])]}. + #{sequence := {_,_} = Mask, spawn_opt := Opts} + = SvcOpts, + {Opts, #recvdata{service_name = SvcName, + peerT = PeerT, + apps = Apps, + sequence = Mask, + codec = maps:with([string_decode, + strict_mbit, + ordered_encode, + incoming_maxlen], + SvcOpts)}}. %% --------------------------------------------------------------------------- %% peer_up/1 @@ -206,42 +207,40 @@ incr_rc(Dir, Pkt, TPid, Dict0) -> incr_rc(Dir, Pkt, TPid, {Dict0, Dict0, Dict0}). %% --------------------------------------------------------------------------- -%% # receive_message/6 +%% receive_message/5 %% -%% Handle an incoming Diameter message. +%% Handle an incoming Diameter message in a watchdog process. %% --------------------------------------------------------------------------- -%% Handle an incoming Diameter message in the watchdog process. - -receive_message(TPid, Route, Pkt, false, Dict0, RecvData) -> - incoming(TPid, Route, Pkt, Dict0, RecvData); - -receive_message(TPid, Route, Pkt, NPid, Dict0, RecvData) -> - NPid ! {diameter, incoming(TPid, Route, Pkt, Dict0, RecvData)}. - -%% incoming/4 - -incoming(TPid, Route, Pkt, Dict0, RecvData) - when is_pid(TPid) -> +-spec receive_message(pid(), Route, #diameter_packet{}, module(), RecvData) + -> pid() %% request handler + | boolean() %% answer, known request or not + | discard %% request discarded by MFA + when Route :: {Handler, RequestRef, Seqs} + | Ack, + RecvData :: {[SpawnOpt], #recvdata{}}, + SpawnOpt :: term(), + Handler :: pid(), + RequestRef :: reference(), + Seqs :: {0..16#FFFFFFFF, 0..16#FFFFFFFF}, + Ack :: boolean(). + +receive_message(TPid, Route, Pkt, Dict0, RecvData) -> #diameter_packet{header = #diameter_header{is_request = R}} = Pkt, recv(R, Route, TPid, Pkt, Dict0, RecvData). %% recv/6 %% Incoming request ... -recv(true, false, TPid, Pkt, Dict0, T) -> - try - {request, spawn_request(TPid, Pkt, Dict0, T)} - catch - error: system_limit = E -> %% discard - ?LOG(error, E), - discard - end; +recv(true, Ack, TPid, Pkt, Dict0, T) + when is_boolean(Ack) -> + {Opts, RecvData} = T, + spawn_request(Ack, TPid, Pkt, Dict0, RecvData, Opts); %% ... answer to known request ... recv(false, {Pid, Ref, TPid}, _, Pkt, Dict0, _) -> Pid ! {answer, Ref, TPid, Dict0, Pkt}, - {answer, Pid}; + true; %% Note that failover could have happened prior to this message being %% received and triggering failback. That is, both a failover message @@ -256,69 +255,91 @@ recv(false, {Pid, Ref, TPid}, _, Pkt, Dict0, _) -> recv(false, false, TPid, Pkt, _, _) -> ?LOG(discarded, Pkt#diameter_packet.header), incr(TPid, {{unknown, 0}, recv, discarded}), - discard. + false. -%% spawn_request/4 +%% spawn_request/6 + +%% An MFA should return a pid() or the atom 'discard'. The latter +%% results in an acknowledgment back to the transport process when +%% appropriate, to ensure that send/recv callbacks can count +%% outstanding requests. Acknowledgement is implicit if the +%% handler process dies (in a handle_request callback for example). +spawn_request(Ack, TPid, Pkt, Dict0, RecvData, {M,F,A}) -> + ReqF = fun() -> + ack(Ack, TPid, recv_request(Ack, TPid, Pkt, Dict0, RecvData)) + end, + ack(Ack, TPid, apply(M, F, [ReqF | A])); -spawn_request(TPid, Pkt, Dict0, {Opts, RecvData}) -> - spawn_request(TPid, Pkt, Dict0, Opts, RecvData); -spawn_request(TPid, Pkt, Dict0, RecvData) -> - spawn_request(TPid, Pkt, Dict0, ?DEFAULT_SPAWN_OPTS, RecvData). +%% A spawned process acks implicitly when it dies, so there's no need +%% to handle 'discard'. +spawn_request(Ack, TPid, Pkt, Dict0, RecvData, Opts) -> + spawn_opt(fun() -> + recv_request(Ack, TPid, Pkt, Dict0, RecvData) + end, + Opts). -spawn_request(TPid, Pkt, Dict0, Opts, RecvData) -> - spawn_opt(fun() -> recv_request(TPid, Pkt, Dict0, RecvData) end, Opts). +%% ack/3 + +ack(Ack, TPid, RC) -> + RC == discard andalso Ack andalso (TPid ! {send, false}), + RC. %% --------------------------------------------------------------------------- -%% recv_request/4 +%% recv_request/5 %% --------------------------------------------------------------------------- -recv_request(TPid, +-spec recv_request(Ack :: boolean(), + TPid :: pid(), + #diameter_packet{}, + Dict0 :: module(), + #recvdata{}) + -> ok %% answer was sent + | discard %% or not + | false. %% no transport + +recv_request(Ack, + TPid, #diameter_packet{header = #diameter_header{application_id = Id}} = Pkt, Dict0, #recvdata{peerT = PeerT, - apps = Apps, - codec = Opts} + apps = Apps} = RecvData) -> - diameter_codec:setopts([{common_dictionary, Dict0} | Opts]), - send_A(recv_R(diameter_service:find_incoming_app(PeerT, TPid, Id, Apps), - TPid, - Pkt, - Dict0, - RecvData), - TPid, - Dict0, - RecvData). - -%% recv_R/5 - -recv_R({#diameter_app{id = Id, dictionary = AppDict} = App, Caps}, - TPid, - Pkt0, - Dict0, - RecvData) -> - incr(recv, Pkt0, TPid, AppDict), - Pkt = errors(Id, diameter_codec:decode(Id, AppDict, Pkt0)), - incr_error(recv, Pkt, TPid, AppDict), - {Caps, Pkt, App, recv_R(App, TPid, Dict0, Caps, RecvData, Pkt)}; -%% Note that the decode is different depending on whether or not Id is -%% ?APP_ID_RELAY. - -%% DIAMETER_APPLICATION_UNSUPPORTED 3007 -%% A request was sent for an application that is not supported. - -recv_R(#diameter_caps{} - = Caps, - _TPid, - #diameter_packet{errors = Es} - = Pkt, - _Dict0, - _RecvData) -> - {Caps, Pkt#diameter_packet{avps = collect_avps(Pkt), - errors = [3007 | Es]}}; + Ack andalso (TPid ! {handler, self()}), + case diameter_service:find_incoming_app(PeerT, TPid, Id, Apps) of + {#diameter_app{id = Aid, dictionary = AppDict} = App, Caps} -> + incr(recv, Pkt, TPid, AppDict), + DecPkt = decode(Aid, AppDict, RecvData, Pkt), + incr_error(recv, DecPkt, TPid, AppDict), + send_A(recv_R(App, TPid, Dict0, Caps, RecvData, DecPkt), + TPid, + App, + Dict0, + RecvData, + DecPkt, + Caps); + #diameter_caps{} = Caps -> + %% DIAMETER_APPLICATION_UNSUPPORTED 3007 + %% A request was sent for an application that is not + %% supported. + RC = 3007, + Es = Pkt#diameter_packet.errors, + DecPkt = Pkt#diameter_packet{avps = collect_avps(Pkt), + errors = [RC | Es]}, + send_answer(answer_message(RC, Dict0, Caps, DecPkt), + TPid, + Dict0, + Dict0, + Dict0, + RecvData, + DecPkt, + [[]]); + false = No -> %% transport has gone down + No + end. -recv_R(false = No, _, _, _, _) -> %% transport has gone down - No. +decode(Id, Dict, #recvdata{codec = Opts}, Pkt) -> + errors(Id, diameter_codec:decode(Id, Dict, Opts, Pkt)). collect_avps(Pkt) -> case diameter_codec:collect_avps(Pkt) of @@ -328,6 +349,14 @@ collect_avps(Pkt) -> Avps end. +%% send_A/7 + +send_A([T | Fs], TPid, App, Dict0, RecvData, DecPkt, Caps) -> + send_A(T, TPid, App, Dict0, RecvData, DecPkt, Caps, Fs); + +send_A(discard = No, _, _, _, _, _, _) -> + No. + %% recv_R/6 %% Answer errors ourselves ... @@ -337,9 +366,9 @@ recv_R(#diameter_app{options = [_, {request_errors, E} | _]}, _Caps, _RecvData, #diameter_packet{errors = [RC|_]}) %% a detected 3xxx is hd - when E == answer, (Dict0 /= ?BASE orelse 3 == RC div 1000); + when E == answer, Dict0 /= ?BASE orelse 3 == RC div 1000; E == answer_3xxx, 3 == RC div 1000 -> - {{answer_message, rc(RC)}, [], []}; + [{answer_message, rc(RC)}, []]; %% ... or make a handle_request callback. Note that %% Pkt#diameter_packet.msg = undefined in the 3001 case. @@ -431,24 +460,24 @@ errors(_, Pkt) -> %% command code in this case. It will also then ignore Dict and use %% the base encoder. request_cb({reply, _Ans} = T, _App, EvalPktFs, EvalFs) -> - {T, EvalPktFs, EvalFs}; + [T, EvalPktFs | EvalFs]; %% An 3xxx result code, for which the E-bit is set in the header. request_cb({protocol_error, RC}, _App, EvalPktFs, EvalFs) when 3 == RC div 1000 -> - {{answer_message, RC}, EvalPktFs, EvalFs}; + [{answer_message, RC}, EvalPktFs | EvalFs]; request_cb({answer_message, RC} = T, _App, EvalPktFs, EvalFs) when 3 == RC div 1000; 5 == RC div 1000 -> - {T, EvalPktFs, EvalFs}; + [T, EvalPktFs | EvalFs]; %% RFC 3588 says we must reply 3001 to anything unrecognized or %% unsupported. 'noreply' is undocumented (and inappropriately named) %% backwards compatibility for this, protocol_error the documented %% alternative. request_cb(noreply, _App, EvalPktFs, EvalFs) -> - {{answer_message, 3001}, EvalPktFs, EvalFs}; + [{answer_message, 3001}, EvalPktFs | EvalFs]; %% Relay a request to another peer. This is equivalent to doing an %% explicit call/4 with the message in question except that (1) a loop @@ -469,7 +498,7 @@ request_cb({A, Opts}, #diameter_app{id = Id}, EvalPktFs, EvalFs) when A == relay, Id == ?APP_ID_RELAY; A == proxy, Id /= ?APP_ID_RELAY; A == resend -> - {{call, Opts}, EvalPktFs, EvalFs}; + [{call, Opts}, EvalPktFs | EvalFs]; request_cb(discard = No, _, _, _) -> No; @@ -483,71 +512,95 @@ request_cb({eval, RC, F}, App, EvalPktFs, Fs) -> request_cb(T, App, _, _) -> ?ERROR({invalid_return, T, handle_request, App}). -%% send_A/4 +%% send_A/8 -send_A({Caps, Pkt}, TPid, Dict0, _RecvData) -> %% unsupported application - #diameter_packet{errors = [RC|_]} = Pkt, - send_A(answer_message(RC, Caps, Dict0, Pkt), - TPid, - {Dict0, Dict0}, - Pkt, - [], - []); +send_A({reply, Ans}, TPid, App, Dict0, RecvData, Pkt, _Caps, Fs) -> + AppDict = App#diameter_app.dictionary, + MsgDict = msg_dict(AppDict, Dict0, Ans), + send_answer(Ans, + TPid, + MsgDict, + AppDict, + Dict0, + RecvData, + Pkt, + Fs); + +send_A({call, Opts}, TPid, App, Dict0, RecvData, Pkt, Caps, Fs) -> + AppDict = App#diameter_app.dictionary, + case resend(Opts, Caps, Pkt, App, Dict0, RecvData) of + #diameter_packet{bin = Bin} = Ans -> %% answer: reset hop by hop id + #diameter_packet{header = #diameter_header{hop_by_hop_id = Id}, + transport_data = TD} + = Pkt, + Reset = diameter_codec:hop_by_hop_id(Id, Bin), + MsgDict = msg_dict(AppDict, Dict0, Ans), + send_answer(Ans#diameter_packet{bin = Reset, + transport_data = TD}, + TPid, + MsgDict, + AppDict, + Dict0, + Fs); + RC -> + send_answer(answer_message(RC, Dict0, Caps, Pkt), + TPid, + Dict0, + AppDict, + Dict0, + RecvData, + Pkt, + Fs) + end; + +%% RFC 3588 only allows 3xxx errors in an answer-message. RFC 6733 +%% added the possibility of setting 5xxx. + +send_A({answer_message, RC} = T, TPid, App, Dict0, RecvData, Pkt, Caps, Fs) -> + Dict0 /= ?BASE orelse 3 == RC div 1000 + orelse ?ERROR({invalid_return, T, handle_request, App}), + send_answer(answer_message(RC, Dict0, Caps, Pkt), + TPid, + Dict0, + App#diameter_app.dictionary, + Dict0, + RecvData, + Pkt, + Fs). -send_A({Caps, Pkt, App, {T, EvalPktFs, EvalFs}}, TPid, Dict0, RecvData) -> - send_A(answer(T, Caps, Pkt, App, Dict0, RecvData), - TPid, - {App#diameter_app.dictionary, Dict0}, - Pkt, - EvalPktFs, - EvalFs); +%% send_answer/8 -send_A(_, _, _, _) -> - ok. +%% Skip the setting of Result-Code and Failed-AVP's below. This is +%% undocumented and shouldn't be relied on. +send_answer([Ans], TPid, MsgDict, AppDict, Dict0, RecvData, Pkt, Fs) + when [] == Pkt#diameter_packet.errors -> + send_answer(Ans, TPid, MsgDict, AppDict, Dict0, RecvData, Pkt, Fs); +send_answer([Ans], TPid, MsgDict, AppDict, Dict0, RecvData, Pkt0, Fs) -> + Pkt = Pkt0#diameter_packet{errors = []}, + send_answer(Ans, TPid, MsgDict, AppDict, Dict0, RecvData, Pkt, Fs); + +send_answer(Ans, TPid, MsgDict, AppDict, Dict0, RecvData, DecPkt, Fs) -> + Pkt = encode({MsgDict, AppDict}, + TPid, + RecvData#recvdata.codec, + make_answer_packet(Ans, DecPkt, MsgDict, Dict0)), + send_answer(Pkt, TPid, MsgDict, AppDict, Dict0, Fs). -%% send_A/6 +%% send_answer/6 -send_A(T, TPid, {AppDict, Dict0} = DictT0, ReqPkt, EvalPktFs, EvalFs) -> - {MsgDict, Pkt} = reply(T, TPid, DictT0, EvalPktFs, ReqPkt), +send_answer(Pkt, TPid, MsgDict, AppDict, Dict0, [EvalPktFs | EvalFs]) -> + eval_packet(Pkt, EvalPktFs), incr(send, Pkt, TPid, AppDict), incr_rc(send, Pkt, TPid, {MsgDict, AppDict, Dict0}), %% count outgoing - send(TPid, Pkt), + send(TPid, z(Pkt), _Route = self()), lists:foreach(fun diameter_lib:eval/1, EvalFs). -%% answer/6 - -answer({reply, Ans}, _Caps, _Pkt, App, Dict0, _RecvData) -> - {msg_dict(App#diameter_app.dictionary, Dict0, Ans), Ans}; - -answer({call, Opts}, Caps, Pkt, App, Dict0, RecvData) -> - #diameter_caps{origin_host = {OH,_}} - = Caps, - #diameter_packet{avps = Avps} - = Pkt, - {Code, _Flags, Vid} = Dict0:avp_header('Route-Record'), - resend(is_loop(Code, Vid, OH, Dict0, Avps), - Opts, - Caps, - Pkt, - App, - Dict0, - RecvData); - -%% RFC 3588 only allows 3xxx errors in an answer-message. RFC 6733 -%% added the possibility of setting 5xxx. -answer({answer_message, RC} = T, Caps, Pkt, App, Dict0, _RecvData) -> - Dict0 /= ?BASE orelse 3 == RC div 1000 - orelse ?ERROR({invalid_return, T, handle_request, App}), - answer_message(RC, Caps, Dict0, Pkt). - %% msg_dict/3 %% %% Return the dictionary defining the message grammar in question: the %% application dictionary or the common dictionary. -msg_dict(AppDict, Dict0, [Msg]) - when is_list(Msg); - is_tuple(Msg) -> +msg_dict(AppDict, Dict0, [Msg]) -> msg_dict(AppDict, Dict0, Msg); msg_dict(AppDict, Dict0, Msg) -> @@ -578,14 +631,10 @@ is_answer_message(Rec, Dict) -> error:_ -> false end. -%% answer_message/4 +%% resend/6 -answer_message(RC, - #diameter_caps{origin_host = {OH,_}, - origin_realm = {OR,_}}, - Dict0, - Pkt) -> - {Dict0, answer_message(OH, OR, RC, Dict0, Pkt)}. +resend(Opts, Caps, Pkt, App, Dict0, RecvData) -> + resend(is_loop(Dict0, Caps, Pkt), Opts, Caps, Pkt, App, Dict0, RecvData). %% resend/7 @@ -595,8 +644,8 @@ answer_message(RC, %% if one is available, but the peer reporting the error has %% identified a configuration problem. -resend(true, _Opts, Caps, Pkt, _App, Dict0, _RecvData) -> - answer_message(3005, Caps, Dict0, Pkt); +resend(true, _Opts, _Caps, _Pkt, _App, _Dict0, _RecvData) -> + 3005; %% 6.1.8. Relaying and Proxying Requests %% @@ -606,11 +655,9 @@ resend(true, _Opts, Caps, Pkt, _App, Dict0, _RecvData) -> resend(false, Opts, - #diameter_caps{origin_host = {_,OH}} - = Caps, + #diameter_caps{origin_host = {_,OH}}, #diameter_packet{header = Hdr0, - avps = Avps} - = Pkt, + avps = Avps}, App, Dict0, #recvdata{service_name = SvcName, @@ -619,7 +666,12 @@ resend(false, Seq = diameter_session:sequence(Mask), Hdr = Hdr0#diameter_header{hop_by_hop_id = Seq}, Msg = [Hdr, Route | Avps], %% reordered at encode - resend(send_request(SvcName, App, Msg, Opts), Caps, Dict0, Pkt). + case send_request(SvcName, App, Msg, Opts) of + #diameter_packet{} = Ans -> + Ans; + _ -> + 3002 %% DIAMETER_UNABLE_TO_DELIVER. + end. %% The incoming request is relayed with the addition of a %% Route-Record. Note the requirement on the return from call/4 below, %% which places a requirement on the value returned by the @@ -636,96 +688,38 @@ resend(false, %% RFC 6.3 says that a relay agent does not modify Origin-Host but %% says nothing about a proxy. Assume it should behave the same way. -%% resend/4 -%% -%% Relay a reply to a relayed request. - -%% Answer from the peer: reset the hop by hop identifier. -resend(#diameter_packet{bin = B} - = Pkt, - _Caps, - _Dict0, - #diameter_packet{header = #diameter_header{hop_by_hop_id = Id}, - transport_data = TD}) -> - Pkt#diameter_packet{bin = diameter_codec:hop_by_hop_id(Id, B), - transport_data = TD}; -%% TODO: counters +%% is_loop/3 -%% Or not: DIAMETER_UNABLE_TO_DELIVER. -resend(_, Caps, Dict0, Pkt) -> - answer_message(3002, Caps, Dict0, Pkt). +is_loop(Dict0, + #diameter_caps{origin_host = {OH,_}}, + #diameter_packet{avps = Avps}) -> + {Code, _Flags, Vid} = Dict0:avp_header('Route-Record'), + is_loop(Code, Vid, OH, Avps). -%% is_loop/5 +%% is_loop/4 %% %% Is there a Route-Record AVP with our Origin-Host? -is_loop(Code, - Vid, - Bin, - _Dict0, - [#diameter_avp{code = Code, vendor_id = Vid, data = Bin} | _]) -> +is_loop(Code, Vid, Bin, [#diameter_avp{code = Code, + vendor_id = Vid, + data = Bin} + | _]) -> true; -is_loop(_, _, _, _, []) -> +is_loop(_, _, _, []) -> false; -is_loop(Code, Vid, OH, Dict0, [_ | Avps]) +is_loop(Code, Vid, OH, [_ | Avps]) when is_binary(OH) -> - is_loop(Code, Vid, OH, Dict0, Avps); - -is_loop(Code, Vid, OH, Dict0, Avps) -> - is_loop(Code, Vid, Dict0:avp(encode, OH, 'Route-Record'), Dict0, Avps). - -%% reply/5 - -%% Local answer ... -reply({MsgDict, Ans}, TPid, {AppDict, Dict0}, Fs, ReqPkt) -> - local(Ans, TPid, {MsgDict, AppDict, Dict0}, Fs, ReqPkt); - -%% ... or relayed. -reply(#diameter_packet{} = Pkt, _TPid, {AppDict, Dict0}, Fs, _ReqPkt) -> - eval_packet(Pkt, Fs), - {msg_dict(AppDict, Dict0, Pkt), Pkt}. - -%% local/5 -%% -%% Send a locally originating reply. - -%% Skip the setting of Result-Code and Failed-AVP's below. This is -%% undocumented and shouldn't be relied on. -local([Msg], TPid, DictT, Fs, ReqPkt) - when is_list(Msg); - is_tuple(Msg) -> - local(Msg, TPid, DictT, Fs, ReqPkt#diameter_packet{errors = []}); - -local(Msg, TPid, {MsgDict, AppDict, Dict0}, Fs, ReqPkt) -> - Pkt = encode({MsgDict, AppDict}, - TPid, - reset(make_answer_packet(Msg, ReqPkt), MsgDict, Dict0), - Fs), - {MsgDict, Pkt}. - -%% reset/3 + is_loop(Code, Vid, OH, Avps); -%% Header/avps list: send as is. -reset(#diameter_packet{msg = [#diameter_header{} | _]} = Pkt, _, _) -> - Pkt; - -%% No errors to set or errors explicitly ignored. -reset(#diameter_packet{errors = Es} = Pkt, _, _) - when Es == []; - Es == false -> - Pkt; - -%% Otherwise possibly set Result-Code and/or Failed-AVP. -reset(#diameter_packet{msg = Msg, errors = Es} = Pkt, Dict, Dict0) -> - {RC, Failed} = select_error(Msg, Es, Dict0), - Pkt#diameter_packet{msg = reset(Msg, Dict, RC, Failed)}. +is_loop(Code, Vid, OH, Avps) -> + is_loop(Code, Vid, list_to_binary(OH), Avps). %% select_error/3 %% %% Extract the first appropriate RC or {RC, #diameter_avp{}} -%% pair from an errors list, and accumulate all #diameter_avp{}. +%% pair from an errors list, along with any leading #diameter_avp{}. %% %% RFC 6733: %% @@ -740,95 +734,138 @@ reset(#diameter_packet{msg = Msg, errors = Es} = Pkt, Dict, Dict0) -> %% indicated by the Result-Code AVP. For practical purposes, this %% Failed-AVP would typically refer to the first AVP processing error %% that a Diameter node encounters. +%% +%% 3xxx can only be set in an answer setting the E-bit. RFC 6733 also +%% allows 5xxx, RFC 3588 doesn't. -select_error(Msg, Es, Dict0) -> - {RC, Avps} = lists:foldl(fun(T,A) -> select(T, A, Dict0) end, - {is_answer_message(Msg, Dict0), []}, - Es), - {RC, lists:reverse(Avps)}. +select_error(E, Es, Dict0) -> + select(E, Es, Dict0, []). -%% Only integer() and {integer(), #diameter_avp{}} are the result of -%% decode. #diameter_avp{} can only be set in a reply for encode. +%% select/4 -select(#diameter_avp{} = A, {RC, As}, _) -> - {RC, [A|As]}; +select(E, [{RC, _} = T | Es], Dict0, Avps) -> + select(E, RC, T, Es, Dict0, Avps); -select(_, {RC, _} = Acc, _) - when is_integer(RC) -> - Acc; +select(E, [#diameter_avp{} = A | Es], Dict0, Avps) -> + select(E, Es, Dict0, [A | Avps]); -select({RC, #diameter_avp{} = A}, {IsAns, As} = Acc, Dict0) - when is_integer(RC) -> - case is_result(RC, IsAns, Dict0) of - true -> {RC, [A|As]}; - false -> Acc - end; +select(E, [RC | Es], Dict0, Avps) -> + select(E, RC, RC, Es, Dict0, Avps); -select(RC, {IsAns, As} = Acc, Dict0) - when is_boolean(IsAns), is_integer(RC) -> - case is_result(RC, IsAns, Dict0) of - true -> {RC, As}; - false -> Acc - end. +select(_, [], _, Avps) -> + Avps. -%% reset/4 +%% select/6 + +select(E, RC, T, _, Dict0, Avps) + when E, 3000 =< RC, RC < 4000; %% E-bit with 3xxx + E, ?BASE /= Dict0, 5000 =< RC, RC < 6000; %% E-bit with 5xxx + not E, RC < 3000 orelse 4000 =< RC -> %% no E-bit + [T | Avps]; -reset(Msg, Dict, RC, Avps) -> - FailedAVP = failed_avp(Msg, Avps, Dict), - ResultCode = rc(Msg, RC, Dict), - set(set(Msg, FailedAVP, Dict), ResultCode, Dict). +select(E, _, _, Es, Dict0, Avps) -> + select(E, Es, Dict0, Avps). %% eval_packet/2 eval_packet(Pkt, Fs) -> lists:foreach(fun(F) -> diameter_lib:eval([F,Pkt]) end, Fs). -%% make_answer_packet/2 +%% make_answer_packet/4 %% Use decode errors to set Result-Code and/or Failed-AVP unless the %% the errors field has been explicitly set. Unfortunately, the %% default value is the empty list rather than 'undefined' so use the %% atom 'false' for "set nothing". (This is historical and changing -%% the default value would require modules including diameter.hrl to -%% be recompiled.) -make_answer_packet(#diameter_packet{errors = []} - = Pkt, - #diameter_packet{errors = [_|_] = Es} - = ReqPkt) -> - make_answer_packet(Pkt#diameter_packet{errors = Es}, ReqPkt); +%% the default value would impact anyone expecting relying on the old +%% default.) -%% A reply message clears the R and T flags and retains the P flag. -%% The E flag will be set at encode. 6.2 of 3588 requires the same P -%% flag on an answer as on the request. A #diameter_packet{} returned -%% from a handle_request callback can circumvent this by setting its -%% own header values. make_answer_packet(#diameter_packet{header = Hdr, msg = Msg, errors = Es, transport_data = TD}, - #diameter_packet{header = ReqHdr}) -> - Hdr0 = ReqHdr#diameter_header{version = ?DIAMETER_VERSION, - is_request = false, - is_error = undefined, - is_retransmitted = false}, - #diameter_packet{header = fold_record(Hdr0, Hdr), - msg = Msg, - errors = Es, + #diameter_packet{header = Hdr0, + errors = Es0}, + MsgDict, + Dict0) -> + #diameter_packet{header = make_answer_header(Hdr0, Hdr), + msg = reset(Msg, Es0, Es, MsgDict, Dict0), transport_data = TD}; %% Binaries and header/avp lists are sent as-is. -make_answer_packet(Bin, #diameter_packet{transport_data = TD}) +make_answer_packet(Bin, #diameter_packet{transport_data = TD}, _, _) when is_binary(Bin) -> #diameter_packet{bin = Bin, transport_data = TD}; make_answer_packet([#diameter_header{} | _] = Msg, - #diameter_packet{transport_data = TD}) -> + #diameter_packet{transport_data = TD}, + _, + _) -> #diameter_packet{msg = Msg, transport_data = TD}; -%% Otherwise, preserve transport_data. -make_answer_packet(Msg, #diameter_packet{transport_data = TD} = Pkt) -> - make_answer_packet(#diameter_packet{msg = Msg, transport_data = TD}, Pkt). +make_answer_packet(Msg, + #diameter_packet{header = Hdr, + errors = Es, + transport_data = TD}, + MsgDict, + Dict0) -> + #diameter_packet{header = make_answer_header(Hdr, undefined), + msg = reset(Msg, [], Es, MsgDict, Dict0), + transport_data = TD}. + +%% make_answer_header/2 + +%% A reply message clears the R and T flags and retains the P flag. +%% The E flag will be set at encode. 6.2 of 3588 requires the same P +%% flag on an answer as on the request. A #diameter_packet{} returned +%% from a handle_request callback can circumvent this by setting its +%% own header values. +make_answer_header(ReqHdr, Hdr) -> + Hdr0 = ReqHdr#diameter_header{version = ?DIAMETER_VERSION, + is_request = false, + is_error = undefined, + is_retransmitted = false}, + fold_record(Hdr0, Hdr). + +%% reset/5 + +reset(Msg, [_|_] = Es0, [] = Es, MsgDict, Dict0) -> + reset(Msg, Es, Es0, MsgDict, Dict0); + +reset(Msg, _, Es, _, _) + when Es == false; + Es == [] -> + Msg; + +reset(Msg, _, Es, MsgDict, Dict0) -> + E = is_answer_message(Msg, Dict0), + reset(Msg, select_error(E, Es, Dict0), choose(E, Dict0, MsgDict)). + +%% reset/4 +%% +%% Set Result-Code and/or Failed-AVP (maybe). Only RC and {RC, AVP} +%% are the result of decode. AVP or {RC, [AVP]} can be set in an +%% answer for encode, as a convenience for injecting additional AVPs +%% into Failed-AVP; eg. 5001 = DIAMETER_AVP_UNSUPPORTED. + +reset(Msg, [], _) -> + Msg; + +reset(Msg, [{RC, As} | Avps], Dict) + when is_list(As) -> + reset(Msg, [RC | As ++ Avps], Dict); + +reset(Msg, [{RC, Avp} | Avps], Dict) -> + reset(Msg, [RC, Avp | Avps], Dict); + +reset(Msg, [#diameter_avp{} | _] = Avps, Dict) -> + set(Msg, failed_avp(Msg, Avps, Dict), Dict); + +reset(Msg, [RC | Avps], Dict) -> + set(Msg, rc(Msg, RC, Dict) ++ failed_avp(Msg, Avps, Dict), Dict). + +%% set/3 %% Reply as name and tuple list ... set([_|_] = Ans, Avps, _) -> @@ -842,11 +879,7 @@ set(Rec, Avps, Dict) -> %% %% Turn the result code into a list if its optional and only set it if %% the arity is 1 or {0,1}. In other cases (which probably shouldn't -%% exist in practise) we can't know what's appropriate. - -rc(_, B, _) - when is_boolean(B) -> - []; +%% exist in practice) we can't know what's appropriate. rc([MsgName | _], RC, Dict) -> K = 'Result-Code', @@ -864,8 +897,8 @@ rc(Rec, RC, Dict) -> failed_avp(_, [] = No, _) -> No; -failed_avp(Rec, Avps, Dict) -> - [failed(Rec, [{'AVP', Avps}], Dict)]. +failed_avp(Msg, [_|_] = Avps, Dict) -> + [failed(Msg, [{'AVP', Avps}], Dict)]. %% Reply as name and tuple list ... failed([MsgName | Values], FailedAvp, Dict) -> @@ -962,22 +995,26 @@ failed(Rec, FailedAvp, Dict) -> %% Error-Message AVP is not intended to be useful in real-time, and %% SHOULD NOT be expected to be parsed by network entities. -%% answer_message/5 +%% answer_message/4 -answer_message(OH, OR, RC, Dict0, #diameter_packet{avps = Avps, - errors = Es}) -> +answer_message(RC, + Dict0, + #diameter_caps{origin_host = {OH,_}, + origin_realm = {OR,_}}, + #diameter_packet{avps = Avps, + errors = Es}) -> {Code, _, Vid} = Dict0:avp_header('Session-Id'), ['answer-message', {'Origin-Host', OH}, {'Origin-Realm', OR}, {'Result-Code', RC}] - ++ session_id(Code, Vid, Dict0, Avps) + ++ session_id(Code, Vid, Avps) ++ failed_avp(RC, Es). -session_id(Code, Vid, Dict0, Avps) +session_id(Code, Vid, Avps) when is_list(Avps) -> try #diameter_avp{data = Bin} = find_avp(Code, Vid, Avps), - [{'Session-Id', [Dict0:avp(decode, Bin, 'Session-Id')]}] + [{'Session-Id', [Bin]}] catch error: _ -> [] @@ -1197,8 +1234,6 @@ get_result(Dict, Msg) -> try [throw(A) || N <- ['Result-Code', 'Experimental-Result'], #diameter_avp{} = A <- [get_avp(Dict, N, Msg)]] - of - [] -> false catch #diameter_avp{} = A -> A end. @@ -1207,7 +1242,7 @@ x(T) -> exit(T). %% --------------------------------------------------------------------------- -%% # send_request/4 +%% send_request/4 %% %% Handle an outgoing Diameter request. %% --------------------------------------------------------------------------- @@ -1260,11 +1295,10 @@ answer_rc(_, _, Sent) -> %% %% In the process spawned for the outgoing request. -send_R(SvcName, AppOrAlias, Msg, Opts, Caller) -> - case pick_peer(SvcName, AppOrAlias, Msg, Opts) of - {Transport, Mask, SvcOpts} -> - diameter_codec:setopts(SvcOpts), - send_request(Transport, Mask, Msg, Opts, Caller, SvcName); +send_R(SvcName, AppOrAlias, Msg, CallOpts, Caller) -> + case pick_peer(SvcName, AppOrAlias, Msg, CallOpts) of + {{_,_} = Transport, SvcOpts} -> + send_request(Transport, SvcOpts, Msg, CallOpts, Caller, SvcName); {error, _} = No -> No end. @@ -1272,31 +1306,45 @@ send_R(SvcName, AppOrAlias, Msg, Opts, Caller) -> %% make_options/1 make_options(Options) -> - lists:foldl(fun mo/2, #options{}, Options). - -mo({timeout, T}, Rec) - when is_integer(T), 0 =< T -> - Rec#options{timeout = T}; - -mo({filter, F}, #options{filter = none} = Rec) -> - Rec#options{filter = F}; -mo({filter, F}, #options{filter = {all, Fs}} = Rec) -> - Rec#options{filter = {all, [F | Fs]}}; -mo({filter, F}, #options{filter = F0} = Rec) -> - Rec#options{filter = {all, [F0, F]}}; - -mo({extra, L}, #options{extra = X} = Rec) + make_opts(Options, [], false, [], none, 5000). + +%% Do our own recursion since this is faster than a lists:foldl/3 +%% setting elements in an #options{} accumulator. + +make_opts([], Peers, Detach, Extra, Filter, Tmo) -> + #options{peers = lists:reverse(Peers), + detach = Detach, + extra = Extra, + filter = Filter, + timeout = Tmo}; + +make_opts([{timeout, Tmo} | Rest], Peers, Detach, Extra, Filter, _) + when is_integer(Tmo), 0 =< Tmo -> + make_opts(Rest, Peers, Detach, Extra, Filter, Tmo); + +make_opts([{filter, F} | Rest], Peers, Detach, Extra, none, Tmo) -> + make_opts(Rest, Peers, Detach, Extra, F, Tmo); +make_opts([{filter, F} | Rest], Peers, Detach, Extra, {all, Fs}, Tmo) -> + make_opts(Rest, Peers, Detach, Extra, {all, [F|Fs]}, Tmo); +make_opts([{filter, F} | Rest], Peers, Detach, Extra, F0, Tmo) -> + make_opts(Rest, Peers, Detach, Extra, {all, [F0, F]}, Tmo); + +make_opts([{extra, L} | Rest], Peers, Detach, Extra, Filter, Tmo) when is_list(L) -> - Rec#options{extra = X ++ L}; + make_opts(Rest, Peers, Detach, Extra ++ L, Filter, Tmo); + +make_opts([detach | Rest], Peers, _, Extra, Filter, Tmo) -> + make_opts(Rest, Peers, true, Extra, Filter, Tmo); -mo(detach, Rec) -> - Rec#options{detach = true}; +make_opts([{peer, TPid} | Rest], Peers, Detach, Extra, Filter, Tmo) + when is_pid(TPid) -> + make_opts(Rest, [TPid | Peers], Detach, Extra, Filter, Tmo); -mo(T, _) -> +make_opts([T | _], _, _, _, _, _) -> ?ERROR({invalid_option, T}). %% --------------------------------------------------------------------------- -%% # send_request/6 +%% send_request/6 %% --------------------------------------------------------------------------- %% Send an outgoing request in its dedicated process. @@ -1309,44 +1357,51 @@ mo(T, _) -> %% The module field of the #diameter_app{} here includes any extra %% arguments passed to diameter:call/4. -send_request({TPid, Caps, App} +send_request({{TPid, _Caps} = TC, App} = Transport, - Mask, - Msg, - Opts, + #{sequence := Mask} + = SvcOpts, + Msg0, + CallOpts, Caller, SvcName) -> - Pkt = make_prepare_packet(Mask, Msg), - - send_R(cb(App, prepare_request, [Pkt, SvcName, {TPid, Caps}]), - Pkt, - Transport, - Opts, - Caller, - SvcName, - []). + Pkt = make_prepare_packet(Mask, Msg0), + + case prepare(cb(App, prepare_request, [Pkt, SvcName, TC]), []) of + [Msg | Fs] -> + ReqPkt = make_request_packet(Msg, Pkt), + EncPkt = encode(App#diameter_app.dictionary, + TPid, + SvcOpts, + ReqPkt), + eval_packet(EncPkt, Fs), + T = send_R(ReqPkt, EncPkt, Transport, CallOpts, Caller, SvcName), + Ans = recv_answer(SvcName, App, CallOpts, T), + handle_answer(SvcName, SvcOpts, App, Ans); + {discard, Reason} -> + {error, Reason}; + discard -> + {error, discarded}; + {error, Reason} -> + ?ERROR({invalid_return, Reason, prepare_request, App}) + end. -%% send_R/7 +%% prepare/2 -send_R({send, Msg}, Pkt, Transport, Opts, Caller, SvcName, Fs) -> - send_R(make_request_packet(Msg, Pkt), - Transport, - Opts, - Caller, - SvcName, - Fs); +prepare({send, Msg}, Fs) -> + [Msg | Fs]; -send_R({discard, Reason} , _, _, _, _, _, _) -> - {error, Reason}; +prepare({eval_packet, RC, F}, Fs) -> + prepare(RC, [F|Fs]); -send_R(discard, _, _, _, _, _, _) -> - {error, discarded}; +prepare({discard, _Reason} = RC, _) -> + RC; -send_R({eval_packet, RC, F}, Pkt, T, Opts, Caller, SvcName, Fs) -> - send_R(RC, Pkt, T, Opts, Caller, SvcName, [F|Fs]); +prepare(discard = RC, _) -> + RC; -send_R(E, _, {_, _, App}, _, _, _, _) -> - ?ERROR({invalid_return, E, prepare_request, App}). +prepare(Reason, _) -> + {error, Reason}. %% make_prepare_packet/2 %% @@ -1366,43 +1421,39 @@ make_prepare_packet(Mask, #diameter_packet{msg = [#diameter_header{} = Hdr make_prepare_packet(Mask, #diameter_packet{header = Hdr} = Pkt) -> Pkt#diameter_packet{header = make_prepare_header(Mask, Hdr)}; +make_prepare_packet(Mask, [#diameter_header{} = Hdr | Avps]) -> + #diameter_packet{msg = [make_prepare_header(Mask, Hdr) | Avps]}; + make_prepare_packet(Mask, Msg) -> - make_prepare_packet(Mask, #diameter_packet{msg = Msg}). + #diameter_packet{header = make_prepare_header(Mask, undefined), + msg = Msg}. %% make_prepare_header/2 make_prepare_header(Mask, undefined) -> Seq = diameter_session:sequence(Mask), - make_prepare_header(#diameter_header{end_to_end_id = Seq, - hop_by_hop_id = Seq}); - -make_prepare_header(Mask, #diameter_header{end_to_end_id = undefined, - hop_by_hop_id = undefined} - = H) -> - Seq = diameter_session:sequence(Mask), - make_prepare_header(H#diameter_header{end_to_end_id = Seq, - hop_by_hop_id = Seq}); - -make_prepare_header(Mask, #diameter_header{end_to_end_id = undefined} = H) -> - Seq = diameter_session:sequence(Mask), - make_prepare_header(H#diameter_header{end_to_end_id = Seq}); - -make_prepare_header(Mask, #diameter_header{hop_by_hop_id = undefined} = H) -> - Seq = diameter_session:sequence(Mask), - make_prepare_header(H#diameter_header{hop_by_hop_id = Seq}); - -make_prepare_header(_, Hdr) -> - make_prepare_header(Hdr). - -%% make_prepare_header/1 - -make_prepare_header(#diameter_header{version = undefined} = Hdr) -> - make_prepare_header(Hdr#diameter_header{version = ?DIAMETER_VERSION}); - -make_prepare_header(#diameter_header{} = Hdr) -> - Hdr; - -make_prepare_header(T) -> + #diameter_header{version = ?DIAMETER_VERSION, + end_to_end_id = Seq, + hop_by_hop_id = Seq}; + +make_prepare_header(Mask, #diameter_header{version = V, + end_to_end_id = EI, + hop_by_hop_id = HI} + = H) + when EI == undefined; + HI == undefined -> + Id = diameter_session:sequence(Mask), + H#diameter_header{version = ?DEFAULT(V, ?DIAMETER_VERSION), + end_to_end_id = ?DEFAULT(EI, Id), + hop_by_hop_id = ?DEFAULT(HI, Id)}; + +make_prepare_header(_, #diameter_header{version = undefined} = H) -> + H#diameter_header{version = ?DIAMETER_VERSION}; + +make_prepare_header(_, #diameter_header{} = H) -> + H; + +make_prepare_header(_, T) -> ?ERROR({invalid_header, T}). %% make_request_packet/2 @@ -1446,42 +1497,45 @@ make_retransmit_header(Hdr) -> Hdr#diameter_header{is_retransmitted = true}. %% fold_record/2 +%% +%% Replace elements in the first record by those in the second that +%% differ from undefined. -fold_record(undefined, R) -> - R; -fold_record(Rec, R) -> - diameter_lib:fold_tuple(2, Rec, R). +fold_record(Rec0, undefined) -> + Rec0; +fold_record(Rec0, Rec) -> + list_to_tuple(fold(tuple_to_list(Rec0), tuple_to_list(Rec))). + +fold([], []) -> + []; +fold([H | T0], [undefined | T]) -> + [H | fold(T0, T)]; +fold([_ | T0], [H | T]) -> + [H | fold(T0, T)]. %% send_R/6 -send_R(Pkt0, - {TPid, Caps, #diameter_app{dictionary = AppDict} = App}, - Opts, +send_R(ReqPkt, + EncPkt, + {{TPid, _Caps} = TC, #diameter_app{dictionary = AppDict}}, + #options{timeout = Timeout}, {Pid, Ref}, - SvcName, - Fs) -> - Pkt = encode(AppDict, TPid, Pkt0, Fs), - - #options{timeout = Timeout} - = Opts, - + SvcName) -> Req = #request{ref = Ref, caller = Pid, handler = self(), - transport = TPid, - caps = Caps, - packet = Pkt0}, + peer = TC, + packet = ReqPkt}, - incr(send, Pkt, TPid, AppDict), - {TRef, MRef} = zend_requezt(TPid, Pkt, Req, SvcName, Timeout), + incr(send, EncPkt, TPid, AppDict), + {TRef, MRef} = zend_requezt(TPid, EncPkt, Req, SvcName, Timeout), Pid ! Ref, %% tell caller a send has been attempted - handle_answer(SvcName, - App, - recv_A(Timeout, SvcName, App, Opts, {TRef, MRef, Req})). + {TRef, MRef, Req}. -%% recv_A/5 +%% recv_answer/4 -recv_A(Timeout, SvcName, App, Opts, {TRef, MRef, #request{ref = Ref} = Req}) -> +recv_answer(SvcName, App, CallOpts, {TRef, MRef, #request{ref = Ref} + = Req}) -> %% Matching on TRef below ensures we ignore messages that pertain %% to a previous transport prior to failover. The answer message %% includes the pid of the transport on which it was received, @@ -1492,97 +1546,90 @@ recv_A(Timeout, SvcName, App, Opts, {TRef, MRef, #request{ref = Ref} = Req}) -> {timeout = Reason, TRef, _} -> %% No timely reply {error, Req, Reason}; {'DOWN', MRef, process, _, _} when false /= MRef -> %% local peer_down - failover(SvcName, App, Req, Opts, Timeout); + failover(SvcName, App, Req, CallOpts); {failover, TRef} -> %% local or remote peer_down - failover(SvcName, App, Req, Opts, Timeout) + failover(SvcName, App, Req, CallOpts) end. -%% failover/5 +%% failover/4 -failover(SvcName, App, Req, Opts, Timeout) -> - retransmit(pick_peer(SvcName, App, Req, Opts), - Req, - Opts, - SvcName, - Timeout). +failover(SvcName, App, Req, CallOpts) -> + resend_request(pick_peer(SvcName, App, Req, CallOpts), + Req, + CallOpts, + SvcName). -%% handle_answer/3 +%% handle_answer/4 -handle_answer(SvcName, App, {error, Req, Reason}) -> - handle_error(App, Req, Reason, SvcName); +handle_answer(SvcName, _, App, {error, Req, Reason}) -> + #request{packet = Pkt, + peer = {_TPid, _Caps} = TC} + = Req, + cb(App, handle_error, [Reason, msg(Pkt), SvcName, TC]); handle_answer(SvcName, - #diameter_app{dictionary = AppDict, - id = Id} + SvcOpts, + #diameter_app{id = Id, + dictionary = AppDict, + options = [{answer_errors, AE} | _]} = App, {answer, Req, Dict0, Pkt}) -> MsgDict = msg_dict(AppDict, Dict0, Pkt), - handle_A(errors(Id, diameter_codec:decode({MsgDict, AppDict}, Pkt)), - SvcName, - MsgDict, - Dict0, - App, - Req). - -%% We don't really need to do a full decode if we're a relay and will -%% just resend with a new hop by hop identifier, but might a proxy -%% want to examine the answer? - -handle_A(Pkt, SvcName, Dict, Dict0, App, #request{transport = TPid} = Req) -> - AppDict = App#diameter_app.dictionary, - - incr(recv, Pkt, TPid, AppDict), - - try - incr_result(recv, Pkt, TPid, {Dict, AppDict, Dict0}) %% count incoming - of - _ -> answer(Pkt, SvcName, App, Req) - catch - exit: {no_result_code, _} -> - %% RFC 6733 requires one of Result-Code or - %% Experimental-Result, but the decode will have detected - %% a missing AVP. If both are optional in the dictionary - %% then this isn't a decode error: just continue on. - answer(Pkt, SvcName, App, Req); - exit: {invalid_error_bit, {_, _, _, Avp}} -> - #diameter_packet{errors = Es} - = Pkt, - E = {5004, Avp}, - answer(Pkt#diameter_packet{errors = [E|Es]}, SvcName, App, Req) - end. - -%% answer/4 - -answer(Pkt, - SvcName, - #diameter_app{module = ModX, - options = [{answer_errors, AE} | _]}, - Req) -> - a(Pkt, SvcName, ModX, AE, Req). - --spec a(_, _, _) -> no_return(). %% silence dialyzer - -a(#diameter_packet{errors = Es} - = Pkt, - SvcName, - ModX, - AE, - #request{transport = TPid, - caps = Caps, - packet = P}) - when [] == Es; - callback == AE -> - cb(ModX, handle_answer, [Pkt, msg(P), SvcName, {TPid, Caps}]); - -a(Pkt, SvcName, _, AE, _) -> - a(Pkt#diameter_packet.header, SvcName, AE). - -a(Hdr, SvcName, report) -> + DecPkt = errors(Id, diameter_codec:decode({MsgDict, AppDict}, + SvcOpts, + Pkt)), + #request{peer = {TPid, _}} + = Req, + + incr(recv, DecPkt, TPid, AppDict), + + AnsPkt = try + incr_result(recv, DecPkt, TPid, {MsgDict, AppDict, Dict0}) + of + _ -> DecPkt + catch + exit: {no_result_code, _} -> + %% RFC 6733 requires one of Result-Code or + %% Experimental-Result, but the decode will have + %% detected a missing AVP. If both are optional in + %% the dictionary then this isn't a decode error: + %% just continue on. + DecPkt; + exit: {invalid_error_bit, {_, _, _, Avp}} -> + #diameter_packet{errors = Es} + = DecPkt, + E = {5004, Avp}, + DecPkt#diameter_packet{errors = [E|Es]} + end, + + handle_answer(AnsPkt, SvcName, App, AE, Req). + +%% handle_answer/5 + +handle_answer(#diameter_packet{errors = Es} + = Pkt, + SvcName, + App, + AE, + #request{peer = {_TPid, _Caps} = TC, + packet = P}) + when callback == AE; + [] == Es -> + cb(App, handle_answer, [Pkt, msg(P), SvcName, TC]); + +handle_answer(#diameter_packet{header = H}, SvcName, _, AE, _) -> + handle_error(H, SvcName, AE). + +%% handle_error/3 + +-spec handle_error(_, _, _) -> no_return(). %% silence dialyzer + +handle_error(Hdr, SvcName, report) -> MFA = {?MODULE, handle_answer, [SvcName, Hdr]}, diameter_lib:warning_report(errors, MFA), - a(Hdr, SvcName, discard); + handle_error(Hdr, SvcName, discard); -a(Hdr, SvcName, discard) -> +handle_error(Hdr, SvcName, discard) -> x({answer_errors, {SvcName, Hdr}}). %% Note that we don't check that the application id in the answer's @@ -1593,16 +1640,38 @@ a(Hdr, SvcName, discard) -> %% timer value is ignored. This means that an answer could be accepted %% from a peer after timeout in the case of failover. -%% retransmit/5 +%% resend_request/4 -retransmit({{_,_,App} = Transport, _, _}, Req, Opts, SvcName, Timeout) -> - try retransmit(Transport, Req, SvcName, Timeout) of - T -> recv_A(Timeout, SvcName, App, Opts, T) - catch - ?FAILURE(Reason) -> {error, Req, Reason} +resend_request({{{TPid, _Caps} = TC, App}, SvcOpts}, + Req0, + #options{timeout = Timeout} + = CallOpts, + SvcName) -> + case + undefined == get(TPid) + andalso prepare_retransmit(TC, App, Req0, SvcName) + of + [ReqPkt | Fs] -> + AppDict = App#diameter_app.dictionary, + EncPkt = encode(AppDict, TPid, SvcOpts, ReqPkt), + eval_packet(EncPkt, Fs), + Req = Req0#request{peer = TC, + packet = ReqPkt}, + ?LOG(retransmission, EncPkt#diameter_packet.header), + incr(TPid, {msg_id(EncPkt, AppDict), send, retransmission}), + {TRef, MRef} = zend_requezt(TPid, EncPkt, Req, SvcName, Timeout), + recv_answer(SvcName, App, CallOpts, {TRef, MRef, Req}); + false -> + {error, Req0, timeout}; + {discard, Reason} -> + {error, Req0, Reason}; + discard -> + {error, Req0, discarded}; + {error, T} -> + ?ERROR({invalid_return, T, prepare_retransmit, App}) end; -retransmit(_, Req, _, _, _) -> %% no alternate peer +resend_request(_, Req, _, _) -> %% no alternate peer {error, Req, failover}. %% pick_peer/4 @@ -1612,8 +1681,8 @@ retransmit(_, Req, _, _, _) -> %% no alternate peer pick_peer(SvcName, App, #request{packet = #diameter_packet{msg = Msg}}, - Opts) -> - pick_peer(SvcName, App, Msg, Opts#options{extra = []}); + CallOpts) -> + pick_peer(SvcName, App, Msg, CallOpts#options{extra = []}); pick_peer(_, _, undefined, _) -> {error, no_connection}; @@ -1621,28 +1690,14 @@ pick_peer(_, _, undefined, _) -> pick_peer(SvcName, AppOrAlias, Msg, - #options{filter = Filter, extra = Xtra}) -> - pick(diameter_service:pick_peer(SvcName, - AppOrAlias, - {fun(D) -> get_destination(D, Msg) end, - Filter, - Xtra})). - -pick(false) -> - {error, no_connection}; - -pick(T) -> - T. - -%% handle_error/4 - -handle_error(App, - #request{packet = Pkt, - transport = TPid, - caps = Caps}, - Reason, - SvcName) -> - cb(App, handle_error, [Reason, msg(Pkt), SvcName, {TPid, Caps}]). + #options{peers = TPids, filter = Filter, extra = Xtra}) -> + X = {fun(D) -> get_destination(D, Msg) end, Filter, Xtra, TPids}, + case diameter_service:pick_peer(SvcName, AppOrAlias, X) of + false -> + {error, no_connection}; + T -> + T + end. msg(#diameter_packet{msg = undefined, bin = Bin}) -> Bin; @@ -1651,27 +1706,20 @@ msg(#diameter_packet{msg = Msg}) -> %% encode/4 -encode(Dict, TPid, Pkt, Fs) -> - P = encode(Dict, TPid, Pkt), - eval_packet(P, Fs), - P. - -%% encode/2 - %% Note that prepare_request can return a diameter_packet containing a %% header or transport_data. Even allow the returned record to contain %% an encoded binary. This isn't the usual case and doesn't properly %% support retransmission but is useful for test. -encode(Dict, TPid, Pkt) +encode(Dict, TPid, Opts, Pkt) when is_atom(Dict) -> - encode({Dict, Dict}, TPid, Pkt); + encode({Dict, Dict}, TPid, Opts, Pkt); %% A message to be encoded. -encode(DictT, TPid, #diameter_packet{bin = undefined} = Pkt) -> +encode(DictT, TPid, Opts, #diameter_packet{bin = undefined} = Pkt) -> {Dict, AppDict} = DictT, try - diameter_codec:encode(Dict, Pkt) + diameter_codec:encode(Dict, Opts, Pkt) catch exit: {diameter_codec, encode, T} = Reason -> incr_error(send, T, TPid, AppDict), @@ -1679,7 +1727,7 @@ encode(DictT, TPid, #diameter_packet{bin = undefined} = Pkt) -> end; %% An encoded binary: just send. -encode(_, _, #diameter_packet{} = Pkt) -> +encode(_, _, _, #diameter_packet{} = Pkt) -> Pkt. %% zend_requezt/5 @@ -1745,81 +1793,26 @@ recv(TPid, Pid, TRef, {LocalTRef, MRef}) -> exit({timeout, LocalTRef, TPid} = T) end. -%% send/2 - -send(Pid, Pkt) -> - Pid ! {send, Pkt}. - %% send/3 send(Pid, Pkt, Route) -> Pid ! {send, Pkt, Route}. -%% retransmit/4 +%% prepare_retransmit/4 -retransmit({TPid, Caps, App} - = Transport, - #request{packet = Pkt0} - = Req, - SvcName, - Timeout) -> - undefined == get(TPid) %% Don't failover to a peer we've - orelse ?THROW(timeout), %% already sent to. +prepare_retransmit({_TPid, _Caps} = TC, App, Req, SvcName) -> + Pkt = make_retransmit_packet(Req#request.packet), - Pkt = make_retransmit_packet(Pkt0), + case prepare(cb(App, prepare_retransmit, [Pkt, SvcName, TC]), []) of + [Msg | Fs] -> + [make_request_packet(Msg, Pkt) | Fs]; + No -> + No + end. - retransmit(cb(App, prepare_retransmit, [Pkt, SvcName, {TPid, Caps}]), - Transport, - Req#request{packet = Pkt}, - SvcName, - Timeout, - []). %% When sending a binary, it's up to prepare_retransmit to modify it %% accordingly. -retransmit({send, Msg}, - Transport, - #request{packet = Pkt} - = Req, - SvcName, - Timeout, - Fs) -> - resend_request(make_request_packet(Msg, Pkt), - Transport, - Req, - SvcName, - Timeout, - Fs); - -retransmit({discard, Reason}, _, _, _, _, _) -> - ?THROW(Reason); - -retransmit(discard, _, _, _, _, _) -> - ?THROW(discarded); - -retransmit({eval_packet, RC, F}, Transport, Req, SvcName, Timeout, Fs) -> - retransmit(RC, Transport, Req, SvcName, Timeout, [F|Fs]); - -retransmit(T, {_, _, App}, _, _, _, _) -> - ?ERROR({invalid_return, T, prepare_retransmit, App}). - -resend_request(Pkt0, - {TPid, Caps, #diameter_app{dictionary = AppDict}}, - Req0, - SvcName, - Tmo, - Fs) -> - Pkt = encode(AppDict, TPid, Pkt0, Fs), - - Req = Req0#request{transport = TPid, - packet = Pkt0, - caps = Caps}, - - ?LOG(retransmission, Pkt#diameter_packet.header), - incr(TPid, {msg_id(Pkt, AppDict), send, retransmission}), - {TRef, MRef} = zend_requezt(TPid, Pkt, Req, SvcName, Tmo), - {TRef, MRef, Req}. - %% peer_monitor/2 peer_monitor(TPid, TRef) -> @@ -1919,7 +1912,7 @@ ungroup(Avp) -> avp_decode(Dict, Name, #diameter_avp{value = undefined, data = Bin} = Avp) -> - try Dict:avp(decode, Bin, Name) of + try Dict:avp(decode, Bin, Name, decode_opts(Dict)) of V -> Avp#diameter_avp{value = V} catch @@ -1930,8 +1923,6 @@ avp_decode(_, _, #diameter_avp{} = Avp) -> Avp. cb(#diameter_app{module = [_|_] = M}, F, A) -> - eval(M, F, A); -cb([_|_] = M, F, A) -> eval(M, F, A). eval([M|X], F, A) -> @@ -1939,3 +1930,10 @@ eval([M|X], F, A) -> choose(true, X, _) -> X; choose(false, _, X) -> X. + +%% Decode options sufficient for AVP extraction. +decode_opts(Dict) -> + #{string_decode => false, + strict_mbit => false, + failed_avp => false, + dictionary => Dict}. diff --git a/lib/diameter/src/base/diameter_types.erl b/lib/diameter/src/base/diameter_types.erl index 6ecf385239..86b674dd48 100644 --- a/lib/diameter/src/base/diameter_types.erl +++ b/lib/diameter/src/base/diameter_types.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2015. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -26,32 +26,16 @@ %% %% Basic types. --export(['OctetString'/2, - 'Integer32'/2, - 'Integer64'/2, - 'Unsigned32'/2, - 'Unsigned64'/2, - 'Float32'/2, - 'Float64'/2]). - -%% Derived types. --export(['Address'/2, - 'Time'/2, - 'UTF8String'/2, - 'DiameterIdentity'/2, - 'DiameterURI'/2, - 'IPFilterRule'/2, - 'QoSFilterRule'/2]). - -%% Functions taking the AVP name in question as second parameter. -export(['OctetString'/3, 'Integer32'/3, 'Integer64'/3, 'Unsigned32'/3, 'Unsigned64'/3, 'Float32'/3, - 'Float64'/3, - 'Address'/3, + 'Float64'/3]). + +%% Derived types. +-export(['Address'/3, 'Time'/3, 'UTF8String'/3, 'DiameterIdentity'/3, @@ -89,81 +73,80 @@ %% AVP Data Format is needed, a new version of this RFC must be created. %% -------------------- -'OctetString'(decode, Bin) +'OctetString'(decode, Bin, #{string_decode := true}) when is_binary(Bin) -> - case diameter_codec:getopt(string_decode) of - true -> - binary_to_list(Bin); - false -> - Bin - end; - -'OctetString'(decode, B) -> + binary_to_list(Bin); + +'OctetString'(decode, Bin, _) + when is_binary(Bin) -> + Bin; + +'OctetString'(decode, B, _) -> ?INVALID_LENGTH(B); -'OctetString'(encode = M, zero) -> - 'OctetString'(M, []); +'OctetString'(encode, zero, _) -> + <<>>; -'OctetString'(encode, Str) -> +'OctetString'(encode, Str, _) -> iolist_to_binary(Str). %% -------------------- -'Integer32'(decode, <<X:32/signed>>) -> +'Integer32'(decode, <<X:32/signed>>, _) -> X; -'Integer32'(decode, B) -> +'Integer32'(decode, B, _) -> ?INVALID_LENGTH(B); -'Integer32'(encode = M, zero) -> - 'Integer32'(M, 0); +'Integer32'(encode, zero, _) -> + <<0:32/signed>>; -'Integer32'(encode, I) +'Integer32'(encode, I, _) when ?SINT(32,I) -> <<I:32/signed>>. %% -------------------- -'Integer64'(decode, <<X:64/signed>>) -> +'Integer64'(decode, <<X:64/signed>>, _) -> X; -'Integer64'(decode, B) -> +'Integer64'(decode, B, _) -> ?INVALID_LENGTH(B); -'Integer64'(encode = M, zero) -> - 'Integer64'(M, 0); +'Integer64'(encode, zero, _) -> + <<0:64/signed>>; -'Integer64'(encode, I) +'Integer64'(encode, I, _) when ?SINT(64,I) -> <<I:64/signed>>. %% -------------------- -'Unsigned32'(decode, <<X:32>>) -> +'Unsigned32'(decode, <<X:32>>, _) -> X; -'Unsigned32'(decode, B) -> +'Unsigned32'(decode, B, _) -> ?INVALID_LENGTH(B); -'Unsigned32'(encode = M, zero) -> - 'Unsigned32'(M, 0); +'Unsigned32'(encode, zero, _) -> + <<0:32>>; -'Unsigned32'(encode, I) +'Unsigned32'(encode, I, _) when ?UINT(32,I) -> <<I:32>>. %% -------------------- -'Unsigned64'(decode, <<X:64>>) -> +'Unsigned64'(decode, <<X:64>>, _) -> X; -'Unsigned64'(decode, B) -> +'Unsigned64'(decode, B, _) -> ?INVALID_LENGTH(B); -'Unsigned64'(encode = M, zero) -> - 'Unsigned64'(M, 0); +'Unsigned64'(encode, zero, _) -> + <<0:64>>; -'Unsigned64'(encode, I) +'Unsigned64'(encode, I, _) when ?UINT(64,I) -> <<I:64>>. @@ -184,25 +167,25 @@ %% arithmetic is performed on the decoded value. Better to be explicit %% that precision has been lost. -'Float32'(decode, <<S:1, 255:8, _:23>>) -> +'Float32'(decode, <<S:1, 255:8, _:23>>, _) -> choose(S, infinity, '-infinity'); -'Float32'(decode, <<X:32/float>>) -> +'Float32'(decode, <<X:32/float>>, _) -> X; -'Float32'(decode, B) -> +'Float32'(decode, B, _) -> ?INVALID_LENGTH(B); -'Float32'(encode = M, zero) -> - 'Float32'(M, 0.0); +'Float32'(encode, zero, _) -> + <<0.0:32/float>>; -'Float32'(encode, infinity) -> +'Float32'(encode, infinity, _) -> <<0:1, 255:8, 0:23>>; -'Float32'(encode, '-infinity') -> +'Float32'(encode, '-infinity', _) -> <<1:1, 255:8, 0:23>>; -'Float32'(encode, X) +'Float32'(encode, X, _) when is_float(X) -> <<X:32/float>>. %% Note that this could also encode infinity/-infinity for large @@ -222,25 +205,25 @@ %% The 64 bit format is entirely analogous to the 32 bit format. -'Float64'(decode, <<S:1, 2047:11, _:52>>) -> +'Float64'(decode, <<S:1, 2047:11, _:52>>, _) -> choose(S, infinity, '-infinity'); -'Float64'(decode, <<X:64/float>>) -> +'Float64'(decode, <<X:64/float>>, _) -> X; -'Float64'(decode, B) -> +'Float64'(decode, B, _) -> ?INVALID_LENGTH(B); -'Float64'(encode, infinity) -> +'Float64'(encode, infinity, _) -> <<0:1, 2047:11, 0:52>>; -'Float64'(encode, '-infinity') -> +'Float64'(encode, '-infinity', _) -> <<1:1, 2047:11, 0:52>>; -'Float64'(encode = M, zero) -> - 'Float64'(M, 0.0); +'Float64'(encode, zero, _) -> + <<0.0:64/float>>; -'Float64'(encode, X) +'Float64'(encode, X, _) when is_float(X) -> <<X:64/float>>. @@ -256,18 +239,18 @@ %% format. %% -------------------- -'Address'(encode, zero) -> +'Address'(encode, zero, _) -> <<0:48>>; -'Address'(decode, <<A:16, B/binary>>) +'Address'(decode, <<A:16, B/binary>>, _) when 1 == A, 4 == size(B); 2 == A, 16 == size(B) -> list_to_tuple([N || <<N:A/unit:8>> <= B]); -'Address'(decode, B) -> +'Address'(decode, B, _) -> ?INVALID_LENGTH(B); -'Address'(encode, T) -> +'Address'(encode, T, _) -> Ns = tuple_to_list(diameter_lib:ipaddr(T)), %% length 4 or 8 A = length(Ns) div 4, %% 1 or 2 B = << <<N:A/unit:8>> || N <- Ns >>, @@ -278,36 +261,38 @@ %% A DiameterIdentity is a FQDN as definined in RFC 1035, which is at %% least one character. -'DiameterIdentity'(encode = M, zero) -> - 'OctetString'(M, [0]); +'DiameterIdentity'(encode, zero, _) -> + <<0>>; -'DiameterIdentity'(encode = M, X) -> - <<_,_/binary>> = 'OctetString'(M, X); +'DiameterIdentity'(encode = M, X, Opts) -> + <<_,_/binary>> = 'OctetString'(M, X, Opts); -'DiameterIdentity'(decode = M, <<_,_/binary>> = X) -> - 'OctetString'(M, X); +'DiameterIdentity'(decode = M, <<_,_/binary>> = X, Opts) -> + 'OctetString'(M, X, Opts); -'DiameterIdentity'(decode, X) -> +'DiameterIdentity'(decode, X, _) -> ?INVALID_LENGTH(X). %% -------------------- -'DiameterURI'(decode, Bin) +'DiameterURI'(decode, Bin, Opts) when is_binary(Bin) -> - scan_uri(Bin); + scan_uri(Bin, Opts); -'DiameterURI'(decode, B) -> +'DiameterURI'(decode, B, _) -> ?INVALID_LENGTH(B); %% The minimal DiameterURI is "aaa://x", 7 characters. -'DiameterURI'(encode = M, zero) -> - 'OctetString'(M, lists:duplicate(0,7)); - -'DiameterURI'(encode, #diameter_uri{type = Type, - fqdn = DN, - port = PN, - transport = T, - protocol = P}) +'DiameterURI'(encode, zero, _) -> + <<0:7/unit:8>>; + +'DiameterURI'(encode, + #diameter_uri{type = Type, + fqdn = DN, + port = PN, + transport = T, + protocol = P}, + _) when (Type == 'aaa' orelse Type == 'aaas'), is_integer(PN), 0 =< PN, @@ -324,48 +309,47 @@ %% defaults, so it's best to be explicit. Interpret defaults on decode %% since there's no choice. -'DiameterURI'(encode, Str) -> +'DiameterURI'(encode, Str, Opts) -> Bin = iolist_to_binary(Str), - #diameter_uri{} = scan_uri(Bin), %% assert + #diameter_uri{} = scan_uri(Bin, Opts), %% assert Bin. %% -------------------- %% This minimal rule is "deny in 0 from 0.0.0.0 to 0.0.0.0", 33 characters. -'IPFilterRule'(encode = M, zero) -> - 'OctetString'(M, lists:duplicate(0,33)); +'IPFilterRule'(encode, zero, _) -> + <<0:33/unit:8>>; -'IPFilterRule'(M, X) -> - 'OctetString'(M, X). +'IPFilterRule'(M, X, Opts) -> + 'OctetString'(M, X, Opts). %% -------------------- %% This minimal rule is the same as for an IPFilterRule. -'QoSFilterRule'(encode = M, zero = X) -> - 'IPFilterRule'(M, X); +'QoSFilterRule'(encode, zero, _) -> + <<0:33/unit:8>>; -'QoSFilterRule'(M, X) -> - 'OctetString'(M, X). +'QoSFilterRule'(M, X, Opts) -> + 'OctetString'(M, X, Opts). %% -------------------- -'UTF8String'(decode, Bin) +'UTF8String'(decode, Bin, #{string_decode := true}) when is_binary(Bin) -> - case diameter_codec:getopt(string_decode) of - true -> - %% assert list return - tl([0|_] = unicode:characters_to_list([0, Bin])); - false -> - <<_/binary>> = unicode:characters_to_binary(Bin) - end; - -'UTF8String'(decode, B) -> + %% assert list return + tl([0|_] = unicode:characters_to_list([0, Bin])); + +'UTF8String'(decode, Bin, _) + when is_binary(Bin) -> + <<_/binary>> = unicode:characters_to_binary(Bin); + +'UTF8String'(decode, B, _) -> ?INVALID_LENGTH(B); -'UTF8String'(encode = M, zero) -> - 'UTF8String'(M, []); +'UTF8String'(encode, zero, _) -> + <<>>; -'UTF8String'(encode, S) -> +'UTF8String'(encode, S, _) -> <<_/binary>> = unicode:characters_to_binary(S). %% assert binary return %% -------------------- @@ -414,67 +398,23 @@ -define(TIME_MIN, {{1968,1,20},{3,14,8}}). %% TIME_1900 + 1 bsl 31 -define(TIME_MAX, {{2104,2,26},{9,42,24}}). %% TIME_2036 + 1 bsl 31 -'Time'(decode, <<Time:32>>) -> +'Time'(decode, <<Time:32>>, _) -> Offset = msb(1 == Time bsr 31), calendar:gregorian_seconds_to_datetime(Time + Offset); -'Time'(decode, B) -> +'Time'(decode, B, _) -> ?INVALID_LENGTH(B); -'Time'(encode, {{_Y,_M,_D},{_HH,_MM,_SS}} = Datetime) +'Time'(encode, {{_Y,_M,_D},{_HH,_MM,_SS}} = Datetime, _) when ?TIME_MIN =< Datetime, Datetime < ?TIME_MAX -> S = calendar:datetime_to_gregorian_seconds(Datetime), T = S - msb(S < ?TIME_2036), 0 = T bsr 32, %% sanity check <<T:32>>; -'Time'(encode, zero) -> +'Time'(encode, zero, _) -> <<0:32>>. -%% ------------------------------------------------------------------------- - -'OctetString'(M, _, Data) -> - 'OctetString'(M, Data). - -'Integer32'(M, _, Data) -> - 'Integer32'(M, Data). - -'Integer64'(M, _, Data) -> - 'Integer64'(M, Data). - -'Unsigned32'(M, _, Data) -> - 'Unsigned32'(M, Data). - -'Unsigned64'(M, _, Data) -> - 'Unsigned64'(M, Data). - -'Float32'(M, _, Data) -> - 'Float32'(M, Data). - -'Float64'(M, _, Data) -> - 'Float64'(M, Data). - -'Address'(M, _, Data) -> - 'Address'(M, Data). - -'Time'(M, _, Data) -> - 'Time'(M, Data). - -'UTF8String'(M, _, Data) -> - 'UTF8String'(M, Data). - -'DiameterIdentity'(M, _, Data) -> - 'DiameterIdentity'(M, Data). - -'DiameterURI'(M, _, Data) -> - 'DiameterURI'(M, Data). - -'IPFilterRule'(M, _, Data) -> - 'IPFilterRule'(M, Data). - -'QoSFilterRule'(M, _, Data) -> - 'QoSFilterRule'(M, Data). - %% =========================================================================== %% =========================================================================== @@ -564,7 +504,7 @@ msb(false) -> ?TIME_2036. %% %% aaa-protocol = ( "diameter" / "radius" / "tacacs+" ) -scan_uri(Bin) -> +scan_uri(Bin, Opts) -> RE = "^(aaas?)://" "([-a-zA-Z0-9.]{1,255})" "(:0{0,5}([0-9]{1,5}))?" @@ -583,28 +523,30 @@ scan_uri(Bin) -> RE, [{capture, [1,2,4,6,8], binary}]), Type = to_atom(A), - {PN0, T0} = defaults(diameter_codec:getopt(rfc), Type), - PortNr = to_int(PN, PN0), - 0 = PortNr bsr 16, %% assert #diameter_uri{type = Type, - fqdn = 'OctetString'(decode, DN), - port = PortNr, - transport = to_atom(T, T0), + fqdn = 'OctetString'(decode, DN, Opts), + port = portnr(PN, Type, Opts), + transport = transport(T, Opts), protocol = to_atom(P, diameter)}. %% Choose defaults based on the RFC, since 6733 has changed them. -defaults(3588, _) -> - {3868, sctp}; -defaults(6733, aaa) -> - {3868, tcp}; -defaults(6733, aaas) -> - {5658, tcp}. - -to_int(<<>>, N) -> - N; -to_int(B, _) -> + +portnr(<<>>, aaa, #{rfc := 6733}) -> + 3868; +portnr(<<>>, aaas, #{rfc := 6733}) -> + 5868; +portnr(<<>>, _, #{rfc := 3588}) -> + 3868; +portnr(B, _, _) -> binary_to_integer(B). +transport(<<>>, #{rfc := 6733}) -> + tcp; +transport(<<>>, #{rfc := 3588}) -> + sctp; +transport(B, _) -> + to_atom(B). + to_atom(<<>>, A) -> A; to_atom(B, _) -> diff --git a/lib/diameter/src/base/diameter_watchdog.erl b/lib/diameter/src/base/diameter_watchdog.erl index f28b8f2910..a63425d92a 100644 --- a/lib/diameter/src/base/diameter_watchdog.erl +++ b/lib/diameter/src/base/diameter_watchdog.erl @@ -50,10 +50,6 @@ -define(IS_NATURAL(N), (is_integer(N) andalso 0 =< N)). --record(config, - {suspect = 1 :: non_neg_integer(), %% OKAY -> SUSPECT - okay = 3 :: non_neg_integer()}). %% REOPEN -> OKAY - -record(watchdog, {%% PCB - Peer Control Block; see RFC 3539, Appendix A status = initial :: initial | okay | suspect | down | reopen, @@ -70,12 +66,19 @@ | integer() %% monotonic time | undefined, dictionary :: module(), %% common dictionary - receive_data :: term(), - %% term passed into diameter_service with incoming message - sequence :: diameter:sequence(), %% mask - restrict :: {diameter:restriction(), boolean()}, - shutdown = false :: boolean(), - config :: #config{}}). + receive_data :: term(), %% term passed with incoming message + config :: #{sequence := diameter:sequence(), %% mask + restrict_connections := diameter:restriction(), + restrict := boolean(), + suspect := non_neg_integer(), %% OKAY -> SUSPECT + okay := non_neg_integer()}, %% REOPEN -> OKAY + codec :: #{string_decode := false, + strict_mbit := boolean(), + failed_avp := false, + rfc := 3588 | 6733, + ordered_encode := false, + incoming_maxlen := diameter:message_length()}, + shutdown = false :: boolean()}). %% --------------------------------------------------------------------------- %% start/2 @@ -85,12 +88,12 @@ %% reason. %% --------------------------------------------------------------------------- --spec start(Type, {RecvData, [Opt], SvcOpts, #diameter_service{}}) +-spec start(Type, {[Opt], SvcOpts, RecvData, #diameter_service{}}) -> {reference(), pid()} when Type :: {connect|accept, diameter:transport_ref()}, - RecvData :: term(), Opt :: diameter:transport_opt(), - SvcOpts :: [diameter:service_opt()]. + SvcOpts :: map(), + RecvData :: term(). start({_,_} = Type, T) -> Ack = make_ref(), @@ -117,22 +120,28 @@ init(T) -> proc_lib:init_ack({ok, self()}), gen_server:enter_loop(?MODULE, [], i(T)). -i({Ack, T, Pid, {RecvData, - Opts, - SvcOpts, +i({Ack, T, Pid, {Opts, + #{restrict_connections := Restrict} + = SvcOpts0, + RecvData, #diameter_service{applications = Apps, capabilities = Caps} = Svc}}) -> monitor(process, Pid), wait(Ack, Pid), + + Dict0 = common_dictionary(Apps), + SvcOpts = SvcOpts0#{rfc => rfc(Dict0)}, putr(restart, {T, Opts, Svc, SvcOpts}), %% save seeing it in trace putr(dwr, dwr(Caps)), %% - {_,_} = Mask = proplists:get_value(sequence, SvcOpts), - Restrict = proplists:get_value(restrict_connections, SvcOpts), Nodes = restrict_nodes(Restrict), - Dict0 = common_dictionary(Apps), - diameter_codec:setopts([{common_dictionary, Dict0}, - {string_decode, false}]), + CodecKeys = [string_decode, + strict_mbit, + incoming_maxlen, + spawn_opt, + rfc, + ordered_encode], + #watchdog{parent = Pid, transport = start(T, Opts, SvcOpts, Nodes, Dict0, Svc), tw = proplists:get_value(watchdog_timer, @@ -140,9 +149,14 @@ i({Ack, T, Pid, {RecvData, ?DEFAULT_TW_INIT), receive_data = RecvData, dictionary = Dict0, - sequence = Mask, - restrict = {Restrict, lists:member(node(), Nodes)}, - config = config(Opts)}. + config = + maps:without(CodecKeys, + config(SvcOpts#{restrict => restrict(Nodes), + suspect => 1, + okay => 3}, + Opts)), + codec = maps:with(CodecKeys, SvcOpts#{string_decode := false, + ordered_encode => false})}. wait(Ref, Pid) -> receive @@ -152,22 +166,31 @@ wait(Ref, Pid) -> exit({shutdown, D}) end. -%% config/1 +%% Regard anything but the generated RFC 3588 dictionary as modern. +%% This affects the interpretation of defaults during the decode +%% of values of type DiameterURI, this having changed from RFC 3588. +%% (So much for backwards compatibility.) +rfc(?BASE) -> + 3588; +rfc(_) -> + 6733. + +%% config/2 %% %% Could also configure counts for SUSPECT to DOWN and REOPEN to DOWN, %% but don't. -config(Opts) -> +config(Map, Opts) -> Config = proplists:get_value(watchdog_config, Opts, []), - lists:foldl(fun config/2, #config{}, Config). + lists:foldl(fun cfg/2, Map, Config). -config({suspect, N}, Rec) +cfg({suspect, N}, Map) when ?IS_NATURAL(N) -> - Rec#config{suspect = N}; + Map#{suspect := N}; -config({okay, N}, Rec) +cfg({okay, N}, Map) when ?IS_NATURAL(N) -> - Rec#config{okay = N}. + Map#{okay := N}. %% start/6 @@ -283,7 +306,7 @@ event(Msg, ?LOG(transition, {From, To}). data(Msg, TPid, reopen, okay) -> - {recv, TPid, false, 'DWA', _Pkt, _NPid} = Msg, %% assert + {recv, TPid, _, 'DWA', _Pkt} = Msg, %% assert {TPid, T} = eraser(open), [T]; @@ -302,6 +325,8 @@ tpid(_, Pid) tpid(Pid, _) -> Pid. +%% send/2 + send(Pid, T) -> Pid ! T. @@ -375,8 +400,8 @@ transition({accepted = T, TPid}, #watchdog{transport = TPid, transition({open, TPid, Hosts, _} = Open, #watchdog{transport = TPid, status = initial, - restrict = {_,R}, - config = #config{suspect = OS}} + config = #{restrict := R, + suspect := OS}} = S) -> case okay(role(), Hosts, R) of okay -> @@ -394,8 +419,8 @@ transition({open, TPid, Hosts, _} = Open, transition({open = Key, TPid, _Hosts, T}, #watchdog{transport = TPid, status = down, - config = #config{suspect = OS, - okay = RO}} + config = #{suspect := OS, + okay := RO}} = S) -> case RO of 0 -> %% non-standard: skip REOPEN @@ -428,7 +453,7 @@ transition({'DOWN', _, process, TPid, _Reason}, transition({'DOWN', _, process, TPid, _Reason} = D, #watchdog{transport = TPid, status = T, - restrict = {_,R}} + config = #{restrict := R}} = S0) -> S = S0#watchdog{pending = false, transport = undefined}, @@ -447,14 +472,15 @@ transition({'DOWN', _, process, TPid, _Reason} = D, end; %% Incoming message. -transition({recv, TPid, Route, Name, Pkt, NPid}, +transition({recv, TPid, Route, Name, Pkt}, #watchdog{transport = TPid} = S) -> - try - incoming(Name, Pkt, NPid, S) - catch + try incoming(Route, Name, Pkt, S) of #watchdog{dictionary = Dict0, receive_data = T} = NS -> - diameter_traffic:receive_message(TPid, Route, Pkt, NPid, Dict0, T), + diameter_traffic:receive_message(TPid, Route, Pkt, Dict0, T), + NS + catch + #watchdog{} = NS -> NS end; @@ -483,9 +509,9 @@ getr(Key) -> eraser(Key) -> erase({?MODULE, Key}). -%% encode/3 +%% encode/4 -encode(dwr = M, Dict0, Mask) -> +encode(dwr = M, Dict0, Opts, Mask) -> Msg = getr(M), Seq = diameter_session:sequence(Mask), Hdr = #diameter_header{version = ?DIAMETER_VERSION, @@ -493,10 +519,10 @@ encode(dwr = M, Dict0, Mask) -> hop_by_hop_id = Seq}, Pkt = #diameter_packet{header = Hdr, msg = Msg}, - diameter_codec:encode(Dict0, Pkt); + diameter_codec:encode(Dict0, Opts, Pkt); -encode(dwa, Dict0, #diameter_packet{header = H, transport_data = TD} - = ReqPkt) -> +encode(dwa, Dict0, Opts, #diameter_packet{header = H, transport_data = TD} + = ReqPkt) -> AnsPkt = #diameter_packet{header = H#diameter_header{is_request = false, is_error = undefined, @@ -504,7 +530,7 @@ encode(dwa, Dict0, #diameter_packet{header = H, transport_data = TD} msg = dwa(ReqPkt), transport_data = TD}, - diameter_codec:encode(Dict0, AnsPkt). + diameter_codec:encode(Dict0, Opts, AnsPkt). %% okay/3 @@ -574,9 +600,10 @@ tw({M,F,A}) -> send_watchdog(#watchdog{pending = false, transport = TPid, dictionary = Dict0, - sequence = Mask} + config = #{sequence := Mask}, + codec = Opts} = S) -> - #diameter_packet{bin = Bin} = EPkt = encode(dwr, Dict0, Mask), + #diameter_packet{bin = Bin} = EPkt = encode(dwr, Dict0, Opts, Mask), diameter_traffic:incr(send, EPkt, TPid, Dict0), send(TPid, {send, Bin}), ?LOG(send, 'DWR'), @@ -586,41 +613,30 @@ send_watchdog(#watchdog{pending = false, %% incoming/4 -incoming(Name, Pkt, false, S) -> - recv(Name, Pkt, S); - -incoming(Name, Pkt, NPid, S) -> - try - recv(Name, Pkt, S) - after - NPid ! {diameter, discard} - end. - -%% recv/3 - -recv(Name, Pkt, S) -> - try rcv(Name, Pkt, rcv(Name, S)) of - #watchdog{} = NS -> - throw(NS) +incoming(Route, Name, Pkt, S) -> + try rcv(Name, S) of + NS -> rcv(Name, Pkt, NS) catch - #watchdog{} = NS -> %% throwaway - NS + #watchdog{transport = TPid} = NS when Route -> %% incoming request + send(TPid, {send, false}), %% requiring ack + throw(NS) end. %% rcv/3 rcv('DWR', Pkt, #watchdog{transport = TPid, - dictionary = Dict0} + dictionary = Dict0, + codec = Opts} = S) -> ?LOG(recv, 'DWR'), - DPkt = diameter_codec:decode(Dict0, Pkt), + DPkt = diameter_codec:decode(Dict0, Opts, Pkt), diameter_traffic:incr(recv, DPkt, TPid, Dict0), diameter_traffic:incr_error(recv, DPkt, TPid, Dict0), #diameter_packet{header = H, transport_data = T, bin = Bin} = EPkt - = encode(dwa, Dict0, Pkt), + = encode(dwa, Dict0, Opts, Pkt), diameter_traffic:incr(send, EPkt, TPid, Dict0), diameter_traffic:incr_rc(send, EPkt, TPid, Dict0), @@ -632,12 +648,13 @@ rcv('DWR', Pkt, #watchdog{transport = TPid, throw(S); rcv('DWA', Pkt, #watchdog{transport = TPid, - dictionary = Dict0} + dictionary = Dict0, + codec = Opts} = S) -> ?LOG(recv, 'DWA'), diameter_traffic:incr(recv, Pkt, TPid, Dict0), diameter_traffic:incr_rc(recv, - diameter_codec:decode(Dict0, Pkt), + diameter_codec:decode(Dict0, Opts, Pkt), TPid, Dict0), throw(S); @@ -699,12 +716,12 @@ rcv(_, #watchdog{status = okay} = S) -> %% SUSPECT Receive non-DWA Failback() %% SetWatchdog() OKAY -rcv('DWA', #watchdog{status = suspect, config = #config{suspect = OS}} = S) -> +rcv('DWA', #watchdog{status = suspect, config = #{suspect := OS}} = S) -> set_watchdog(S#watchdog{status = okay, num_dwa = OS, pending = false}); -rcv(_, #watchdog{status = suspect, config = #config{suspect = OS}} = S) -> +rcv(_, #watchdog{status = suspect, config = #{suspect := OS}} = S) -> set_watchdog(S#watchdog{status = okay, num_dwa = OS}); @@ -714,8 +731,8 @@ rcv(_, #watchdog{status = suspect, config = #config{suspect = OS}} = S) -> rcv('DWA', #watchdog{status = reopen, num_dwa = N, - config = #config{suspect = OS, - okay = RO}} + config = #{suspect := OS, + okay := RO}} = S) when N+1 == RO -> S#watchdog{status = okay, @@ -846,18 +863,19 @@ restart(S) -> %% reconnect has won race with timeout restart({{connect, _} = T, Opts, Svc, SvcOpts}, #watchdog{parent = Pid, - restrict = {R,_}, + config = #{restrict_connections := R} + = M, dictionary = Dict0} = S) -> send(Pid, {reconnect, self()}), Nodes = restrict_nodes(R), S#watchdog{transport = start(T, Opts, SvcOpts, Nodes, Dict0, Svc), - restrict = {R, lists:member(node(), Nodes)}}; + config = M#{restrict => restrict(Nodes)}}; %% No restriction on the number of connections to the same peer: just %% die. Note that a state machine never enters state REOPEN in this %% case. -restart({{accept, _}, _, _, _}, #watchdog{restrict = {_, false}}) -> +restart({{accept, _}, _, _, _}, #watchdog{config = #{restrict := false}}) -> stop; %% Otherwise hang around until told to die, either by the service or @@ -901,3 +919,8 @@ restrict_nodes(Nodes) restrict_nodes(F) -> diameter_lib:eval(F). + +%% restrict/1 + +restrict(Nodes) -> + lists:member(node(), Nodes). diff --git a/lib/diameter/src/compiler/diameter_codegen.erl b/lib/diameter/src/compiler/diameter_codegen.erl index 928ae37e7f..f56e4a5249 100644 --- a/lib/diameter/src/compiler/diameter_codegen.erl +++ b/lib/diameter/src/compiler/diameter_codegen.erl @@ -150,20 +150,21 @@ erl_forms(Mod, ParseD) -> {id, 0}, {vendor_id, 0}, {vendor_name, 0}, - {decode_avps, 2}, %% in diameter_gen.hrl - {encode_avps, 2}, %% + {decode_avps, 3}, %% in diameter_gen.hrl + {encode_avps, 3}, %% + {grouped_avp, 4}, %% {msg_name, 2}, {msg_header, 1}, {rec2msg, 1}, {msg2rec, 1}, {name2rec, 1}, {avp_name, 2}, + {avp_arity, 1}, {avp_arity, 2}, {avp_header, 1}, - {avp, 3}, - {grouped_avp, 3}, + {avp, 4}, {enumerated_avp, 3}, - {empty_value, 1}, + {empty_value, 2}, {dict, 0}]}, %% diameter.hrl is included for #diameter_avp {?attribute, include_lib, "diameter/include/diameter.hrl"}, @@ -178,7 +179,8 @@ erl_forms(Mod, ParseD) -> f_msg2rec(ParseD), f_name2rec(ParseD), f_avp_name(ParseD), - f_avp_arity(ParseD), + f_avp_arity_1(ParseD), + f_avp_arity_2(ParseD), f_avp_header(ParseD), f_avp(ParseD), f_enumerated_avp(ParseD), @@ -418,10 +420,32 @@ vendor_id_map(ParseD) -> get_value(grouped, ParseD)). %%% ------------------------------------------------------------------------ +%%% # avp_arity/1 +%%% ------------------------------------------------------------------------ + +f_avp_arity_1(ParseD) -> + {?function, avp_arity, 1, avp_arities(ParseD) ++ [?BADARG(1)]}. + +avp_arities(ParseD) -> + Msgs = get_value(messages, ParseD), + Groups = get_value(grouped, ParseD) + ++ lists:flatmap(fun avps/1, get_value(import_groups, ParseD)), + lists:map(fun c_avp_arities/1, Msgs ++ Groups). + +c_avp_arities({N,_,_,_,As}) -> + c_avp_arities(N,As); +c_avp_arities({N,_,_,As}) -> + c_avp_arities(N,As). + +c_avp_arities(Name, Avps) -> + Arities = [{?A(N), A} || T <- Avps, {N,A} <- [avp_info(T)]], + {?clause, [?Atom(Name)], [], [?TERM(Arities)]}. + +%%% ------------------------------------------------------------------------ %%% # avp_arity/2 %%% ------------------------------------------------------------------------ -f_avp_arity(ParseD) -> +f_avp_arity_2(ParseD) -> {?function, avp_arity, 2, avp_arity(ParseD)}. avp_arity(ParseD) -> @@ -452,7 +476,7 @@ c_arity(Name, Avp) -> %%% ------------------------------------------------------------------------ f_avp(ParseD) -> - {?function, avp, 3, avp(ParseD) ++ [?BADARG(3)]}. + {?function, avp, 4, avp(ParseD) ++ [?BADARG(4)]}. avp(ParseD) -> Native = get_value(avp_types, ParseD), @@ -491,19 +515,25 @@ avp(Native, Imported, Custom, Enums) -> not_in(List, X) -> not lists:member(X, List). -c_base_avp({AvpName, T}) -> - {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName)], +c_base_avp({AvpName, "Enumerated"}) -> + {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName), ?VAR('_')], [], - [b_base_avp(AvpName, T)]}. + [?CALL(enumerated_avp, [?VAR('T'), ?Atom(AvpName), ?VAR('Data')])]}; -b_base_avp(AvpName, "Enumerated") -> - ?CALL(enumerated_avp, [?VAR('T'), ?Atom(AvpName), ?VAR('Data')]); - -b_base_avp(AvpName, "Grouped") -> - ?CALL(grouped_avp, [?VAR('T'), ?Atom(AvpName), ?VAR('Data')]); +c_base_avp({AvpName, "Grouped"}) -> + {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName), ?VAR('Opts')], + [], + [?CALL(grouped_avp, [?VAR('T'), + ?Atom(AvpName), + ?VAR('Data'), + ?VAR('Opts')])]}; -b_base_avp(_, Type) -> - ?APPLY(diameter_types, ?A(Type), [?VAR('T'), ?VAR('Data')]). +c_base_avp({AvpName, Type}) -> + {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName), ?VAR('Opts')], + [], + [?APPLY(diameter_types, ?A(Type), [?VAR('T'), + ?VAR('Data'), + ?VAR('Opts')])]}. cs_imported_avp({Mod, Avps}, Enums, CustomNames) -> lists:map(fun(A) -> imported_avp(Mod, A, Enums) end, @@ -525,11 +555,13 @@ imported_avp(Mod, {AvpName, _, _, _}, _) -> c_imported_avp(Mod, AvpName). c_imported_avp(Mod, AvpName) -> - {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName)], + {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName), ?VAR('Opts')], [], - [?APPLY(Mod, avp, [?VAR('T'), - ?VAR('Data'), - ?Atom(AvpName)])]}. + [?CALL(avp, [?VAR('T'), + ?VAR('Data'), + ?Atom(AvpName), + ?VAR('Opts'), + ?ATOM(Mod)])]}. cs_custom_avp({Mod, Key, Avps}, Dict) -> lists:map(fun(N) -> c_custom_avp(Mod, Key, N, orddict:fetch(N, Dict)) end, @@ -537,9 +569,12 @@ cs_custom_avp({Mod, Key, Avps}, Dict) -> c_custom_avp(Mod, Key, AvpName, Type) -> {F,A} = custom(Key, AvpName, Type), - {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName)], + {?clause, [?VAR('T'), ?VAR('Data'), ?Atom(AvpName), ?VAR('Opts')], [], - [?APPLY(?A(Mod), ?A(F), [?VAR('T'), ?Atom(A), ?VAR('Data')])]}. + [?APPLY(?A(Mod), ?A(F), [?VAR('T'), + ?Atom(A), + ?VAR('Data'), + ?VAR('Opts')])]}. custom(custom_types, AvpName, Type) -> {AvpName, Type}; @@ -568,7 +603,11 @@ enumerated_avp(Mod, Es, Enums) -> Es). cs_enumerated_avp(true, Mod, Name) -> - [c_imported_avp(Mod, Name)]; + [{?clause, [?VAR('T'), ?Atom(Name), ?VAR('Data')], + [], + [?APPLY(Mod, enumerated_avp, [?VAR('T'), + ?Atom(Name), + ?VAR('Data')])]}]; cs_enumerated_avp(false, _, _) -> []. @@ -682,7 +721,7 @@ v(false, _, _, _) -> %%% ------------------------------------------------------------------------ f_empty_value(ParseD) -> - {?function, empty_value, 1, empty_value(ParseD)}. + {?function, empty_value, 2, empty_value(ParseD)}. empty_value(ParseD) -> Imported = lists:flatmap(fun avps/1, get_value(import_enums, ParseD)), @@ -692,15 +731,17 @@ empty_value(ParseD) -> not lists:keymember(N, 1, Imported)] ++ Imported, lists:map(fun c_empty_value/1, Groups ++ Enums) - ++ [{?clause, [?VAR('Name')], [], [?CALL(empty, [?VAR('Name')])]}]. + ++ [{?clause, [?VAR('Name'), ?VAR('Opts')], + [], + [?CALL(empty, [?VAR('Name'), ?VAR('Opts')])]}]. c_empty_value({Name, _, _, _}) -> - {?clause, [?Atom(Name)], + {?clause, [?Atom(Name), ?VAR('Opts')], [], - [?CALL(empty_group, [?Atom(Name)])]}; + [?CALL(empty_group, [?Atom(Name), ?VAR('Opts')])]}; c_empty_value({Name, _}) -> - {?clause, [?Atom(Name)], + {?clause, [?Atom(Name), ?VAR('_')], [], [?TERM(<<0:32>>)]}. diff --git a/lib/diameter/src/diameter.app.src b/lib/diameter/src/diameter.app.src index d380ebbd92..9a6e47006b 100644 --- a/lib/diameter/src/diameter.app.src +++ b/lib/diameter/src/diameter.app.src @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -28,10 +28,10 @@ ]}, {registered, [%REGISTERED%]}, {applications, [ - {stdlib, "2.0"}, {kernel, "3.0"}%, {erts, "6.0"} - %% {syntax-tools, "1.6.14"} - %% {runtime-tools, "1.8.14"} - %, {ssl, "5.3.4"} + {stdlib, "2.4"}, {kernel, "3.2"}%, {erts, "6.4"} + %% {syntax-tools, "1.6,18"} + %% {runtime-tools, "1.8.16"} + %, {ssl, "6.0"} ]}, {env, []}, {mod, {diameter_app, []}}, diff --git a/lib/diameter/src/diameter.appup.src b/lib/diameter/src/diameter.appup.src index eb5a5a44f3..07d0389bfd 100644 --- a/lib/diameter/src/diameter.appup.src +++ b/lib/diameter/src/diameter.appup.src @@ -51,7 +51,8 @@ {"1.11.1", [{restart_application, diameter}]}, %% 18.2 {"1.11.2", [{restart_application, diameter}]}, %% 18.3 {"1.12", [{restart_application, diameter}]}, %% 19.0 - {"1.12.1", [{restart_application, diameter}]} %% 19.1 + {"1.12.1", [{restart_application, diameter}]}, %% 19.1 + {"1.12.2", [{restart_application, diameter}]} %% 19.3 ], [ {"0.9", [{restart_application, diameter}]}, @@ -84,6 +85,7 @@ {"1.11.1", [{restart_application, diameter}]}, {"1.11.2", [{restart_application, diameter}]}, {"1.12", [{restart_application, diameter}]}, - {"1.12.1", [{restart_application, diameter}]} + {"1.12.1", [{restart_application, diameter}]}, + {"1.12.2", [{restart_application, diameter}]} ] }. diff --git a/lib/diameter/src/modules.mk b/lib/diameter/src/modules.mk index 4e4ce60ddf..bb3b234d20 100644 --- a/lib/diameter/src/modules.mk +++ b/lib/diameter/src/modules.mk @@ -1,7 +1,7 @@ # %CopyrightBegin% # -# Copyright Ericsson AB 2010-2016. All Rights Reserved. +# Copyright Ericsson AB 2010-2017. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ RT_MODULES = \ base/diameter_config \ base/diameter_config_sup \ base/diameter_codec \ - base/diameter_dict \ + base/diameter_gen \ base/diameter_lib \ base/diameter_misc_sup \ base/diameter_peer \ diff --git a/lib/diameter/src/transport/diameter_sctp.erl b/lib/diameter/src/transport/diameter_sctp.erl index 76aacabcb8..6a9f1f940b 100644 --- a/lib/diameter/src/transport/diameter_sctp.erl +++ b/lib/diameter/src/transport/diameter_sctp.erl @@ -52,21 +52,20 @@ %% Keys into process dictionary. -define(INFO_KEY, info). -define(REF_KEY, ref). +-define(TRANSPORT_KEY, transport). -define(ERROR(T), erlang:error({T, ?MODULE, ?LINE})). %% The default port for a listener. -define(DEFAULT_PORT, 3868). %% RFC 3588, ch 2.1 -%% Remote addresses to accept connections from. --define(DEFAULT_ACCEPT, []). %% any - %% How long to wait for a transport process to attach after %% association establishment. -define(ACCEPT_TIMEOUT, 5000). -type connect_option() :: {raddr, inet:ip_address()} | {rport, inet:port_number()} + | option() | term(). %% gen_sctp:open_option(). -type match() :: inet:ip_address() @@ -74,8 +73,14 @@ | [match()]. -type listen_option() :: {accept, match()} + | option() | term(). %% gen_sctp:open_option(). +-type option() :: {sender, boolean()} + | sender + | {packet, boolean() | raw} + | {message_cb, false | diameter:evaluable()}. + -type uint() :: non_neg_integer(). %% Accepting/connecting transport process state. @@ -87,20 +92,35 @@ %% {RAs, RP, Errors} | connect, socket :: gen_sctp:sctp_socket() | undefined, - assoc_id :: gen_sctp:assoc_id(), %% association identifier + active = false :: boolean(), %% is socket active? + recv = true :: boolean(), %% should it be active? + assoc_id :: gen_sctp:assoc_id() %% association identifier + | undefined + | true, peer :: {[inet:ip_address()], uint()} %% {RAs, RP} | undefined, streams :: {uint(), uint()} %% {InStream, OutStream} counts | undefined, - os = 0 :: uint()}). %% next output stream + os = 0 :: uint(), %% next output stream + packet = true :: boolean() %% legacy transport_data? + | raw, + message_cb = false :: false | diameter:evaluable(), + send = false :: pid() | boolean()}). %% sending process + +%% Monitor process state. +-record(monitor, + {transport :: pid(), + ack = false :: boolean(), + socket :: gen_sctp:sctp_socket(), + assoc_id :: gen_sctp:assoc_id()}). %% next output stream %% Listener process state. -record(listener, {ref :: reference(), socket :: gen_sctp:sctp_socket(), - service = false :: false | pid(), %% service process + service :: pid(), %% service process pending = {0, queue:new()}, - accept :: [match()]}). + opts :: [[match()] | boolean() | diameter:evaluable()]}). %% Field pending implements two queues: the first of transport-to-be %% processes to which an association has been assigned but for which %% diameter hasn't yet spawned a transport process, a short-lived @@ -132,11 +152,11 @@ start(T, Svc, Opts) when is_list(Opts) -> #diameter_service{capabilities = Caps, - pid = SPid} + pid = Pid} = Svc, diameter_sctp_sup:start(), %% start supervisors on demand Addrs = Caps#diameter_caps.host_ip_address, - s(T, Addrs, SPid, lists:map(fun ip/1, Opts)). + s(T, Addrs, Pid, lists:map(fun ip/1, Opts)). ip({ifaddr, A}) -> {ip, A}; @@ -147,9 +167,9 @@ ip(T) -> %% when there is not yet an association to assign it, or at comm_up on %% a new association in which case the call retrieves a transport from %% the pending queue. -s({accept, Ref} = A, Addrs, SPid, Opts) -> - {ok, LPid, LAs} = listener(Ref, {Opts, Addrs}), - try gen_server:call(LPid, {A, self(), SPid}, infinity) of +s({accept, Ref} = A, Addrs, SvcPid, Opts) -> + {ok, LPid, LAs} = listener(Ref, {Opts, SvcPid, Addrs}), + try gen_server:call(LPid, {A, self()}, infinity) of {ok, TPid} -> {ok, TPid, LAs}; No -> @@ -162,7 +182,7 @@ s({accept, Ref} = A, Addrs, SPid, Opts) -> %% gen_sctp in order to be able to accept a new association only %% *after* an accepting transport has been spawned. -s({connect = C, Ref}, Addrs, _SPid, Opts) -> +s({connect = C, Ref}, Addrs, _SvcPid, Opts) -> diameter_sctp_sup:start_child({C, self(), Opts, Addrs, Ref}). %% start_link/1 @@ -216,22 +236,39 @@ init(T) -> %% i/1 +i(#monitor{transport = TPid} = S) -> + monitor(process, TPid), + putr(?TRANSPORT_KEY, TPid), + proc_lib:init_ack({ok, self()}), + S; + %% A process owning a listening socket. -i({listen, Ref, {Opts, Addrs}}) -> +i({listen, Ref, {Opts, SvcPid, Addrs}}) -> + monitor(process, SvcPid), [_] = diameter_config:subscribe(Ref, transport), %% assert existence - {[Matches], Rest} = proplists:split(Opts, [accept]), + {Split, Rest} + = proplists:split(Opts, [accept, packet, sender, message_cb]), + OwnOpts = lists:append(Split), {LAs, Sock} = AS = open(Addrs, Rest, ?DEFAULT_PORT), ok = gen_sctp:listen(Sock, true), true = diameter_reg:add_new({?MODULE, listener, {Ref, AS}}), proc_lib:init_ack({ok, self(), LAs}), #listener{ref = Ref, + service = SvcPid, socket = Sock, - accept = [[M] || {accept, M} <- Matches]}; + opts = [[[M] || {accept, M} <- OwnOpts], + proplists:get_value(packet, OwnOpts, true) + | [proplists:get_value(K, OwnOpts, false) + || K <- [sender, message_cb]]]}; %% A connecting transport. i({connect, Pid, Opts, Addrs, Ref}) -> - {[As, Ps], Rest} = proplists:split(Opts, [raddr, rport]), - RAs = [diameter_lib:ipaddr(A) || {raddr, A} <- As], + {[Ps | Split], Rest} + = proplists:split(Opts, [rport, raddr, packet, sender, message_cb]), + OwnOpts = lists:append(Split), + CB = proplists:get_value(message_cb, OwnOpts, false), + false == CB orelse (Pid ! {diameter, ack}), + RAs = [diameter_lib:ipaddr(A) || {raddr, A} <- OwnOpts], [RP] = [P || {rport, P} <- Ps] ++ [P || P <- [?DEFAULT_PORT], [] == Ps], {LAs, Sock} = open(Addrs, Rest, 0), putr(?REF_KEY, Ref), @@ -239,7 +276,10 @@ i({connect, Pid, Opts, Addrs, Ref}) -> monitor(process, Pid), #transport{parent = Pid, mode = {connect, connect(Sock, RAs, RP, [])}, - socket = Sock}; + socket = Sock, + message_cb = CB, + packet = proplists:get_value(packet, OwnOpts, true), + send = proplists:get_value(sender, OwnOpts, false)}; %% An accepting transport spawned by diameter, not yet owning an %% association. @@ -273,11 +313,16 @@ i({K, Ref}, #transport{mode = {accept, _}} = S) -> receive {Ref, Pid} when K == parent -> %% transport process started S#transport{parent = Pid}; - {K, T, Matches} when K == peeloff -> %% association + {K, T, Opts} when K == peeloff -> %% association {sctp, Sock, _RA, _RP, _Data} = T, + [Matches, Packet, Sender, CB] = Opts, ok = accept_peer(Sock, Matches), demonitor(Ref, [flush]), - t(T, S#transport{socket = Sock}); + false == CB orelse (S#transport.parent ! {diameter, ack}), + t(T, S#transport{socket = Sock, + message_cb = CB, + packet = Packet, + send = Sender}); accept_timeout = T -> x(T); {'DOWN', _, process, _, _} = T -> @@ -374,13 +419,9 @@ handle_call({{accept, Ref}, Pid}, _, #listener{ref = Ref} = S) -> {TPid, NewS} = accept(Ref, Pid, S), {reply, {ok, TPid}, NewS}; -handle_call({{accept, _} = T, Pid, SPid}, From, #listener{service = P} = S) -> - handle_call({T, Pid}, From, if not is_pid(P), is_pid(SPid) -> - monitor(process, SPid), - S#listener{service = SPid}; - true -> - S - end); +%% Transport is telling us of parent death. +handle_call({stop, _Pid} = Reason, _From, #monitor{} = S) -> + {stop, {shutdown, Reason}, ok, S}; handle_call(_, _, State) -> {reply, nok, State}. @@ -400,7 +441,11 @@ handle_info(T, #transport{} = S) -> {noreply, #transport{} = t(T,S)}; handle_info(T, #listener{} = S) -> - {noreply, #listener{} = l(T,S)}. + {noreply, #listener{} = l(T,S)}; + +handle_info(T, #monitor{} = S) -> + m(T,S), + {noreply, S}. %% Prior to the possibility of setting pool_size on in transport %% configuration, a new accepting transport was only started following @@ -422,6 +467,9 @@ code_change(_, State, _) -> %% # terminate/2 %% --------------------------------------------------------------------------- +terminate(_, #monitor{}) -> + ok; + terminate(_, #transport{assoc_id = undefined}) -> ok; @@ -445,11 +493,11 @@ getr(Key) -> %% Incoming message from SCTP. l({sctp, Sock, _RA, _RP, Data} = T, #listener{socket = Sock, - accept = Matches} + opts = Opts} = S) -> Id = assoc_id(Data), {TPid, NewS} = accept(S), - TPid ! {peeloff, setelement(2, T, peeloff(Sock, Id, TPid)), Matches}, + TPid ! {peeloff, setelement(2, T, peeloff(Sock, Id, TPid)), Opts}, setopts(Sock), NewS; @@ -503,12 +551,21 @@ t(T,S) -> %% Incoming message. transition({sctp, Sock, _RA, _RP, Data}, #transport{socket = Sock} = S) -> - setopts(Sock), - recv(Data, S); + setopts(S, recv(Data, S#transport{active = false})); %% Outgoing message. transition({diameter, {send, Msg}}, S) -> - send(Msg, S); + message(send, Msg, S); + +%% Monitor has sent an outgoing message. +transition(Msg, S) + when is_record(Msg, diameter_packet); + is_binary(Msg) -> + message(ack, Msg, S); + +%% Deferred actions from a message_cb. +transition({actions, Dir, Acts}, S) -> + actions(Acts, Dir, S); %% Request to close the transport connection. transition({diameter, {close, Pid}}, #transport{parent = Pid}) -> @@ -522,8 +579,18 @@ transition({diameter, {close, Pid}}, #transport{parent = Pid}) -> transition({diameter, {tls, _Ref, _Type, _Bool}}, _) -> stop; -%% Parent process has died. -transition({'DOWN', _, process, Pid, _}, #transport{parent = Pid}) -> +%% Parent process has died: call the monitor to not close the socket +%% during an ongoing send, but don't let it take forever. +transition({'DOWN', _, process, Pid, _}, #transport{parent = Pid, + send = MPid}) -> + is_boolean(MPid) + orelse ok == (catch gen_server:call(MPid, {stop, Pid})) + orelse exit(MPid, kill), + stop; + +%% Monitor process has died. +transition({'DOWN', _, process, MPid, _}, #transport{send = MPid}) + when is_pid(MPid) -> stop; %% Timeout after transport process has been started. @@ -536,6 +603,18 @@ transition({resolve_port, Pid}, #transport{socket = Sock}) Pid ! inet:port(Sock), ok. +%% m/2 + +m({Msg, StreamId}, #monitor{socket = Sock, + transport = TPid, + assoc_id = AId, + ack = B}) -> + send(Sock, AId, StreamId, Msg), + B andalso (TPid ! Msg); + +m({'DOWN', _, process, TPid, _} = T, #monitor{transport = TPid}) -> + x(T). + %% Crash on anything unexpected. ok({ok, T}) -> @@ -578,33 +657,52 @@ q(Ref, Pid, #listener{pending = {_,Q}}) -> %% send/2 +%% Start monitor process on first send. +send(Msg, #transport{send = true, + socket = Sock, + assoc_id = AId, + message_cb = CB} + = S) -> + {ok, MPid} = diameter_sctp_sup:start_child(#monitor{transport = self(), + socket = Sock, + assoc_id = AId, + ack = false /= CB}), + monitor(process, MPid), + send(Msg, S#transport{send = MPid}); + %% Outbound Diameter message on a specified stream ... -send(#diameter_packet{bin = Bin, transport_data = {outstream, SId}}, +send(#diameter_packet{transport_data = {outstream, SId}} + = Msg, #transport{streams = {_, OS}} = S) -> - send(SId rem OS, Bin, S), - S; + send(SId rem OS, Msg, S); %% ... or not: rotate through all streams. -send(#diameter_packet{bin = Bin}, S) -> - send(Bin, S); -send(Bin, #transport{streams = {_, OS}, +send(Msg, #transport{streams = {_, OS}, os = N} - = S) - when is_binary(Bin) -> - send(N, Bin, S), - S#transport{os = (N + 1) rem OS}. + = S) -> + send(N, Msg, S#transport{os = (N + 1) rem OS}). %% send/3 -send(StreamId, Bin, #transport{socket = Sock, - assoc_id = AId}) -> - send(Sock, AId, StreamId, Bin). +send(StreamId, Msg, #transport{send = false, + socket = Sock, + assoc_id = AId} + = S) -> + send(Sock, AId, StreamId, Msg), + message(ack, Msg, S); + +send(StreamId, Msg, #transport{send = MPid} = S) -> + MPid ! {Msg, StreamId}, + S. %% send/4 -send(Sock, AssocId, Stream, Bin) -> - case gen_sctp:send(Sock, AssocId, Stream, Bin) of +send(Sock, AssocId, StreamId, #diameter_packet{bin = Bin}) -> + send(Sock, AssocId, StreamId, Bin); + +send(Sock, AssocId, StreamId, Bin) -> + case gen_sctp:send(Sock, AssocId, StreamId, Bin) of ok -> ok; {error, Reason} -> @@ -624,7 +722,9 @@ recv({_, #sctp_assoc_change{state = comm_up, = S) -> Ref = getr(?REF_KEY), publish(T, Ref, Id, Sock), - up(S#transport{assoc_id = Id, + %% Deal with different association id after peeloff on Solaris by + %% taking the id from the first reception. + up(S#transport{assoc_id = T == accept orelse Id, streams = {IS, OS}}); %% ... or not: try the next address. @@ -639,17 +739,19 @@ recv({_, #sctp_assoc_change{} = E}, recv({_, #sctp_assoc_change{}}, _) -> stop; +%% First inbound on an accepting transport. +recv({[#sctp_sndrcvinfo{assoc_id = Id}], _Bin} + = T, + #transport{assoc_id = true} + = S) -> + recv(T, S#transport{assoc_id = Id}); + %% Inbound Diameter message. -recv({[#sctp_sndrcvinfo{stream = Id}], Bin}, #transport{parent = Pid}) +recv({[#sctp_sndrcvinfo{}], Bin} = Msg, S) when is_binary(Bin) -> - diameter_peer:recv(Pid, #diameter_packet{transport_data = {stream, Id}, - bin = Bin}), - ok; + message(recv, Msg, S); -recv({_, #sctp_shutdown_event{assoc_id = A}}, - #transport{assoc_id = Id}) - when A == Id; - A == 0 -> +recv({_, #sctp_shutdown_event{}}, _) -> stop; %% Note that diameter_sctp(3) documents that sctp_events cannot be @@ -765,6 +867,23 @@ connect(Sock, [Addr | AT] = As, Port, Reasons) -> connect(Sock, AT, Port, [{Addr, E} | Reasons]) end. +%% setopts/2 + +setopts(_, #transport{socket = Sock, + active = A, + recv = B} + = S) + when B, not A -> + setopts(Sock), + S#transport{active = true}; + +setopts(_, #transport{} = S) -> + S; + +setopts(#transport{socket = Sock}, T) -> + setopts(Sock), + T. + %% setopts/1 setopts(Sock) -> @@ -772,3 +891,83 @@ setopts(Sock) -> ok -> ok; X -> x({setopts, Sock, X}) %% possibly on peer disconnect end. + +%% A message_cb is invoked whenever a message is sent or received, or +%% to provide acknowledgement of a completed send or discarded +%% request. See diameter_tcp for semantics, the only difference being +%% that a recv callback can get a diameter_packet record as Msg +%% depending on how/if option packet has been specified. + +%% message/3 + +message(send, false = M, S) -> + message(ack, M, S); + +message(ack, _, #transport{message_cb = false} = S) -> + S; + +message(Dir, Msg, S) -> + setopts(S, actions(cb(S, Dir, Msg), Dir, S)). + +%% actions/3 + +actions([], _, S) -> + S; + +actions([B | As], Dir, S) + when is_boolean(B) -> + actions(As, Dir, S#transport{recv = B}); + +actions([Dir | As], _, S) + when Dir == send; + Dir == recv -> + actions(As, Dir, S); + +actions([Msg | As], send = Dir, S) + when is_record(Msg, diameter_packet); + is_binary(Msg) -> + actions(As, Dir, send(Msg, S)); + +actions([Msg | As], recv = Dir, #transport{parent = Pid} = S) + when is_record(Msg, diameter_packet); + is_binary(Msg) -> + diameter_peer:recv(Pid, Msg), + actions(As, Dir, S); + +actions([{defer, Tmo, Acts} | As], Dir, S) -> + erlang:send_after(Tmo, self(), {actions, Dir, Acts}), + actions(As, Dir, S); + +actions(CB, _, S) -> + S#transport{message_cb = CB}. + +%% cb/3 + +cb(#transport{message_cb = false, packet = P}, recv, Msg) -> + [pkt(P, true, Msg)]; + +cb(#transport{message_cb = CB, packet = P}, recv = D, Msg) -> + cb(CB, D, pkt(P, false, Msg)); + +cb(#transport{message_cb = CB}, Dir, Msg) -> + cb(CB, Dir, Msg); + +cb(false, send, Msg) -> + [Msg]; + +cb(CB, Dir, Msg) -> + diameter_lib:eval([CB, Dir, Msg]). + +%% pkt/3 + +pkt(false, _, {_Info, Bin}) -> + Bin; + +pkt(true, _, {[#sctp_sndrcvinfo{stream = Id}], Bin}) -> + #diameter_packet{bin = Bin, transport_data = {stream, Id}}; + +pkt(raw, true, {[Info], Bin}) -> + #diameter_packet{bin = Bin, transport_data = Info}; + +pkt(raw, false, {[_], _} = Msg) -> + Msg. diff --git a/lib/diameter/src/transport/diameter_sctp_sup.erl b/lib/diameter/src/transport/diameter_sctp_sup.erl index 36050aaf28..e8e26ec7c5 100644 --- a/lib/diameter/src/transport/diameter_sctp_sup.erl +++ b/lib/diameter/src/transport/diameter_sctp_sup.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -49,6 +49,7 @@ start() -> start_child(T) -> SupRef = case element(1,T) of + monitor -> ?TRANSPORT_SUP; connect -> ?TRANSPORT_SUP; accept -> ?TRANSPORT_SUP; listen -> ?LISTENER_SUP diff --git a/lib/diameter/src/transport/diameter_tcp.erl b/lib/diameter/src/transport/diameter_tcp.erl index 44abc5c3b4..a2f393d5d4 100644 --- a/lib/diameter/src/transport/diameter_tcp.erl +++ b/lib/diameter/src/transport/diameter_tcp.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -19,7 +19,6 @@ %% -module(diameter_tcp). --dialyzer({no_fail_call, throttle/2}). -behaviour(gen_server). @@ -53,6 +52,7 @@ %% Keys into process dictionary. -define(INFO_KEY, info). -define(REF_KEY, ref). +-define(TRANSPORT_KEY, transport). -define(ERROR(T), erlang:error({T, ?MODULE, ?LINE})). @@ -68,16 +68,23 @@ %% The same gen_server implementation supports three different kinds %% of processes: an actual transport process, one that will club it to %% death should the parent die before a connection is established, and -%% a process owning the listening port. +%% a process owning the listening port. The monitor process +%% historically died after connection establishment, but can now live +%% on as the sender of outgoing messages, so that a blocking send +%% doesn't prevent messages from being received. %% Listener process state. -record(listener, {socket :: inet:socket(), + module :: module(), service = false :: false | pid()}). %% service process %% Monitor process state. -record(monitor, - {parent :: pid(), - transport = self() :: pid()}). + {parent :: reference() | false | pid(), + transport = self() :: pid(), + ack = false :: boolean(), + socket :: inet:socket() | ssl:sslsocket() | undefined, + module :: module() | undefined}). -type length() :: 0..16#FFFFFF. %% message length from Diameter header -type size() :: non_neg_integer(). %% accumulated binary size @@ -97,25 +104,30 @@ -type listen_option() :: {accept, match()} | {ssl_options, true | [ssl:listen_option()]} + | option() | ssl:listen_option() | gen_tcp:listen_option(). -type option() :: {port, non_neg_integer()} - | {fragment_timer, 0..16#FFFFFFFF} - | {throttle_cb, diameter:evaluable()}. + | {sender, boolean()} + | sender + | {message_cb, false | diameter:evaluable()} + | {fragment_timer, 0..16#FFFFFFFF}. %% Accepting/connecting transport process state. -record(transport, {socket :: inet:socket() | ssl:sslsocket(), %% accept/connect socket + active = false :: boolean(), %% is socket active? + recv = true :: boolean(), %% should it be active? parent :: pid(), %% of process that started us module :: module(), %% gen_tcp-like module - frag = <<>> :: frag(), %% message fragment ssl :: [term()] | boolean(), %% ssl options, ssl or not + frag = <<>> :: frag(), %% message fragment timeout :: infinity | 0..16#FFFFFFFF, %% fragment timeout tref = false :: false | reference(), %% fragment timer reference flush = false :: boolean(), %% flush fragment at timeout? - throttle_cb :: false | diameter:evaluable(), %% ask to receive - throttled :: boolean() | binary()}). %% stopped receiving? + message_cb :: false | diameter:evaluable(), + send :: pid() | false}). %% sending process %% The usual transport using gen_tcp can be replaced by anything %% sufficiently gen_tcp-like by passing a 'module' option as the first @@ -137,13 +149,13 @@ start({T, Ref}, Svc, Opts) -> #diameter_service{capabilities = Caps, - pid = SPid} + pid = SvcPid} = Svc, diameter_tcp_sup:start(), %% start tcp supervisors on demand {Mod, Rest} = split(Opts), Addrs = Caps#diameter_caps.host_ip_address, - Arg = {T, Ref, Mod, self(), Rest, Addrs, SPid}, + Arg = {T, Ref, Mod, self(), Rest, Addrs, SvcPid}, diameter_tcp_sup:start_child(Arg). split([{module, M} | Opts]) -> @@ -197,57 +209,53 @@ init(T) -> %% i/1 %% A transport process. -i({T, Ref, Mod, Pid, Opts, Addrs, SPid}) +i({T, Ref, Mod, Pid, Opts, Addrs, SvcPid}) when T == accept; T == connect -> monitor(process, Pid), %% Since accept/connect might block indefinitely, spawn a process - %% that does nothing but kill us with the parent until call - %% returns. - {ok, MPid} = diameter_tcp_sup:start_child(#monitor{parent = Pid}), + %% that kills us with the parent until call returns, and then + %% sends outgoing messages. {[SO|TO], Rest} = proplists:split(Opts, [ssl_options, - fragment_timer, - throttle_cb]), + sender, + message_cb, + fragment_timer]), SslOpts = ssl_opts(SO), OwnOpts = lists:append(TO), Tmo = proplists:get_value(fragment_timer, OwnOpts, ?DEFAULT_FRAGMENT_TIMEOUT), + [CB, Sender] = [proplists:get_value(K, OwnOpts, false) + || K <- [message_cb, sender]], ?IS_TIMEOUT(Tmo) orelse ?ERROR({fragment_timer, Tmo}), - Throttle = proplists:get_value(throttle_cb, OwnOpts, false), - Sock = init(T, Ref, Mod, Pid, SslOpts, Rest, Addrs, SPid), - MPid ! {stop, self()}, %% tell the monitor to die + {ok, MPid} = diameter_tcp_sup:start_child(#monitor{parent = Pid}), + Sock = init(T, Ref, Mod, Pid, SslOpts, Rest, Addrs, SvcPid), M = if SslOpts -> ssl; true -> Mod end, + Sender andalso monitor(process, MPid), + false == CB orelse (Pid ! {diameter, ack}), + MPid ! {start, self(), Sender andalso {Sock, M}, false /= CB}, putr(?REF_KEY, Ref), - throttle(#transport{parent = Pid, - module = M, - socket = Sock, - ssl = SslOpts, - timeout = Tmo, - throttle_cb = Throttle, - throttled = false /= Throttle}); + setopts(#transport{parent = Pid, + module = M, + socket = Sock, + ssl = SslOpts, + message_cb = CB, + timeout = Tmo, + send = Sender andalso MPid}); %% Put the reference in the process dictionary since we now use it %% advertise the ssl socket after TLS upgrade. -i({T, _Ref, _Mod, _Pid, _Opts, _Addrs} = Arg) %% from old code - when T == accept; - T == connect -> - i(erlang:append_element(Arg, _SPid = false)); - %% A monitor process to kill the transport if the parent dies. i(#monitor{parent = Pid, transport = TPid} = S) -> + putr(?TRANSPORT_KEY, TPid), proc_lib:init_ack({ok, self()}), - monitor(process, Pid), monitor(process, TPid), - S; + S#monitor{parent = monitor(process, Pid)}; %% In principle a link between the transport and killer processes %% could do the same thing: have the accepting/connecting process be %% killed when the killer process dies as a consequence of parent %% death. However, a link can be unlinked and this is exactly what -%% gen_tcp seems to so. Links should be left to supervisors. - -i({listen = L, Ref, _APid, T}) -> %% from old code - i({L, Ref, T}); +%% gen_tcp seems to do. Links should be left to supervisors. i({listen, Ref, {Mod, Opts, Addrs}}) -> [_] = diameter_config:subscribe(Ref, transport), %% assert existence @@ -258,7 +266,8 @@ i({listen, Ref, {Mod, Opts, Addrs}}) -> LAddr = laddr(LAddrOpt, Mod, LSock), true = diameter_reg:add_new({?MODULE, listener, {Ref, {LAddr, LSock}}}), proc_lib:init_ack({ok, self(), {LAddr, LSock}}), - #listener{socket = LSock}. + #listener{socket = LSock, + module = Mod}. laddr([], Mod, Sock) -> {ok, {Addr, _Port}} = sockname(Mod, Sock), @@ -279,19 +288,19 @@ ssl_opts(T) -> %% init/8 %% Establish a TLS connection before capabilities exchange ... -init(Type, Ref, Mod, Pid, true, Opts, Addrs, SPid) -> - init(Type, Ref, ssl, Pid, [{cb_info, ?TCP_CB(Mod)} | Opts], Addrs, SPid); +init(Type, Ref, Mod, Pid, true, Opts, Addrs, SvcPid) -> + init(Type, Ref, ssl, Pid, [{cb_info, ?TCP_CB(Mod)} | Opts], Addrs, SvcPid); %% ... or not. -init(Type, Ref, Mod, Pid, _, Opts, Addrs, SPid) -> - init(Type, Ref, Mod, Pid, Opts, Addrs, SPid). +init(Type, Ref, Mod, Pid, _, Opts, Addrs, SvcPid) -> + init(Type, Ref, Mod, Pid, Opts, Addrs, SvcPid). %% init/7 -init(accept = T, Ref, Mod, Pid, Opts, Addrs, SPid) -> +init(accept = T, Ref, Mod, Pid, Opts, Addrs, SvcPid) -> {[Matches], Rest} = proplists:split(Opts, [accept]), {ok, LPid, {LAddr, LSock}} = listener(Ref, {Mod, Rest, Addrs}), - ok = gen_server:call(LPid, {accept, SPid}, infinity), + ok = gen_server:call(LPid, {accept, SvcPid}, infinity), proc_lib:init_ack({ok, self(), [LAddr]}), Sock = ok(accept(Mod, LSock)), ok = accept_peer(Mod, Sock, accept(Matches)), @@ -299,7 +308,7 @@ init(accept = T, Ref, Mod, Pid, Opts, Addrs, SPid) -> diameter_peer:up(Pid), Sock; -init(connect = T, Ref, Mod, Pid, Opts, Addrs, _SPid) -> +init(connect = T, Ref, Mod, Pid, Opts, Addrs, _SvcPid) -> {[LA, RA, RP], Rest} = proplists:split(Opts, [ip, raddr, rport]), LAddrOpt = get_addr(LA, Addrs), RAddr = get_addr(RA), @@ -451,14 +460,18 @@ portnr(Sock) -> %% # handle_call/3 %% --------------------------------------------------------------------------- -handle_call({accept, SPid}, _From, #listener{service = P} = S) -> - {reply, ok, if not is_pid(P), is_pid(SPid) -> - monitor(process, SPid), - S#listener{service = SPid}; +handle_call({accept, SvcPid}, _From, #listener{service = P} = S) -> + {reply, ok, if not is_pid(P), is_pid(SvcPid) -> + monitor(process, SvcPid), + S#listener{service = SvcPid}; true -> S end}; - + +%% Transport is telling us of parent death. +handle_call({stop, _Pid} = Reason, _From, #monitor{} = S) -> + {stop, {shutdown, Reason}, ok, S}; + handle_call(_, _, State) -> {reply, nok, State}. @@ -480,8 +493,7 @@ handle_info(T, #listener{} = S) -> {noreply, #listener{} = l(T,S)}; handle_info(T, #monitor{} = S) -> - m(T,S), - x(T). + {noreply, #monitor{} = m(T,S)}. %% --------------------------------------------------------------------------- %% # code_change/3 @@ -497,6 +509,7 @@ code_change(_, State, _) -> terminate(_, _) -> ok. + %% --------------------------------------------------------------------------- putr(Key, Val) -> @@ -509,18 +522,47 @@ getr(Key) -> %% %% Transition monitor state. +%% Outgoing message. +m(Msg, S) + when is_record(Msg, diameter_packet); + is_binary(Msg) -> + send(Msg, S), + S; + +%% Transport has established a connection. Stop monitoring on the +%% parent so as not to die before a send from the transport. +m({start, TPid, T, Ack} = M, #monitor{transport = TPid} = S) -> + case T of + {Sock, Mod} -> + demonitor(S#monitor.parent, [flush]), + S#monitor{parent = false, + socket = Sock, + module = Mod, + ack = Ack}; + false -> %% monitor not sending + x(M) + end; + %% Transport is telling us to die. -m({stop, TPid}, #monitor{transport = TPid}) -> - ok; +m({stop, TPid} = T, #monitor{transport = TPid}) -> + x(T); -%% Transport has died. -m({'DOWN', _, process, TPid, _}, #monitor{transport = TPid}) -> - ok; +%% Transport is telling us to die. +m({stop, TPid} = T, #monitor{transport = TPid}) -> + x(T); -%% Transport parent has died. -m({'DOWN', _, process, Pid, _}, #monitor{parent = Pid, - transport = TPid}) -> - exit(TPid, {shutdown, parent}). +%% Transport is telling us that TLS has been negotiated after +%% capabilities exchange. +m({tls, SSock}, S) -> + S#monitor{socket = SSock, + module = ssl}; + +%% Transport or parent has died. +m({'DOWN', M, process, P, _} = T, #monitor{parent = MRef, + transport = TPid}) + when M == MRef; + P == TPid -> + x(T). %% l/2 %% @@ -528,18 +570,16 @@ m({'DOWN', _, process, Pid, _}, #monitor{parent = Pid, %% Service process has died. l({'DOWN', _, process, Pid, _} = T, #listener{service = Pid, - socket = Sock}) -> - gen_tcp:close(Sock), + socket = Sock, + module = M}) -> + M:close(Sock), x(T); %% Transport has been removed. -l({transport, remove, _} = T, #listener{socket = Sock}) -> - gen_tcp:close(Sock), - x(T); - -%% Possibly death of an accepting process monitored in old code. -l(_, S) -> - S. +l({transport, remove, _} = T, #listener{socket = Sock, + module = M}) -> + M:close(Sock), + x(T). %% t/2 %% @@ -557,21 +597,13 @@ t(T,S) -> %% transition/2 -%% Incoming message. +%% Incoming packets. transition({P, Sock, Bin}, #transport{socket = Sock, - ssl = B, - throttled = T} + ssl = B} = S) when P == ssl, true == B; P == tcp -> - false = T, %% assert - recv(Bin, S); - -%% Make a new throttling callback after a timeout. -transition(throttle, #transport{throttled = false}) -> - ok; -transition(throttle, S) -> - throttle(S); + recv(Bin, S#transport{active = false}); %% Capabilties exchange has decided on whether or not to run over TLS. transition({diameter, {tls, Ref, Type, B}}, #transport{parent = Pid} @@ -581,7 +613,7 @@ transition({diameter, {tls, Ref, Type, B}}, #transport{parent = Pid} = NS = tls_handshake(Type, B, S), Pid ! {diameter, {tls, Ref}}, - throttle(NS#transport{ssl = B}); + NS#transport{ssl = B}; transition({C, Sock}, #transport{socket = Sock, ssl = B}) @@ -597,8 +629,18 @@ transition({E, Sock, _Reason} = T, #transport{socket = Sock, ?ERROR({T,S}); %% Outgoing message. -transition({diameter, {send, Bin}}, S) -> - send(Bin, S); +transition({diameter, {send, Msg}}, #transport{} = S) -> + message(send, Msg, S); + +%% Monitor has sent an outgoing message. +transition(Msg, S) + when is_record(Msg, diameter_packet); + is_binary(Msg) -> + message(ack, Msg, S); + +%% Deferred actions from a message_cb. +transition({actions, Dir, Acts}, S) -> + actions(Acts, Dir, S); %% Request to close the transport connection. transition({diameter, {close, Pid}}, #transport{parent = Pid, @@ -618,8 +660,18 @@ transition({resolve_port, Pid}, #transport{socket = Sock, Pid ! portnr(M, Sock), ok; -%% Parent process has died. -transition({'DOWN', _, process, Pid, _}, #transport{parent = Pid}) -> +%% Parent process has died: call the monitor to not close the socket +%% during an ongoing send, but don't let it take forever. +transition({'DOWN', _, process, Pid, _}, #transport{parent = Pid, + send = MPid}) -> + false == MPid + orelse (ok == gen_server:call(MPid, {stop, self()}, 1000)) + orelse exit(MPid, {shutdown, parent}), + stop; + +%% Monitor process has died. +transition({'DOWN', _, process, MPid, _}, #transport{send = MPid}) + when is_pid(MPid) -> stop. %% Crash on anything unexpected. @@ -643,11 +695,13 @@ tls_handshake(_, true, #transport{ssl = false}) -> %% Capabilities exchange negotiated TLS: upgrade the connection. tls_handshake(Type, true, #transport{socket = Sock, module = M, - ssl = Opts} + ssl = Opts, + send = MPid} = S) -> {ok, SSock} = tls(Type, Sock, [{cb_info, ?TCP_CB(M)} | Opts]), Ref = getr(?REF_KEY), true = diameter_reg:add_new({?MODULE, Type, {Ref, SSock}}), + false == MPid orelse (MPid ! {tls, SSock}), %% tell the sender process S#transport{socket = SSock, module = ssl}; @@ -666,24 +720,15 @@ tls(accept, Sock, Opts) -> %% using Nagle. %% Receive packets until a full message is received, -recv(Bin, #transport{frag = Head, throttled = false} = S) -> +recv(Bin, #transport{frag = Head} = S) -> case rcv(Head, Bin) of - {Msg, B} -> - throttle(S#transport{frag = B, throttled = Msg}); - Frag -> - setopts(S), - start_fragment_timer(S#transport{frag = Frag, - flush = false}) + {Msg, B} -> %% have a complete message ... + message(recv, Msg, S#transport{frag = B}); + Frag -> %% read more on the socket + start_fragment_timer(setopts(S#transport{frag = Frag, + flush = false})) end. -%% recv/1 - -recv(#transport{throttled = false} = S) -> - recv(<<>>, S); - -recv(#transport{} = S) -> - S. - %% rcv/2 %% No previous fragment. @@ -743,13 +788,16 @@ recv1(Len, Bin) -> <<Msg:Len/binary, Rest/binary>> = Bin, {Msg, Rest}. -%% bin/1-2 +%% bin/2 bin(Head, Acc) -> list_to_binary([Head | lists:reverse(Acc)]). +%% bin/1 + bin({_, _, Head, Acc}) -> bin(Head, Acc); + bin(Bin) when is_binary(Bin) -> Bin. @@ -768,9 +816,7 @@ bin(Bin) %% also eventually lead to watchdog failover. %% No fragment to flush or not receiving messages. -flush(#transport{frag = Frag, throttled = B} = S) - when Frag == <<>>; - B /= false -> +flush(#transport{frag = <<>>} = S) -> S; %% Messages have been received since last timer expiry. @@ -778,9 +824,8 @@ flush(#transport{flush = false} = S) -> start_fragment_timer(S#transport{flush = true}); %% No messages since last expiry. -flush(#transport{frag = Frag, parent = Pid} = S) -> - diameter_peer:recv(Pid, bin(Frag)), - S#transport{frag = <<>>}. +flush(#transport{frag = Frag} = S) -> + message(recv, bin(Frag), S#transport{frag = <<>>}). %% start_fragment_timer/1 %% @@ -813,9 +858,27 @@ connect(Mod, Host, Port, Opts) -> %% send/2 -send(Bin, #transport{socket = Sock, - module = M}) -> - case send(M, Sock, Bin) of +send(Msg, #monitor{socket = Sock, module = M, transport = TPid, ack = B}) -> + send1(M, Sock, Msg), + B andalso (TPid ! Msg); + +send(Msg, #transport{socket = Sock, module = M, send = false} = S) -> + send1(M, Sock, Msg), + message(ack, Msg, S); + +%% Send from the monitor process to avoid deadlock if both the +%% receiver and the peer were to block in send. +send(Msg, #transport{send = Pid} = S) -> + Pid ! Msg, + S. + +%% send1/3 + +send1(Mod, Sock, #diameter_packet{bin = Bin}) -> + send1(Mod, Sock, Bin); + +send1(Mod, Sock, Bin) -> + case send(Mod, Sock, Bin) of ok -> ok; {error, Reason} -> @@ -842,120 +905,19 @@ setopts(M, Sock, Opts) -> %% setopts/1 -setopts(#transport{socket = Sock, module = M}) -> - setopts(M, Sock). - -%% setopts/2 - -setopts(M, Sock) -> +setopts(#transport{socket = Sock, + active = A, + recv = B, + module = M} + = S) + when B, not A -> case setopts(M, Sock, [{active, once}]) of - ok -> ok; - X -> x({setopts, M, Sock, X}) %% possibly on peer disconnect - end. - -%% throttle/1 - -%% Still collecting packets for a complete message: keep receiving. -throttle(#transport{throttled = false} = S) -> - recv(S); - -%% Decide whether to receive another, or whether to accept a message -%% that's been received. -throttle(#transport{throttle_cb = F, throttled = T} = S) -> - Res = cb(F, T), - - try throttle(Res, S) of - #transport{ssl = SB} = NS when is_boolean(SB) -> - throttle(defrag(NS)); - #transport{throttled = Msg} = NS when is_binary(Msg) -> - %% Initial incoming message when we might need to upgrade - %% to TLS: wait for reception of a tls tuple. - defrag(NS) - catch - #transport{} = NS -> - recv(NS) - end. - -%% cb/2 - -cb(false, _) -> - ok; - -cb(F, B) -> - diameter_lib:eval([F, true /= B andalso B]). - -%% throttle/2 - -%% Callback says to receive another message. -throttle(ok, #transport{throttled = true} = S) -> - throw(S#transport{throttled = false}); - -%% Callback says to accept a received message. -throttle(ok, #transport{parent = Pid, throttled = Msg} = S) - when is_binary(Msg) -> - diameter_peer:recv(Pid, Msg), - S; - -throttle({ok = T, F}, S) -> - throttle(T, S#transport{throttle_cb = F}); - -%% Callback says to accept a received message and acknowledged the -%% returned pid with a {request, Pid} message if a request pid is -%% spawned, a discard message otherwise. The latter does not mean that -%% the message was necessarily discarded: it could have been an -%% answer. -throttle(NPid, #transport{parent = Pid, throttled = Msg} = S) - when is_pid(NPid), is_binary(Msg) -> - diameter_peer:recv(Pid, {Msg, NPid}), - S; - -throttle({NPid, F}, #transport{throttled = Msg} = S) - when is_pid(NPid), is_binary(Msg) -> - throttle(NPid, S#transport{throttle_cb = F}); - -%% Callback to accept a received message says to discard it. -throttle(discard, #transport{throttled = Msg} = S) - when is_binary(Msg) -> - S; - -throttle({discard = T, F}, #transport{throttled = Msg} = S) - when is_binary(Msg) -> - throttle(T, S#transport{throttle_cb = F}); - -%% Callback to accept a received message says to answer it with the -%% supplied binary. -throttle(Bin, #transport{throttled = Msg} = S) - when is_binary(Bin), is_binary(Msg) -> - send(Bin, S), - S; - -throttle({Bin, F}, #transport{throttled = Msg} = S) - when is_binary(Bin), is_binary(Msg) -> - throttle(Bin, S#transport{throttle_cb = F}); - -%% Callback says to ask again in the specified number of milliseconds. -throttle({timeout, Tmo}, S) -> - erlang:send_after(Tmo, self(), throttle), - throw(S); - -throttle({timeout = T, Tmo, F}, S) -> - throttle({T, Tmo}, S#transport{throttle_cb = F}); - -throttle(T, #transport{throttle_cb = F}) -> - ?ERROR({invalid_return, T, F}). - -%% defrag/1 -%% -%% Try to extract another message from packets already read before -%% another throttling callback. + ok -> S#transport{active = true}; + X -> x({setopts, Sock, M, X}) %% possibly on peer disconnect + end; -defrag(#transport{frag = Head} = S) -> - case rcv(Head, <<>>) of - {Msg, B} -> - S#transport{throttled = Msg, frag = B}; - _ -> - S#transport{throttled = true} - end. +setopts(S) -> + S. %% portnr/2 @@ -990,3 +952,80 @@ getstat(gen_tcp, Sock) -> getstat(M, Sock) -> M:getstat(Sock). %% Note that ssl:getstat/1 doesn't yet exist in R15B01. + +%% A message_cb is invoked whenever a message is sent or received, or +%% to provide acknowledgement of a completed send or discarded +%% request. Ignoring possible extra arguments, calls are of the +%% following form. +%% +%% cb(recv, Msg) Receive a message into diameter? +%% cb(send, Msg) Send a message on the socket? +%% cb(ack, Msg) Acknowledgement of a completed send. +%% cb(ack, false) Acknowledgement of a discarded request. +%% +%% Msg will be binary() in a recv callback, but can be a +%% diameter_packet record in a send/ack callback if a recv/send +%% callback returns a record. Callbacks return a list of the following +%% form. +%% +%% [boolean() | send | recv | binary() | #diameter_packet{}] +%% +%% The atoms are meaningless by themselves, but say whether subsequent +%% messages are to be sent or received. A boolean says whether or not +%% to continue reading on the socket. Messages can be received even +%% after false is returned if these arrived in the same packet. A +%% leading recv or send is implicit on the corresponding callbacks. A +%% new callback can be returned as the tail of a returned list: any +%% value not of the aforementioned list type is interpreted as a +%% callback. + +%% message/3 + +message(send, false = M, S) -> + message(ack, M, S); + +message(ack, _, #transport{message_cb = false} = S) -> + S; + +message(Dir, Msg, #transport{message_cb = CB} = S) -> + recv(<<>>, actions(cb(CB, Dir, Msg), Dir, S)). + +%% actions/3 + +actions([], _, S) -> + S; + +actions([B | As], Dir, S) + when is_boolean(B) -> + actions(As, Dir, S#transport{recv = B}); + +actions([Dir | As], _, S) + when Dir == send; + Dir == recv -> + actions(As, Dir, S); + +actions([Msg | As], send = Dir, S) + when is_binary(Msg); + is_record(Msg, diameter_packet) -> + actions(As, Dir, send(Msg, S)); + +actions([Msg | As], recv = Dir, #transport{parent = Pid} = S) + when is_binary(Msg); + is_record(Msg, diameter_packet) -> + diameter_peer:recv(Pid, Msg), + actions(As, Dir, S); + +actions([{defer, Tmo, Acts} | As], Dir, S) -> + erlang:send_after(Tmo, self(), {actions, Dir, Acts}), + actions(As, Dir, S); + +actions(CB, _, S) -> + S#transport{message_cb = CB}. + +%% cb/3 + +cb(false, _, Msg) -> + [Msg]; + +cb(CB, Dir, Msg) -> + diameter_lib:eval([CB, Dir, Msg]). diff --git a/lib/diameter/test/diameter_capx_SUITE.erl b/lib/diameter/test/diameter_capx_SUITE.erl index fdeff96a58..51b6c1d7f2 100644 --- a/lib/diameter/test/diameter_capx_SUITE.erl +++ b/lib/diameter/test/diameter_capx_SUITE.erl @@ -433,7 +433,7 @@ server_reject(Config, F, RC) -> ?fail({LRef, OH}) end. -%% cliient_closed/4 +%% client_closed/4 client_closed(Config, Host, F, RC) -> true = diameter:subscribe(?CLIENT), diff --git a/lib/diameter/test/diameter_codec_SUITE.erl b/lib/diameter/test/diameter_codec_SUITE.erl index 558ba3b848..9f08f49f9f 100644 --- a/lib/diameter/test/diameter_codec_SUITE.erl +++ b/lib/diameter/test/diameter_codec_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2015. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -31,6 +31,8 @@ -export([suite/0, all/0, groups/0, + init_per_suite/1, + end_per_suite/1, init_per_group/2, end_per_group/2, init_per_testcase/2, @@ -63,6 +65,12 @@ groups() -> grouped_error, failed_error]}]. +init_per_suite(Config) -> + Config. + +end_per_suite(_Config) -> + ok. + init_per_group(recode, Config) -> ok = diameter:start(), Config. @@ -277,7 +285,14 @@ recode(Msg) -> recode(Msg, diameter_gen_base_rfc6733). recode(#diameter_packet{} = Pkt, Dict) -> - diameter_codec:decode(Dict, diameter_codec:encode(Dict, Pkt)); + diameter_codec:decode(Dict, opts(Dict), diameter_codec:encode(Dict, Pkt)); recode(Msg, Dict) -> recode(#diameter_packet{msg = Msg}, Dict). + +opts(Mod) -> + #{dictionary => Mod, + string_decode => false, + strict_mbit => true, + rfc => 6733, + failed_avp => false}. diff --git a/lib/diameter/test/diameter_codec_SUITE_data/diameter_test_unknown.erl b/lib/diameter/test/diameter_codec_SUITE_data/diameter_test_unknown.erl index 50cc6e7eef..700910878c 100644 --- a/lib/diameter/test/diameter_codec_SUITE_data/diameter_test_unknown.erl +++ b/lib/diameter/test/diameter_codec_SUITE_data/diameter_test_unknown.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -59,7 +59,7 @@ enc(M, #diameter_packet{msg = Vs} = P) -> P#diameter_packet{msg = [M|Vs]}). run(M, Pkt) -> - dec(M, diameter_codec:decode(diameter_test_recv, Pkt)). + dec(M, diameter_codec:decode(diameter_test_recv, opts(M), Pkt)). %% Note that the recv dictionary defines neither XXX nor YYY. dec('AR', #diameter_packet @@ -75,3 +75,10 @@ dec('BR', #diameter_packet errors = [{5001, ?MANDATORY_XXX}, {5008, ?NOT_MANDATORY_YYY}]}) -> ok. + +opts(Mod) -> + #{dictionary => Mod, + string_decode => true, + strict_mbit => true, + rfc => 6733, + failed_avp => false}. diff --git a/lib/diameter/test/diameter_codec_test.erl b/lib/diameter/test/diameter_codec_test.erl index 869797f11f..b548f85cb8 100644 --- a/lib/diameter/test/diameter_codec_test.erl +++ b/lib/diameter/test/diameter_codec_test.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -94,7 +94,7 @@ base(T) -> %% Ensure that 'zero' values encode only zeros. base(zero = T, F) -> - B = diameter_types:F(encode, T), + B = diameter_types:F(encode, T, opts()), B = z(B); %% Ensure that we can decode what we encode and vice-versa, and that @@ -106,7 +106,7 @@ base(decode, F) -> [] = run([[fun base_invalid/2, F, V] || V <- Is]). base_decode(F, Eq, Value) -> - d(fun(X,V) -> diameter_types:F(X,V) end, Eq, Value). + d(fun(X,V) -> diameter_types:F(X, V, opts()) end, Eq, Value). base_invalid(F, Value) -> try @@ -171,7 +171,7 @@ gen(M, avp_types, {Name, Code, Type, _Flags}) -> V = undefined /= VendorId, V = 0 /= Flags band 2#10000000, {Name, Type} = M:avp_name(Code, VendorId), - B = M:empty_value(Name), + B = M:empty_value(Name, #{module => M}), B = z(B), [] = avp_decode(M, Type, Name); @@ -207,10 +207,22 @@ avp_decode(Mod, Name, Type, Eq, Value) -> d(fun(X,V) -> avp(Mod, X, V, Name, Type) end, Eq, Value). avp(Mod, decode = X, V, Name, 'Grouped') -> - {Rec, _} = Mod:avp(X, V, Name), + {Rec, _} = Mod:avp(X, V, Name, opts(Mod)), Rec; -avp(Mod, X, V, Name, _) -> - Mod:avp(X, V, Name). +avp(Mod, decode = X, V, Name, _) -> + Mod:avp(X, V, Name, opts(Mod)); +avp(Mod, encode = X, V, Name, _) -> + iolist_to_binary(Mod:avp(X, V, Name, opts(Mod))). + +opts(Mod) -> + (opts())#{module => Mod, + dictionary => Mod}. + +opts() -> + #{string_decode => true, + strict_mbit => true, + rfc => 6733, + failed_avp => false}. %% v/1 @@ -257,8 +269,8 @@ arity(M, Name, AvpName, Rec) -> enum(M, Name, {_,E}) -> B = <<E:32>>, - B = M:avp(encode, E, Name), - E = M:avp(decode, B, Name). + B = M:avp(encode, E, Name, opts(M)), + E = M:avp(decode, B, Name, opts(M)). retag(import_avps) -> avp_types; retag(import_groups) -> grouped; @@ -280,7 +292,8 @@ d(F, Eq, V) -> end. z(B) -> - << <<0>> || <<_>> <= B >>. + Sz = size(B), + <<0:Sz/unit:8>>. %% values/1 %% diff --git a/lib/diameter/test/diameter_compiler_SUITE.erl b/lib/diameter/test/diameter_compiler_SUITE.erl index 7a9ac65ae3..73fe1ef6e0 100644 --- a/lib/diameter/test/diameter_compiler_SUITE.erl +++ b/lib/diameter/test/diameter_compiler_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -39,7 +39,7 @@ -export([dict/0]). %% fake dictionary module %% dictionary callbacks for flatten2/1 --export(['A1'/3, 'Unsigned32'/3]). +-export(['A1'/4, 'Unsigned32'/4]). -define(base, "base_rfc3588.dia"). -define(util, diameter_util). @@ -552,13 +552,13 @@ flatten2(_Config) -> T <- [encode, decode], M <- [M2, M3], Ref <- [make_ref()], - RC <- [M:avp(T, Ref, A)], + RC <- [M:avp(T, Ref, A, #{module => M})], RC /= {T, Ref}]. -'A1'(T, 'Unsigned32', Ref) -> +'A1'(T, 'Unsigned32', Ref, _Opts) -> {T, Ref}. -'Unsigned32'(T, 'A3', Ref) -> +'Unsigned32'(T, 'A3', Ref, _Opts) -> {T, Ref}. load_forms(Forms) -> diff --git a/lib/diameter/test/diameter_dict_SUITE.erl b/lib/diameter/test/diameter_dict_SUITE.erl deleted file mode 100644 index 4c1349f4eb..0000000000 --- a/lib/diameter/test/diameter_dict_SUITE.erl +++ /dev/null @@ -1,145 +0,0 @@ -%% -%% %CopyrightBegin% -%% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. -%% -%% Licensed under the Apache License, Version 2.0 (the "License"); -%% you may not use this file except in compliance with the License. -%% You may obtain a copy of the License at -%% -%% http://www.apache.org/licenses/LICENSE-2.0 -%% -%% Unless required by applicable law or agreed to in writing, software -%% distributed under the License is distributed on an "AS IS" BASIS, -%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -%% See the License for the specific language governing permissions and -%% limitations under the License. -%% -%% %CopyrightEnd% -%% - -%% -%% Tests of the dict-like diameter_dict. -%% - --module(diameter_dict_SUITE). - --export([suite/0, - all/0, - groups/0]). - -%% testcases --export([append/1, - fetch/1, - fetch_keys/1, - filter/1, - find/1, - fold/1, - is_key/1, - map/1, - merge/1, - update/1, - update_counter/1]). - --include("diameter_ct.hrl"). - --define(dict, diameter_dict). --define(util, diameter_util). - -%% =========================================================================== - -suite() -> - [{timetrap, {seconds, 60}}]. - -all() -> - [{group, all}, - {group, all, [parallel]}]. - -groups() -> - [{all, [], tc()}]. - -tc() -> - [append, - fetch, - fetch_keys, - filter, - find, - fold, - is_key, - map, - merge, - update, - update_counter]. - -%% =========================================================================== - --define(KV100, [{N,[N]} || N <- lists:seq(1,100)]). - -append(_) -> - D = ?dict:append(k, v, ?dict:new()), - [{k,[v,v]}] = ?dict:to_list(?dict:append(k, v, D)). - -fetch(_) -> - D = ?dict:from_list(?KV100), - [50] = ?dict:fetch(50, D), - Ref = make_ref(), - Ref = try ?dict:fetch(Ref, D) catch _:_ -> Ref end. - -fetch_keys(_) -> - L = ?KV100, - D = ?dict:from_list(L), - L = [{N,[N]} || N <- lists:sort(?dict:fetch_keys(D))]. - -filter(_) -> - L = ?KV100, - F = fun(K,[_]) -> 0 == K rem 2 end, - D = ?dict:filter(F, ?dict:from_list(L)), - true = [T || {K,V} = T <- L, F(K,V)] == lists:sort(?dict:to_list(D)). - -find(_) -> - D = ?dict:from_list(?KV100), - {ok, [50]} = ?dict:find(50, D), - error = ?dict:find(make_ref(), D). - -fold(_) -> - L = ?KV100, - S = lists:sum([N || {N,_} <- L]), - S = ?dict:fold(fun(K,[_],A) -> K + A end, 0, ?dict:from_list(L)). - -is_key(_) -> - L = ?KV100, - D = ?dict:from_list(L), - true = lists:all(fun({N,_}) -> ?dict:is_key(N,D) end, L), - false = ?dict:is_key(make_ref(), D). - -map(_) -> - L = ?KV100, - F = fun(_,V) -> [N] = V, N*2 end, - D = ?dict:map(F, ?dict:from_list(L)), - M = [{K, F(K,V)} || {K,V} <- L], - M = lists:sort(?dict:to_list(D)). - -merge(_) -> - L = ?KV100, - F = fun(_,V1,V2) -> V1 ++ V2 end, - D = ?dict:merge(F, ?dict:from_list(L), ?dict:from_list(L)), - M = [{K, F(K,V,V)} || {K,V} <- L], - M = lists:sort(?dict:to_list(D)). - -update(_) -> - L = ?KV100, - F = fun([V]) -> 2*V end, - D = ?dict:update(50, F, ?dict:from_list(L)), - 100 = ?dict:fetch(50, D), - Ref = make_ref(), - Ref = try ?dict:update(Ref, F, D) catch _:_ -> Ref end, - [Ref] = ?dict:fetch(Ref, ?dict:update(Ref, - fun(_,_) -> ?ERROR(i_think_not) end, - [Ref], - D)). - -update_counter(_) -> - L = [{N,2*N} || {N,_} <- ?KV100], - D = ?dict:update_counter(50, 20, ?dict:from_list(L)), - 120 = ?dict:fetch(50,D), - 2 = ?dict:fetch(1,D). diff --git a/lib/diameter/test/diameter_dpr_SUITE.erl b/lib/diameter/test/diameter_dpr_SUITE.erl index 55702fbf78..779b919d3c 100644 --- a/lib/diameter/test/diameter_dpr_SUITE.erl +++ b/lib/diameter/test/diameter_dpr_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2012-2015. All Rights Reserved. +%% Copyright Ericsson AB 2012-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -27,6 +27,8 @@ -export([suite/0, all/0, groups/0, + init_per_suite/1, + end_per_suite/1, init_per_group/2, end_per_group/2]). @@ -56,16 +58,12 @@ %% Config for diameter:start_service/2. -define(SERVICE(Host), - [{'Origin-Host', Host}, + [{'Origin-Host', Host ++ ".erlang.org"}, {'Origin-Realm', "erlang.org"}, {'Host-IP-Address', [?ADDR]}, {'Vendor-Id', hd(Host)}, %% match this in disconnect/5 {'Product-Name', "OTP/diameter"}, - {'Acct-Application-Id', [0]}, - {restrict_connections, false}, - {application, [{dictionary, diameter_gen_base_rfc6733}, - {alias, common}, - {module, #diameter_callback{_ = false}}]}]). + {restrict_connections, false}]). %% Disconnect reasons that diameter passes as the first argument of a %% function configured as disconnect_cb. @@ -89,13 +87,19 @@ suite() -> [{timetrap, {seconds, 60}}]. all() -> - [start, send_dpr, stop | [{group, R} || R <- ?REASONS]]. + [{group, R} || R <- [client, server, uncommon | ?REASONS]]. %% The group determines how transports are terminated: by remove_transport, %% stop_service or application stop. groups() -> - Ts = tc(), - [{R, [], Ts} || R <- ?REASONS]. + [{R, [], [start, send_dpr, stop]} || R <- [client, server, uncommon]] + ++ [{R, [], Ts} || Ts <- [tc()], R <- ?REASONS]. + +init_per_suite(Config) -> %% not need, but a useful place to enable trace + Config. + +end_per_suite(_Config) -> + ok. init_per_group(Name, Config) -> [{group, Name} | Config]. @@ -107,29 +111,86 @@ tc() -> [start, connect, remove_transport, stop_service, check, stop]. %% =========================================================================== -%% start/stop testcases -start(_Config) -> - ok = diameter:start(), - ok = diameter:start_service(?SERVER, ?SERVICE(?SERVER)), - ok = diameter:start_service(?CLIENT, ?SERVICE(?CLIENT)). +%% start/1 -send_dpr(_Config) -> +start(Config) + when is_list(Config) -> + Grp = group(Config), + ok = diameter:start(), + ok = diameter:start_service(?SERVER, service(?SERVER, Grp)), + ok = diameter:start_service(?CLIENT, service(?CLIENT, Grp)). + +service(?SERVER = Svc, _) -> + ?SERVICE(Svc) + ++ [{'Acct-Application-Id', [0,3]}, + {application, [{dictionary, diameter_gen_base_rfc6733}, + {alias, common}, + {module, #diameter_callback{_ = false}}]}, + {application, [{dictionary, diameter_gen_acct_rfc6733}, + {alias, acct}, + {module, #diameter_callback{_ = false}}]}]; + +%% Client that receives a server DPR despite no explicit support for +%% Diameter common messages. +service(?CLIENT = Svc, server) -> + ?SERVICE(Svc) + ++ [{'Acct-Application-Id', [3]}, + {application, [{dictionary, diameter_gen_acct_rfc6733}, + {alias, acct}, + {module, #diameter_callback{_ = false}}]}]; + +%% Client that sends DPR despite advertised only the accounting +%% application. The dictionary is required for encode. +service(?CLIENT = Svc, uncommon) -> + ?SERVICE(Svc) + ++ [{'Acct-Application-Id', [3]}, + {application, [{dictionary, diameter_gen_base_rfc6733}, + {alias, common}, + {module, #diameter_callback{_ = false}}]}, + {application, [{dictionary, diameter_gen_acct_rfc6733}, + {alias, acct}, + {module, #diameter_callback{_ = false}}]}]; + +service(?CLIENT = Svc, _) -> + ?SERVICE(Svc) + ++ [{'Auth-Application-Id', [0]}, + {application, [{dictionary, diameter_gen_base_rfc6733}, + {alias, common}, + {module, #diameter_callback{_ = false}}]}]. + +%% send_dpr/1 + +send_dpr(Config) -> LRef = ?util:listen(?SERVER, tcp), Ref = ?util:connect(?CLIENT, tcp, LRef, [{dpa_timeout, 10000}]), + Svc = sender(group(Config)), + [Info] = diameter:service_info(Svc, connections), + {_, {TPid, _}} = lists:keyfind(peer, 1, Info), #diameter_base_DPA{'Result-Code' = 2001} - = diameter:call(?CLIENT, + = diameter:call(Svc, common, - ['DPR', {'Origin-Host', "CLIENT.erlang.org"}, - {'Origin-Realm', "erlang.org"}, - {'Disconnect-Cause', 0}]), - ok = receive %% endure the transport dies on DPA + ['DPR', {'Origin-Host', Svc ++ ".erlang.org"}, + {'Origin-Realm', "erlang.org"}, + {'Disconnect-Cause', 0}], + [{peer, TPid}]), + ok = receive %% ensure the transport dies on DPA #diameter_event{service = ?CLIENT, info = {down, Ref, _, _}} -> ok after 5000 -> erlang:process_info(self(), messages) end. +%% sender/1 + +sender(server) -> + ?SERVER; + +sender(_) -> + ?CLIENT. + +%% connect/1 + connect(Config) -> Pid = spawn(fun init/0), %% process for disconnect_cb to bang Grp = group(Config), @@ -138,16 +199,22 @@ connect(Config) -> || RCs <- ?RETURNS], ?util:write_priv(Config, config, [Pid | Refs]). +%% remove_transport/1 + %% Remove all the client transports only in the transport group. remove_transport(Config) -> transport == group(Config) andalso (ok = diameter:remove_transport(?CLIENT, true)). +%% stop_service/1 + %% Stop the service only in the service group. stop_service(Config) -> service == group(Config) andalso (ok = diameter:stop_service(?CLIENT)). +%% check/1 + %% Check for callbacks before diameter:stop/0, not the other way around %% for the timing reason explained below. check(Config) -> @@ -157,9 +224,13 @@ check(Config) -> Dict = receive {Pid, D} -> D end, %% get it check(Refs, ?RETURNS, Grp, Dict). %% check for callbacks +%% stop/1 + stop(_Config) -> ok = diameter:stop(). +%% =========================================================================== + %% Whether or not there are callbacks after diameter:stop() depends on %% timing as long as the server runs on the same node: a server %% transport could close the connection before the client has chance diff --git a/lib/diameter/test/diameter_examples_SUITE.erl b/lib/diameter/test/diameter_examples_SUITE.erl index e4ed2b227d..fad54d62b2 100644 --- a/lib/diameter/test/diameter_examples_SUITE.erl +++ b/lib/diameter/test/diameter_examples_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2013-2015. All Rights Reserved. +%% Copyright Ericsson AB 2013-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. diff --git a/lib/diameter/test/diameter_gen_sctp_SUITE.erl b/lib/diameter/test/diameter_gen_sctp_SUITE.erl index 79db39ca45..ccee6baec1 100644 --- a/lib/diameter/test/diameter_gen_sctp_SUITE.erl +++ b/lib/diameter/test/diameter_gen_sctp_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -33,8 +33,8 @@ end_per_suite/1]). %% testcases --export([send_not_from_controlling_process/1, - send_from_multiple_clients/1, send_from_multiple_clients/0, +-export([send_one_from_many/1, send_one_from_many/0, + send_many_from_one/1, send_many_from_one/0, receive_what_was_sent/1]). -include_lib("kernel/include/inet_sctp.hrl"). @@ -45,16 +45,24 @@ %% Open sockets on the loopback address. -define(ADDR, {127,0,0,1}). -%% Snooze, nap, siesta. --define(SLEEP(T), receive after T -> ok end). - %% An indescribably long number of milliseconds after which everthing %% that should have happened has. -define(FOREVER, 2000). +%% How many milliseconds to tolerate between the fastest and slowest +%% turnaround times. +-define(VARIANCE, 100). + %% The first byte in each message we send as a simple guard against %% not receiving what was sent. --define(MAGIC, 42). +-define(MAGIC, 0). + +%% Requested number of inbound/outbound streams. +-define(STREAMS, 5). + +%% Success for send_multiple. Match in each testcase rather than in +%% send_multiple itself for a better failure in common_test. +-define(OK, {_, true, _, [true, true], [], _}). %% =========================================================================== @@ -62,8 +70,8 @@ suite() -> [{timetrap, {seconds, 10}}]. all() -> - [send_not_from_controlling_process, - send_from_multiple_clients, + [send_one_from_many, + send_many_from_one, receive_what_was_sent]. init_per_suite(Config) -> @@ -81,130 +89,37 @@ end_per_suite(_Config) -> %% =========================================================================== -%% send_not_from_controlling_process/1 -%% -%% This testcase failing shows gen_sctp:send/4 hanging when called -%% outside the controlling process of the socket in question. - -send_not_from_controlling_process(_) -> - Pids = send_not_from_controlling_process(), - ?SLEEP(?FOREVER), - try - [] = [{P,I} || P <- Pids, I <- [process_info(P)], I /= undefined] - after - lists:foreach(fun(P) -> exit(P, kill) end, Pids) - end. - -%% send_not_from_controlling_process/0 -%% -%% Returns the pids of three spawned processes: a listening process, a -%% connecting process and a sending process. -%% -%% The expected behaviour is that all three processes exit: -%% -%% - The listening process exits upon receiving an SCTP message -%% sent by the sending process. -%% - The connecting process exits upon listening process exit. -%% - The sending process exits upon gen_sctp:send/4 return. -%% -%% The observed behaviour is that all three processes remain alive -%% indefinitely: -%% -%% - The listening process never receives the SCTP message sent -%% by the sending process. -%% - The connecting process has an inet_reply message in its mailbox -%% as a consequence of the call to gen_sctp:send/4 call from the -%% sending process. -%% - The call to gen_sctp:send/4 in the sending process doesn't return, -%% hanging in prim_inet:getopts/2. - -send_not_from_controlling_process() -> - FPid = self(), - {L, MRef} = spawn_monitor(fun() -> listen(FPid) end), - receive - {?MODULE, C, S} -> - demonitor(MRef, [flush]), - [L,C,S]; - {'DOWN', MRef, process, _, _} = T -> - error(T) - end. - -%% listen/1 - -listen(FPid) -> - {ok, Sock} = open(), - ok = gen_sctp:listen(Sock, true), - {ok, PortNr} = inet:port(Sock), - LPid = self(), - spawn(fun() -> connect1(PortNr, FPid, LPid) end), %% connecting process - Id = assoc(Sock), - recv(Sock, Id). - -%% connect1/3 - -connect1(PortNr, FPid, LPid) -> - {ok, Sock} = open(), - ok = gen_sctp:connect_init(Sock, ?ADDR, PortNr, []), - Id = assoc(Sock), - FPid ! {?MODULE, - self(), - spawn(fun() -> send(Sock, Id) end)}, %% sending process - MRef = monitor(process, LPid), - down(MRef). %% Waits with this as current_function. - -%% down/1 - -down(MRef) -> - receive {'DOWN', MRef, process, _, Reason} -> Reason end. - -%% send/2 - -send(Sock, Id) -> - ok = gen_sctp:send(Sock, Id, 0, <<0:32>>). - -%% =========================================================================== - -%% send_from_multiple_clients/0 +%% send_one_from_many/0 %% %% Demonstrates sluggish delivery of messages. -send_from_multiple_clients() -> - [{timetrap, {seconds, 60}}]. +send_one_from_many() -> + [{timetrap, {seconds, 30}}]. -send_from_multiple_clients(_) -> - {S, Rs} = T = send_from_multiple_clients(8, 1024), - Max = ?FOREVER*1000, - {false, [], _} = {Max < S, - Rs -- [OI || {O,_} = OI <- Rs, is_integer(O)], - T}. +send_one_from_many(_) -> + ?OK = send_multiple(128, 1, 1024). -%% send_from_multiple_clients/2 +%% send_one_from_many/2 %% %% Opens a listening socket and then spawns a specified number of -%% processes, each of which connects to the listening socket. Each -%% connecting process then sends a message, whose size in bytes is -%% passed as an argument, the listening process sends a reply -%% containing the time at which the message was received, and the -%% connecting process then exits upon reception of this reply. +%% processes, each of which connects, sends a message, receives a +%% reply, and exits. %% %% Returns the elapsed time for all connecting process to exit -%% together with a list of exit reasons for the connecting processes. -%% In the successful case a connecting process exits with the -%% outbound/inbound transit times for the sent/received message as -%% reason. +%% together with a list of exit reasons. In the successful case a +%% connecting process exits with the outbound/inbound transit times +%% for the sent/received message as reason. %% %% The observed behaviour is that some outbound messages (that is, %% from a connecting process to the listening process) can take an %% unexpectedly long time to complete their journey. The more -%% connecting processes, the longer the possible delay it seems. +%% connecting processes, the longer it can take it seems. %% -%% eg. (With F = fun send_from_multiple_clients/2.) -%% -%% 5> F(2, 1024). +%% eg. 5> send_one_from_many(2, 1024). %% {875,[{128,116},{113,139}]} -%% 6> F(4, 1024). +%% 6> send_one_from_many(4, 1024). %% {2995290,[{2994022,250},{2994071,80},{200,130},{211,113}]} -%% 7> F(8, 1024). +%% 7> send_one_from_many(8, 1024). %% {8997461,[{8996161,116}, %% {2996471,86}, %% {2996278,116}, @@ -213,7 +128,7 @@ send_from_multiple_clients(_) -> %% {213,159}, %% {373,173}, %% {376,118}]} -%% 8> F(8, 1024). +%% 8> send_one_from_many(8, 1024). %% {21001891,[{20999968,128}, %% {8997891,172}, %% {8997927,91}, @@ -223,120 +138,279 @@ send_from_multiple_clients(_) -> %% {117,98}, %% {149,125}]} %% -%% This turns out to have been due to SCTP resends as a consequence of -%% the listener having an insufficient recbuf. Increasing the size -%% solves the problem. -%% - -send_from_multiple_clients(N, Sz) - when is_integer(N), 0 < N, is_integer(Sz), 0 < Sz -> - timer:tc(fun listen/2, [N, <<?MAGIC, 0:Sz/unit:8>>]). -%% listen/2 - -listen(N, Bin) -> +send_multiple(Clients, Msgs, Sz) + when is_integer(Clients), 0 < Clients, + is_integer(Msgs), 0 < Msgs, + is_integer(Sz), 0 < Sz -> + T0 = diameter_lib:now(), + {S, Res} = timer:tc(fun listen/3, [Clients, Msgs, Sz]), + report(T0, Res), + Ts = lists:append(Res), + Outgoing = [DT || {_,{_,_,DT},{_,_,_},_} <- Ts], + Incoming = [DT || {_,{_,_,_},{_,_,DT},_} <- Ts], + Diffs = [lists:max(L) - lists:min(L) || L <- [Outgoing, Incoming]], + {S, + S < ?FOREVER*1000, + Diffs, + [D < V || V <- [?VARIANCE*1000], D <- Diffs], + [T || T <- Ts, [] == [T || {_,{_,_,_},{_,_,_},_} <- [T]]], + Res}. + +%% listen/3 + +listen(Clients, Msgs, Sz) -> {ok, Sock} = open(), ok = gen_sctp:listen(Sock, true), {ok, PortNr} = inet:port(Sock), %% Spawn a middleman that in turn spawns N connecting processes, %% collects a list of exit reasons and then exits with the list as - %% reason. loop/3 returns when we receive this list from the + %% reason. accept/2 returns when we receive this list from the %% middleman's 'DOWN'. Self = self(), - Fun = fun() -> exit(connect2(Self, PortNr, Bin)) end, - {_, MRef} = spawn_monitor(fun() -> exit(fold(N, Fun)) end), - loop(Sock, MRef, Bin). + Fun = fun() -> exit(client(Self, PortNr, Msgs, Sz)) end, %% start clients + {_, MRef} = spawn_monitor(fun() -> exit(clients(Clients, Fun)) end), + accept_loop(Sock, MRef). -%% fold/2 +%% fclients/2 %% %% Spawn N processes and collect their exit reasons in a list. -fold(N, Fun) -> +clients(N, Fun) -> start(N, Fun), acc(N, []). +%% start/2 + start(0, _) -> ok; + start(N, Fun) -> spawn_monitor(Fun), start(N-1, Fun). +%% acc/2 + acc(0, Acc) -> Acc; + acc(N, Acc) -> receive {'DOWN', _MRef, process, _, RC} -> acc(N-1, [RC | Acc]) end. -%% loop/3 +%% accept_loop/2 -loop(Sock, MRef, Bin) -> +accept_loop(Sock, MRef) -> + ok = inet:setopts(Sock, [{active, once}]), receive - ?SCTP(Sock, {[#sctp_sndrcvinfo{assoc_id = Id}], B}) - when is_binary(B) -> - Sz = size(Bin), - {Sz, Bin} = {size(B), B}, %% assert - ok = send(Sock, Id, mark(Bin)), - loop(Sock, MRef, Bin); + ?SCTP(Sock, {_, #sctp_assoc_change{state = comm_up, + outbound_streams = OS, + assoc_id = Id}}) -> + Self = self(), + TPid = spawn(fun() -> assoc(monitor(process, Self), Id, OS) end), + NewSock = peeloff(Sock, Id, TPid), + TPid ! {peeloff, NewSock}, + accept_loop(Sock, MRef); ?SCTP(Sock, _) -> - loop(Sock, MRef, Bin); + accept_loop(Sock, MRef); {'DOWN', MRef, process, _, Reason} -> - Reason + Reason; + T -> + error(T) end. -%% connect2/3 +%% assoc/3 +%% +%% Server process that answers incoming messages as long as the parent +%% lives. + +assoc(MRef, _Id, OS) + when is_reference(MRef) -> + {peeloff, Sock} = receive T -> T end, + recv_loop(Sock, false, sender(Sock, false, OS), MRef). + +%% recv_loop/4 + +recv_loop(Sock, Id, Pid, MRef) -> + ok = inet:setopts(Sock, [{active, once}]), + recv(Sock, Id, Pid, MRef, receive T -> T end). + +%% recv/5 + +%% Association id can change on a peeloff socket on some versions of +%% Solaris. +recv(Sock, + false, + Pid, + MRef, + ?SCTP(Sock, {[#sctp_sndrcvinfo{assoc_id = Id}], _}) + = T) -> + Pid ! {assoc_id, Id}, + recv(Sock, Id, Pid, MRef, T); + +recv(Sock, Id, Pid, MRef, ?SCTP(Sock, {[#sctp_sndrcvinfo{assoc_id = I}], B})) + when is_binary(B) -> + T2 = diameter_lib:now(), + Id = I, %% assert + <<?MAGIC, Bin/binary>> = B, %% assert + {[_,_,_,Sz] = L, Bytes} = unmark(Bin), + Sz = size(Bin) - Bytes, %% assert + <<_:Bytes/binary, Body:Sz/binary>> = Bin, + send(Pid, [T2|L], Body), %% answer + recv_loop(Sock, Id, Pid, MRef); + +recv(Sock, Id, Pid, MRef, ?SCTP(Sock, _)) -> + recv_loop(Sock, Id, Pid, MRef); + +recv(_, _, _, MRef, {'DOWN', MRef, process, _, Reason}) -> + Reason; + +recv(_, _, _, _, T) -> + error(T). -connect2(Pid, PortNr, Bin) -> - monitor(process, Pid), +%% send/3 - {ok, Sock} = open(), - ok = gen_sctp:connect_init(Sock, ?ADDR, PortNr, []), - Id = assoc(Sock), +send(Pid, Header, Body) -> + Pid ! {send, Header, Body}. - %% T1 = time before send - %% T2 = time after listening process received our message - %% T3 = time after reply is received +%% sender/3 +%% +%% Start a process that sends, so as not to block the controlling process. - T1 = diameter_lib:now(), - ok = send(Sock, Id, Bin), - T2 = unmark(recv(Sock, Id)), - T3 = diameter_lib:now(), - {diameter_lib:micro_diff(T2, T1), %% Outbound - diameter_lib:micro_diff(T3, T2)}. %% Inbound +sender(Sock, Id, OS) -> + Pid = self(), + spawn(fun() -> send_loop(Sock, Id, OS, 1, monitor(process, Pid)) end). -%% recv/2 +%% send_loop/5 -recv(Sock, Id) -> +send_loop(Sock, Id, OS, N, MRef) -> receive - ?SCTP(Sock, {[#sctp_sndrcvinfo{assoc_id = I}], Bin}) - when is_binary(Bin) -> - Id = I, %% assert - Bin; - ?SCTP(S, _) -> - Sock = S, %% assert - recv(Sock, Id); + {assoc_id, I} -> + send_loop(Sock, I, OS, N, MRef); + {send, L, Body} -> + Stream = N rem OS, + ok = send(Sock, Id, Stream, mark(Body, [N, Stream | L])), + send_loop(Sock, Id, OS, N+1, MRef); + {'DOWN', MRef, process, _, _} = T -> + T; T -> - exit(T) + error(T) end. -%% send/3 +%% peeloff/3 + +peeloff(LSock, Id, TPid) -> + {ok, Sock} = gen_sctp:peeloff(LSock, Id), + ok = gen_sctp:controlling_process(Sock, TPid), + Sock. + +%% client/4 + +client(Pid, PortNr, Msgs, Sz) -> + monitor(process, Pid), + {ok, Sock} = open(), + ok = gen_sctp:connect_init(Sock, ?ADDR, PortNr, []), + recv_loop(Sock, Msgs, Sz). -send(Sock, Id, Bin) -> - gen_sctp:send(Sock, Id, 0, Bin). +%% recv_loop/3 -%% mark/1 +recv_loop(_, 0, T) -> + [_,_|Acc] = T, + Acc; + +recv_loop(Sock, Msgs, T) -> + ok = inet:setopts(Sock, [{active, once}]), + {I, NewT} = recv(Sock, Msgs, T, receive X -> X end), + recv_loop(Sock, Msgs - I, NewT). + +%% recv/4 + +recv(Sock, Msgs, Sz, ?SCTP(Sock, {_, #sctp_assoc_change{} = A})) -> + #sctp_assoc_change{state = comm_up, %% assert + assoc_id = Id, + outbound_streams = OS} + = A, + true = is_integer(Sz), %% assert + send_n(Msgs, sender(Sock, Id, OS), Sz), + {0, [Id, OS]}; + +recv(Sock, _, T, ?SCTP(Sock, {[#sctp_sndrcvinfo{assoc_id = Id}], Bin})) -> + T4 = diameter_lib:now(), + [Id, OS | Acc] = T, + {1, [Id, OS, stat(T4, Bin) | Acc]}; + +recv(Sock, _, T, ?SCTP(Sock, _)) -> + {0, [_,_|_] = T}; + +recv(_, _, _, T) -> + error(T). + +%% send_n/3 +%% +%% Send messages to the server from dedicated processes. + +send_n(0, _, _) -> + ok; -mark(Bin) -> - Info = term_to_binary(diameter_lib:now()), +send_n(N, Pid, Sz) -> + M = rand:uniform(255), + send(Pid, [Sz], binary:copy(<<M>>, Sz)), + send_n(N-1, Pid, Sz). + +%% send/4 + +send(Sock, Id, Stream, Bin) -> + case gen_sctp:send(Sock, Id, Stream, <<?MAGIC, Bin/binary>>) of + {error, eagain} -> + send(Sock, Id, Stream, Bin); + RC -> + RC + end. + +%% stat/2 + +stat(T4, <<?MAGIC, Bin/binary>>) -> + %% T1 = time at send + %% T2 = time at reception by server + %% T3 = time at reception by server's sender + %% T4 = time at reception of answer + + {[T3,NI,SI,T2,T1,NO,SO,Sz], Bytes} = unmark(Bin), + + Sz = size(Bin) - Bytes, %% assert + + {T1, + {NO, SO, diameter_lib:micro_diff(T2, T1)}, %% Outbound + {NI, SI, diameter_lib:micro_diff(T4, T3)}, %% Inbound + T4}. + +%% mark/2 + +mark(Bin, T) -> + Info = term_to_binary([diameter_lib:now() | T]), <<Info/binary, Bin/binary>>. %% unmark/1 unmark(Bin) -> - binary_to_term(Bin). + T = binary_to_term(Bin), + {T, size(term_to_binary(T))}. + +%% =========================================================================== + +%% send_many_from_one/0 +%% +%% Demonstrates sluggish delivery of messages. + +send_many_from_one() -> + [{timetrap, {seconds, 30}}]. + +send_many_from_one(_) -> + ?OK = send_multiple(1, 128, 1024). %% =========================================================================== @@ -345,7 +419,7 @@ unmark(Bin) -> %% Demonstrates reception of a message that differs from that sent. receive_what_was_sent(_Config) -> - send_from_multiple_clients(1, 1024*32). %% fails + ?OK = send_multiple(1, 1, 1024*32). %% =========================================================================== @@ -357,16 +431,23 @@ open() -> %% open/1 open(Opts) -> - gen_sctp:open([{ip, ?ADDR}, {port, 0}, {active, true}, binary, + gen_sctp:open([{ip, ?ADDR}, {port, 0}, {active, false}, binary, + {sctp_initmsg, #sctp_initmsg{num_ostreams = ?STREAMS, + max_instreams = ?STREAMS}}, {recbuf, 1 bsl 16}, {sndbuf, 1 bsl 16} | Opts]). -%% assoc/1 +%% report/2 -assoc(Sock) -> - receive - ?SCTP(Sock, {_, #sctp_assoc_change{state = S, - assoc_id = Id}}) -> - comm_up = S, %% assert - Id - end. +report(T0, Ts) -> + ct:pal("~p~n", [lists:sort([sort([{diameter_lib:micro_diff(T1,T0), + OT, + IT, + diameter_lib:micro_diff(T4,T0)} + || {T1,OT,IT,T4} <- L]) + || L <- Ts])]). + +%% sort/1 + +sort(L) -> + lists:sort(fun({_,{N,_,_},_,_}, {_,{M,_,_},_,_}) -> N =< M end, L). diff --git a/lib/diameter/test/diameter_gen_tcp_SUITE.erl b/lib/diameter/test/diameter_gen_tcp_SUITE.erl index 2be2cf4b35..db42ea813e 100644 --- a/lib/diameter/test/diameter_gen_tcp_SUITE.erl +++ b/lib/diameter/test/diameter_gen_tcp_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2014-2015. All Rights Reserved. +%% Copyright Ericsson AB 2014-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -54,7 +54,7 @@ all() -> send_long(_) -> {Sock, SendF} = connection(), - B = list_to_binary(lists:duplicate(1 bsl 20, $X)), + B = binary:copy(<<$X>>, 1 bsl 20), ok = SendF(B), B = recv(Sock, size(B), []). diff --git a/lib/diameter/test/diameter_relay_SUITE.erl b/lib/diameter/test/diameter_relay_SUITE.erl index 5353688bf4..5d74e63b8d 100644 --- a/lib/diameter/test/diameter_relay_SUITE.erl +++ b/lib/diameter/test/diameter_relay_SUITE.erl @@ -302,7 +302,7 @@ stats(?RELAY1, L) -> %% RAR x 2 (send_timeout_[12]) {{{0,257,0},recv},3}, %% CEA {{{0,257,0},send},1}, %% " - {{{0,257,1},recv},1}, %% CER + {{{0,257,1},recv},1}, %% CER {{{0,257,1},send},3}, %% " {{{relay,0},recv,{'Result-Code',2001}},2}, %% STA x 2 (send[34]) {{{relay,0},recv,{'Result-Code',3005}},1}, %% ASA (send_loop) diff --git a/lib/diameter/test/diameter_traffic_SUITE.erl b/lib/diameter/test/diameter_traffic_SUITE.erl index 4c82d4dee2..84b41f14b7 100644 --- a/lib/diameter/test/diameter_traffic_SUITE.erl +++ b/lib/diameter/test/diameter_traffic_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -27,6 +27,8 @@ -export([suite/0, all/0, groups/0, + init_per_suite/1, + end_per_suite/1, init_per_group/2, end_per_group/2, init_per_testcase/2, @@ -90,7 +92,6 @@ send_multiple_filters_2/1, send_multiple_filters_3/1, send_anything/1, - outstanding/1, remove_transports/1, empty/1, stop_services/1, @@ -106,6 +107,9 @@ handle_error/6, handle_request/3]). +%% diameter_{tcp,sctp} callbacks +-export([message/3]). + -include("diameter.hrl"). -include("diameter_gen_base_rfc3588.hrl"). -include("diameter_gen_base_accounting.hrl"). @@ -145,21 +149,29 @@ %% Whether to decode stringish Diameter types to strings, or leave %% them as binary. --define(STRING_DECODES, [true, false]). +-define(STRING_DECODES, [false, true]). %% Which transport protocol to use. -define(TRANSPORTS, [tcp, sctp]). +%% Send from a dedicated process? +-define(SENDERS, [true, false]). + +%% Message callbacks from diameter_{tcp,sctp}? +-define(CALLBACKS, [true, false]). + -record(group, {transport, + strings, client_service, client_encoding, client_dict0, - client_strings, + client_sender, server_service, server_encoding, server_container, - server_strings}). + server_sender, + server_throttle}). %% Not really what we should be setting unless the message is sent in %% the common application but diameter doesn't care. @@ -190,8 +202,7 @@ {'Acct-Application-Id', [?DIAMETER_APP_ID_ACCOUNTING]}, {restrict_connections, false}, {string_decode, Decode}, - {incoming_maxlen, 1 bsl 21}, - {spawn_opt, [{min_heap_size, 5000}]} + {incoming_maxlen, 1 bsl 21} | [{application, [{dictionary, D}, {module, ?MODULE}, {answer_errors, callback}]} @@ -243,76 +254,128 @@ suite() -> [{timetrap, {seconds, 10}}]. all() -> - [start, result_codes, {group, traffic}, outstanding, empty, stop]. + [start, result_codes, {group, traffic}, empty, stop]. groups() -> - Ts = tc(), - Sctp = ?util:have_sctp(), - [{B, [P], Ts} || {B,P} <- [{true, shuffle}, {false, parallel}]] + [{P, [P], Ts} || Ts <- [tc(tc())], P <- [shuffle, parallel]] ++ - [{?util:name([T,R,D,A,C,SD,CD]), + [{?util:name([T,R,D,A,C,S,SS,ST,CS]), [], - [start_services, - add_transports, - result_codes, - {group, SD orelse CD}, - remove_transports, - stop_services]} + [{group, if S -> shuffle; not S -> parallel end}]} || T <- ?TRANSPORTS, - T /= sctp orelse Sctp, R <- ?ENCODINGS, D <- ?RFCS, A <- ?ENCODINGS, C <- ?CONTAINERS, - SD <- ?STRING_DECODES, - CD <- ?STRING_DECODES] + S <- ?STRING_DECODES, + SS <- ?SENDERS, + ST <- ?CALLBACKS, + CS <- ?SENDERS] ++ - [{traffic, [], [{group, ?util:name([T,R,D,A,C,SD,CD])} - || T <- ?TRANSPORTS, - T /= sctp orelse Sctp, - R <- ?ENCODINGS, - D <- ?RFCS, - A <- ?ENCODINGS, - C <- ?CONTAINERS, - SD <- ?STRING_DECODES, - CD <- ?STRING_DECODES]}]. + [{T, [], groups([[T,R,D,A,C,S,SS,ST,CS] + || R <- ?ENCODINGS, + D <- ?RFCS, + A <- ?ENCODINGS, + C <- ?CONTAINERS, + S <- ?STRING_DECODES, + SS <- ?SENDERS, + ST <- ?CALLBACKS, + CS <- ?SENDERS, + SS orelse CS])} %% avoid deadlock + || T <- ?TRANSPORTS] + ++ + [{traffic, [], [{group, T} || T <- ?TRANSPORTS]}]. + +%groups(_) -> %% debug +% Name = [sctp,record,rfc6733,record,pkt,false,false,false,false], +% [{group, ?util:name(Name)}]; +groups(Names) -> + [{group, ?util:name(L)} || L <- Names]. + +%tc([N|_]) -> %% debug +% [N]; +tc(L) -> + L. + +%% -------------------- + +init_per_suite(Config) -> + [{sctp, ?util:have_sctp()} | Config]. + +end_per_suite(_Config) -> + ok. + +%% -------------------- + +init_per_group(Name, Config) + when Name == shuffle; + Name == parallel -> + start_services(Config), + add_transports(Config); + +init_per_group(sctp = Name, Config) -> + {_, Sctp} = lists:keyfind(Name, 1, Config), + if Sctp -> + Config; + true -> + {skip, Name} + end; init_per_group(Name, Config) -> case ?util:name(Name) of - [T,R,D,A,C,SD,CD] -> + [T,R,D,A,C,S,SS,ST,CS] -> G = #group{transport = T, + strings = S, client_service = [$C|?util:unique_string()], client_encoding = R, client_dict0 = dict0(D), - client_strings = CD, + client_sender = CS, server_service = [$S|?util:unique_string()], server_encoding = A, server_container = C, - server_strings = SD}, - [{group, G} | Config]; + server_sender = SS, + server_throttle = ST}, + %% Limit the number of testcase, since the number of + %% groups is large. + All = ?util:scramble(tc()), + TCs = lists:sublist(All, rand:uniform(32)), + [{group, G}, {runlist, TCs} | Config]; _ -> Config end. +end_per_group(Name, Config) + when Name == shuffle; + Name == parallel -> + remove_transports(Config), + stop_services(Config); + end_per_group(_, _) -> ok. +%% -------------------- + %% Skip testcases that can reasonably fail under SCTP. init_per_testcase(Name, Config) -> - case [skip || #group{transport = sctp} - <- [proplists:get_value(group, Config)], - send_maxlen == Name - orelse send_long == Name] + TCs = proplists:get_value(runlist, Config, []), + Run = [] == TCs orelse lists:member(Name, TCs), + case [G || #group{transport = sctp} = G + <- [proplists:get_value(group, Config)]] of - [skip] -> + [_] when Name == send_maxlen; + Name == send_long -> {skip, sctp}; - [] -> + _ when not Run -> + {skip, random}; + _ -> [{testcase, Name} | Config] end. end_per_testcase(_, _) -> ok. +%% -------------------- + %% Testcases to run when services are started and connections %% established. tc() -> @@ -377,28 +440,34 @@ start(_Config) -> ok = diameter:start(). start_services(Config) -> - #group{client_service = CN, - client_strings = CD, - server_service = SN, - server_strings = SD} + #group{strings = S, + client_service = CN, + server_service = SN} = group(Config), - ok = diameter:start_service(SN, ?SERVICE(SN, SD)), + ok = diameter:start_service(SN, ?SERVICE(SN, S)), ok = diameter:start_service(CN, [{sequence, ?CLIENT_MASK} - | ?SERVICE(CN, CD)]). + | ?SERVICE(CN, S)]). add_transports(Config) -> #group{transport = T, client_service = CN, - server_service = SN} - = group(Config), + client_sender = CS, + server_service = SN, + server_sender = SS, + server_throttle = ST} + = group(Config), LRef = ?util:listen(SN, - T, + [T, + {sender, SS}, + {message_cb, ST andalso {?MODULE, message, [4]}} + | [{packet, hd(?util:scramble([false, raw]))} + || T == sctp andalso CS]], [{capabilities_cb, fun capx/2}, {pool_size, 8}, - {spawn_opt, [{min_heap_size, 8096}]}, - {applications, apps(rfc3588)}]), + {applications, apps(rfc3588)}] + ++ [{spawn_opt, {erlang, spawn, []}} || CS]), Cs = [?util:connect(CN, - T, + [T, {sender, CS}], LRef, [{id, Id}, {capabilities, [{'Origin-State-Id', origin(Id)}]}, @@ -415,11 +484,6 @@ apps(D0) -> D = dict0(D0), [acct(D), D]. -%% Ensure there are no outstanding requests in request table. -outstanding(_Config) -> - [] = [T || T <- ets:tab2list(diameter_request), - is_atom(element(1,T))]. - remove_transports(Config) -> #group{client_service = CN, server_service = SN} @@ -689,14 +753,14 @@ send_unexpected_mandatory(Config) -> %% Send something long that will be fragmented by TCP. send_long(Config) -> Req = ['STR', {'Termination-Cause', ?LOGOUT}, - {'User-Name', [lists:duplicate(1 bsl 20, $X)]}], + {'User-Name', [binary:copy(<<$X>>, 1 bsl 20)]}], ['STA', {'Session-Id', _}, {'Result-Code', ?SUCCESS} | _] = call(Config, Req). %% Send something longer than the configure incoming_maxlen. send_maxlen(Config) -> Req = ['STR', {'Termination-Cause', ?LOGOUT}, - {'User-Name', [lists:duplicate(1 bsl 21, $X)]}], + {'User-Name', [binary:copy(<<$X>>, 1 bsl 21)]}], {timeout, _} = call(Config, Req). %% Send something for which pick_peer finds no suitable peer. @@ -875,7 +939,7 @@ group(Config) -> #group{} = proplists:get_value(group, Config). string(V, Config) -> - #group{client_strings = B} = group(Config), + #group{strings = B} = group(Config), decode(V,B). decode(S, true) @@ -995,7 +1059,7 @@ pick_peer(Peers, _, [$C|_], _State, {send_detach, Group}, _, {_,_}) -> find(#group{client_service = CN, server_encoding = A, server_container = C}, - Peers) -> + [_|_] = Peers) -> Id = {A,C}, [P] = [P || P <- Peers, id(Id, P, CN)], {ok, P}. @@ -1429,3 +1493,33 @@ request(#diameter_base_STR{'Session-Id' = SId}, %% send_error/send_timeout request(#diameter_base_RAR{}, _Caps) -> receive after 2000 -> {protocol_error, ?TOO_BUSY} end. + +%% message/3 +%% +%% Limit the number of messages received. More can be received if read +%% in the same packet. + +message(recv = D, {[_], Bin}, N) -> + message(D, Bin, N); +message(Dir, #diameter_packet{bin = Bin}, N) -> + message(Dir, Bin, N); + +%% incoming request +message(recv, <<_:32, 1, _/bits>> = Bin, N) -> + [Bin, 1 < N, fun ?MODULE:message/3, N-1]; + +%% incoming answer +message(recv, Bin, _) -> + [Bin]; + +%% outgoing +message(send, Bin, _) -> + [Bin]; + +%% sent request +message(ack, <<_:32, 1, _/bits>>, _) -> + []; + +%% sent answer or discarded request +message(ack, _, N) -> + [0 =< N, fun ?MODULE:message/3, N+1]. diff --git a/lib/diameter/test/diameter_transport_SUITE.erl b/lib/diameter/test/diameter_transport_SUITE.erl index c94f46b7a5..9d981d0a2b 100644 --- a/lib/diameter/test/diameter_transport_SUITE.erl +++ b/lib/diameter/test/diameter_transport_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -294,10 +294,17 @@ init(gen_accept, {Prot, Ref}) -> {ok, PortNr} = inet:port(LSock), true = diameter_reg:add_new(?TEST_LISTENER(Ref, PortNr)), - %% Accept a connection, receive a message and send it back. + %% Accept a connection, receive a message send it back, and wait + %% for the peer to close the connection. {ok, Sock} = gen_accept(Prot, LSock), Bin = gen_recv(Prot, Sock), - ok = gen_send(Prot, Sock, Bin); + ok = gen_send(Prot, Sock, Bin), + receive + {tcp_closed, Sock} = T -> + T; + ?SCTP(Sock, {_, #sctp_assoc_change{}}) = T -> + T + end; init(connect, {Prot, Ref}) -> %% Lookup the peer's listening socket. @@ -311,12 +318,7 @@ init(connect, {Prot, Ref}) -> %% Send a message and receive it back. Bin = make_msg(), TPid ! ?TMSG({send, Bin}), - Bin = bin(Prot, ?RECV(?TMSG({recv, P}), P)), - - %% Expect the transport process to die as a result of the peer - %% closing the connection. - MRef = erlang:monitor(process, TPid), - ?RECV({'DOWN', MRef, process, _, _}). + Bin = bin(Prot, ?RECV(?TMSG({recv, P}), P)). bin(sctp, #diameter_packet{bin = Bin}) -> Bin; @@ -336,15 +338,11 @@ make_msg() -> <<1:8, Len:24, Bin/binary>>. %% crypto:rand_bytes/1 isn't available on all platforms (since openssl -%% isn't) so roll our own. +%% isn't) so roll our own. Not particularly random, but less verbose +%% in trace. rand_bytes(N) -> - rand_bytes(N, <<>>). - -rand_bytes(0, Bin) -> - Bin; -rand_bytes(N, Bin) -> Oct = rand:uniform(256) - 1, - rand_bytes(N-1, <<Oct, Bin/binary>>). + binary:copy(<<Oct>>, N). %% =========================================================================== diff --git a/lib/diameter/test/diameter_util.erl b/lib/diameter/test/diameter_util.erl index cca28dd23c..03f79096ac 100644 --- a/lib/diameter/test/diameter_util.erl +++ b/lib/diameter/test/diameter_util.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2016. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -172,18 +172,7 @@ recvl([{MRef, F} | L], Ref, Fun, Acc) -> %% Sort a list into random order. scramble(L) -> - foldl(fun(true, _, S, false) -> S end, - false, - [[fun s/1, L]]). - -s(L) -> - s([], L). - -s(Acc, []) -> - Acc; -s(Acc, L) -> - {H, [T|Rest]} = lists:split(rand:uniform(length(L)) - 1, L), - s([T|Acc], H ++ Rest). + [X || {_,X} <- lists:sort([{rand:uniform(), T} || T <- L])]. %% --------------------------------------------------------------------------- %% unique_string/0 @@ -195,21 +184,22 @@ unique_string() -> %% have_sctp/0 have_sctp() -> - case erlang:system_info(system_architecture) of - %% We do not support the sctp version present in solaris - %% version "sparc-sun-solaris2.10", that behaves differently - %% from later versions and linux - "sparc-sun-solaris2.10" -> - false; - _-> - case gen_sctp:open() of - {ok, Sock} -> - gen_sctp:close(Sock), - true; - {error, E} when E == eprotonosupport; - E == esocktnosupport -> %% fail on any other reason - false - end + have_sctp(erlang:system_info(system_architecture)). + +%% Don't run SCTP on platforms where it's either known to be flakey or +%% isn't available. + +have_sctp("sparc-sun-solaris2.10") -> + false; + +have_sctp(_) -> + case gen_sctp:open() of + {ok, Sock} -> + gen_sctp:close(Sock), + true; + {error, E} when E == eprotonosupport; + E == esocktnosupport -> %% fail on any other reason + false end. %% --------------------------------------------------------------------------- @@ -313,17 +303,23 @@ listen(SvcName, Prot, Opts) -> connect(Client, Prot, LRef) -> connect(Client, Prot, LRef, []). -connect(Client, Prot, LRef, Opts) -> +connect(Client, ProtOpts, LRef, Opts) -> + Prot = head(ProtOpts), [PortNr] = lport(Prot, LRef), Client = diameter:service_info(Client, name), %% assert true = diameter:subscribe(Client), - Ref = add_transport(Client, {connect, opts(Prot, PortNr) ++ Opts}), + Ref = add_transport(Client, {connect, opts(ProtOpts, PortNr) ++ Opts}), true = transport(Client, Ref), %% assert diameter_lib:for_n(fun(_) -> ok = up(Client, Ref, Prot, PortNr) end, proplists:get_value(pool_size, Opts, 1)), Ref. +head([T|_]) -> + T; +head(T) -> + T. + up(Client, Ref, Prot, PortNr) -> receive {diameter_event, Client, {up, Ref, _, _, _}} -> ok @@ -366,10 +362,13 @@ tmod(sctp) -> tmod(any) -> [diameter_sctp, diameter_tcp]. -opts(Prot, T) -> - tmo(T, lists:append([[{transport_module, M}, {transport_config, C}] +opts([Prot | Opts], T) -> + tmo(T, lists:append([[{transport_module, M}, {transport_config, C ++ Opts}] || M <- tmod(Prot), - C <- [cfg(M,T) ++ cfg(M) ++ cfg(T)]])). + C <- [cfg(M,T) ++ cfg(M) ++ cfg(T)]])); + +opts(Prot, T) -> + opts([Prot], T). tmo(listen, Opts) -> Opts; diff --git a/lib/diameter/test/diameter_watchdog_SUITE.erl b/lib/diameter/test/diameter_watchdog_SUITE.erl index 6d22ddcc18..39c4f051a5 100644 --- a/lib/diameter/test/diameter_watchdog_SUITE.erl +++ b/lib/diameter/test/diameter_watchdog_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2010-2015. All Rights Reserved. +%% Copyright Ericsson AB 2010-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -44,13 +44,8 @@ -export([peer_up/3, peer_down/3]). -%% gen_tcp-ish interface --export([listen/2, - accept/1, - connect/3, - send/2, - setopts/2, - close/1]). +%% diameter_tcp message_cb +-export([message/3]). -include("diameter.hrl"). -include("diameter_ct.hrl"). @@ -161,9 +156,9 @@ reopen(Type, Test, Ref, Wd, N, M) -> reopen(Type, Test, SvcName, TRef, Wd, N, M). cfg(Type, Type, Wd) -> - {Wd, [], []}; + {Wd, [], false}; cfg(_Type, _Test, _Wd) -> - {?WD(?PEER_WD), [{okay, 0}], [{module, ?MODULE}]}. + {?WD(?PEER_WD), [{okay, 0}], true}. %% reopen/7 @@ -346,7 +341,7 @@ recv_reopen(listen, Ref) -> %% reg/3 %% %% Lookup the pid of the transport process and publish a term for -%% send/2 to lookup. +%% message/3 to lookup. reg(TRef, SvcName, T) -> TPid = tpid(TRef, diameter:service_info(SvcName, transport)), true = diameter_reg:add_new({?MODULE, TPid, T}). @@ -394,7 +389,7 @@ suspect(_) -> suspect(Type, Fake, Ref, N) when is_reference(Ref) -> {SvcName, TRef} - = start(Type, Ref, {?WD(10000), [{suspect, N}], mod(Fake)}), + = start(Type, Ref, {?WD(10000), [{suspect, N}], Fake}), {initial, okay} = ?WD_EVENT(TRef), suspect(TRef, Fake, SvcName, N); @@ -436,11 +431,6 @@ abuse([F|A], Test) -> abuse(F, Test) -> abuse([F], Test). -mod(true) -> - [{module, ?MODULE}]; -mod(false) -> - []. - %% =========================================================================== %% # okay/1 %% =========================================================================== @@ -456,7 +446,7 @@ okay(Type, Fake, Ref, N) {SvcName, TRef} = start(Type, Ref, {?WD(10000), [{okay, choose(Fake, 0, N)}], - mod(Fake)}), + Fake}), {initial, okay} = ?WD_EVENT(TRef), okay(TRef, Fake, @@ -515,12 +505,17 @@ start(Type, Ref, T) -> true = diameter_reg:add_new({Type, Ref, Name}), {Name, TRef}. -opts(Type, Ref, {Timer, Config, Mod}) -> +opts(Type, Ref, {Timer, Config, Fake}) + when is_boolean(Fake) -> [{transport_module, diameter_tcp}, - {transport_config, Mod ++ [{ip, ?ADDR}, {port, 0}] ++ cfg(Type, Ref)}, + {transport_config, mod(Fake) ++ [{ip, ?ADDR}, {port, 0}] + ++ cfg(Type, Ref)}, {watchdog_timer, Timer}, {watchdog_config, Config}]. +mod(B) -> + [{message_cb, [fun message/3, capx]} || B]. + cfg(listen, _) -> []; cfg(connect, Ref) -> @@ -531,37 +526,29 @@ cfg(connect, Ref) -> %% =========================================================================== -listen(PortNr, Opts) -> - gen_tcp:listen(PortNr, Opts). - -accept(LSock) -> - gen_tcp:accept(LSock). +%% message/3 -connect(Addr, Port, Opts) -> - gen_tcp:connect(Addr, Port, Opts). +message(send, Bin, X) -> + send(Bin, X); -setopts(Sock, Opts) -> - inet:setopts(Sock, Opts). +message(recv, Bin, _) -> + [Bin]; -send(Sock, Bin) -> - send(getr(config), Sock, Bin). - -close(Sock) -> - gen_tcp:close(Sock). +message(_, _, _) -> + []. -%% send/3 +%% send/2 %% First outgoing message from a new transport process is CER/CEA. %% Remaining outgoing messages are either DWR or DWA. -send(undefined, Sock, Bin) -> - <<_:32, _:8, 257:24, _/binary>> = Bin, - putr(config, init), - gen_tcp:send(Sock, Bin); +send(Bin, capx) -> + <<_:32, _:8, 257:24, _/binary>> = Bin, %% assert on CER/CEA + [Bin, fun message/3, init]; %% Outgoing DWR: fake reception of DWA. Use the fact that AVP values %% are ignored. This is to ensure that the peer's watchdog state %% transitions are only induced by responses to messages it sends. -send(_, Sock, <<_:32, 1:1, _:7, 280:24, _:32, EId:32, HId:32, _/binary>>) -> +send(<<_:32, 1:1, _:7, 280:24, _:32, EId:32, HId:32, _/binary>>, _) -> Pkt = #diameter_packet{header = #diameter_header{version = 1, end_to_end_id = EId, hop_by_hop_id = HId}, @@ -569,47 +556,36 @@ send(_, Sock, <<_:32, 1:1, _:7, 280:24, _:32, EId:32, HId:32, _/binary>>) -> {'Origin-Host', "XXX"}, {'Origin-Realm', ?REALM}]}, #diameter_packet{bin = Bin} = diameter_codec:encode(?BASE, Pkt), - self() ! {tcp, Sock, Bin}, - ok; + [recv, Bin]; %% First outgoing DWA. -send(init, Sock, Bin) -> +send(Bin, init) -> [{{?MODULE, _, T}, _}] = diameter_reg:wait({?MODULE, self(), '_'}), - putr(config, T), - send(Sock, Bin); + send(Bin, T); %% First transport process. -send({SvcName, {_,_,_} = T}, Sock, Bin) -> +send(Bin, {SvcName, {_,_,_} = T}) -> [{'Origin-Host', _} = OH, {'Origin-Realm', _} = OR | _] = ?SERVICE(SvcName), putr(origin, [OH, OR]), - putr(config, T), - send(Sock, Bin); + send(Bin, T); %% Discard DWA, failback after another timeout in the peer. -send({Wd, 0 = No, Msg}, Sock, Bin) -> +send(Bin, {Wd, 0 = No, Msg}) -> Origin = getr(origin), - spawn(fun() -> failback(?ONE_WD(Wd), Msg, Sock, Bin, Origin) end), - putr(config, No), - ok; + [{defer, ?ONE_WD(Wd), [msg(Msg, Bin, Origin)]}, fun message/3, No]; %% Send DWA while we're in the mood (aka 0 < N). -send({Wd, N, Msg}, Sock, Bin) -> - putr(config, {Wd, N-1, Msg}), - gen_tcp:send(Sock, Bin); +send(Bin, {Wd, N, Msg}) -> + [Bin, fun message/3, {Wd, N-1, Msg}]; %% Discard DWA. -send(0, _Sock, _Bin) -> - ok; +send(_Bin, 0 = No) -> + [fun message/3, No]; %% Send DWA. -send(N, Sock, <<_:32, 0:1, _:7, 280:24, _/binary>> = Bin) -> - putr(config, N-1), - gen_tcp:send(Sock, Bin). - -failback(Tmo, Msg, Sock, Bin, Origin) -> - timer:sleep(Tmo), - ok = gen_tcp:send(Sock, msg(Msg, Bin, Origin)). +send(<<_:32, 0:1, _:7, 280:24, _/binary>> = DWA, N) -> + [DWA, fun message/3, N-1]. %% msg/2 diff --git a/lib/diameter/test/modules.mk b/lib/diameter/test/modules.mk index 80d0f8d59c..0c73adca12 100644 --- a/lib/diameter/test/modules.mk +++ b/lib/diameter/test/modules.mk @@ -1,7 +1,7 @@ # %CopyrightBegin% # -# Copyright Ericsson AB 2010-2015. All Rights Reserved. +# Copyright Ericsson AB 2010-2017. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -31,7 +31,6 @@ MODULES = \ diameter_codec_test \ diameter_config_SUITE \ diameter_compiler_SUITE \ - diameter_dict_SUITE \ diameter_distribution_SUITE \ diameter_dpr_SUITE \ diameter_event_SUITE \ diff --git a/lib/diameter/vsn.mk b/lib/diameter/vsn.mk index 94d9d72a48..4801f542fb 100644 --- a/lib/diameter/vsn.mk +++ b/lib/diameter/vsn.mk @@ -17,5 +17,5 @@ # %CopyrightEnd% APPLICATION = diameter -DIAMETER_VSN = 1.12.2 +DIAMETER_VSN = 2.0 APP_VSN = $(APPLICATION)-$(DIAMETER_VSN)$(PRE_VSN) diff --git a/lib/hipe/cerl/erl_types.erl b/lib/hipe/cerl/erl_types.erl index 4d7f1be513..0883a69918 100644 --- a/lib/hipe/cerl/erl_types.erl +++ b/lib/hipe/cerl/erl_types.erl @@ -228,7 +228,7 @@ -export([t_is_identifier/1]). -endif. --export_type([erl_type/0, opaques/0, type_table/0, mod_records/0, +-export_type([erl_type/0, opaques/0, type_table/0, var_table/0, cache/0]). %%-define(DEBUG, true). @@ -366,15 +366,17 @@ -type opaques() :: [erl_type()] | 'universe'. +-type file_line() :: {file:name(), erl_anno:line()}. -type record_key() :: {'record', atom()}. -type type_key() :: {'type' | 'opaque', mfa()}. --type record_value() :: [{atom(), erl_parse:abstract_expr(), erl_type()}]. --type type_value() :: {{module(), {file:name(), erl_anno:line()}, +-type field() :: {atom(), erl_parse:abstract_expr(), erl_type()}. +-type record_value() :: {file_line(), + [{RecordSize :: non_neg_integer(), [field()]}]}. +-type type_value() :: {{module(), file_line(), erl_parse:abstract_type(), ArgNames :: [atom()]}, erl_type()}. -type type_table() :: #{record_key() | type_key() => record_value() | type_value()}. --type mod_records() :: dict:dict(module(), type_table()). -opaque var_table() :: #{atom() => erl_type()}. @@ -528,7 +530,9 @@ list_contains_opaque(List, Opaques) -> 'error' | {'ok', erl_type(), erl_type()}. t_find_opaque_mismatch(T1, T2, Opaques) -> - catch t_find_opaque_mismatch(T1, T2, T2, Opaques). + try t_find_opaque_mismatch(T1, T2, T2, Opaques) + catch throw:error -> error + end. t_find_opaque_mismatch(?any, _Type, _TopType, _Opaques) -> error; t_find_opaque_mismatch(?none, _Type, _TopType, _Opaques) -> throw(error); @@ -580,8 +584,9 @@ t_find_opaque_mismatch_ordlists(L1, L2, TopType, Opaques) -> t_find_opaque_mismatch_list(List). t_find_opaque_mismatch_lists(L1, L2, _TopType, Opaques) -> - List = [catch t_find_opaque_mismatch(T1, T2, T2, Opaques) || - T1 <- L1, T2 <- L2], + List = [try t_find_opaque_mismatch(T1, T2, T2, Opaques) + catch throw:error -> error + end || T1 <- L1, T2 <- L2], t_find_opaque_mismatch_list(List). t_find_opaque_mismatch_list([]) -> throw(error); @@ -611,7 +616,9 @@ t_find_unknown_opaque(T1, T2, Opaques) -> %% is assumed to be taken from the contract. t_decorate_with_opaque(T1, T2, Opaques) -> - case t_is_equal(T1, T2) orelse not t_contains_opaque(T2) of + case + Opaques =:= [] orelse t_is_equal(T1, T2) orelse not t_contains_opaque(T2) + of true -> T1; false -> T = t_inf(T1, T2), @@ -4447,11 +4454,11 @@ mod_name(Mod, Name) -> -type cache_key() :: {module(), atom(), expand_depth(), [erl_type()], type_names()}. -type mod_type_table() :: ets:tid(). +-type mod_records() :: dict:dict(module(), type_table()). -record(cache, { types = maps:new() :: #{cache_key() => {erl_type(), expand_limit()}}, - mod_recs = {mrecs, dict:new()} :: 'undefined' - | {'mrecs', mod_records()} + mod_recs = {mrecs, dict:new()} :: {'mrecs', mod_records()} }). -opaque cache() :: #cache{}. @@ -5331,21 +5338,17 @@ is_erl_type(_) -> false. 'error' | {type_table(), cache()}. lookup_module_types(Module, CodeTable, Cache) -> - #cache{mod_recs = ModRecs} = Cache, - case ModRecs of - undefined -> error; - {mrecs, MRecs} -> - case dict:find(Module, MRecs) of - {ok, R} -> - {R, Cache}; - error -> - try ets:lookup_element(CodeTable, Module, 2) of - R -> - NewMRecs = dict:store(Module, R, MRecs), - {R, Cache#cache{mod_recs = {mrecs, NewMRecs}}} - catch - _:_ -> error - end + #cache{mod_recs = {mrecs, MRecs}} = Cache, + case dict:find(Module, MRecs) of + {ok, R} -> + {R, Cache}; + error -> + try ets:lookup_element(CodeTable, Module, 2) of + R -> + NewMRecs = dict:store(Module, R, MRecs), + {R, Cache#cache{mod_recs = {mrecs, NewMRecs}}} + catch + _:_ -> error end end. |