diff options
Diffstat (limited to 'lib')
27 files changed, 2834 insertions, 770 deletions
diff --git a/lib/common_test/src/test_server_node.erl b/lib/common_test/src/test_server_node.erl index 0b406c54cc..92c610730e 100644 --- a/lib/common_test/src/test_server_node.erl +++ b/lib/common_test/src/test_server_node.erl @@ -18,11 +18,11 @@ %% %CopyrightEnd% %% -module(test_server_node). --compile(r12). +-compile(r16). %%% %%% The same compiled code for this module must be possible to load -%%% in R12B and later. +%%% in R16B and later. %%% %% Test Controller interface diff --git a/lib/compiler/src/beam_type.erl b/lib/compiler/src/beam_type.erl index 2b5d558ee4..7e9a243ada 100644 --- a/lib/compiler/src/beam_type.erl +++ b/lib/compiler/src/beam_type.erl @@ -26,6 +26,8 @@ -import(lists, [filter/2,foldl/3,keyfind/3,member/2, reverse/1,reverse/2,sort/1]). +-define(UNICODE_INT, {integer,{0,16#10FFFF}}). + -spec module(beam_utils:module_code(), [compile:option()]) -> {'ok',beam_utils:module_code()}. @@ -494,6 +496,10 @@ update({test,test_arity,_Fail,[Src,Arity]}, Ts0) -> tdb_update([{Src,{tuple,Arity,[]}}], Ts0); update({test,is_map,_Fail,[Src]}, Ts0) -> tdb_update([{Src,map}], Ts0); +update({get_map_elements,_,Src,{list,Elems0}}, Ts0) -> + {_Ss,Ds} = beam_utils:split_even(Elems0), + Elems = [{Dst,kill} || Dst <- Ds], + tdb_update([{Src,map}|Elems], Ts0); update({test,is_nonempty_list,_Fail,[Src]}, Ts0) -> tdb_update([{Src,nonempty_list}], Ts0); update({test,is_eq_exact,_,[Reg,{atom,_}=Atom]}, Ts) -> @@ -507,10 +513,39 @@ update({test,is_eq_exact,_,[Reg,{atom,_}=Atom]}, Ts) -> end; update({test,is_record,_Fail,[Src,Tag,{integer,Arity}]}, Ts) -> tdb_update([{Src,{tuple,Arity,[Tag]}}], Ts); -update({test,_Test,_Fail,_Other}, Ts) -> - Ts; + +%% Binary matching + update({test,bs_get_integer2,_,_,Args,Dst}, Ts) -> tdb_update([{Dst,get_bs_integer_type(Args)}], Ts); +update({test,bs_get_utf8,_,_,_,Dst}, Ts) -> + tdb_update([{Dst,?UNICODE_INT}], Ts); +update({test,bs_get_utf16,_,_,_,Dst}, Ts) -> + tdb_update([{Dst,?UNICODE_INT}], Ts); +update({test,bs_get_utf32,_,_,_,Dst}, Ts) -> + tdb_update([{Dst,?UNICODE_INT}], Ts); +update({bs_init,_,_,_,_,Dst}, Ts) -> + tdb_update([{Dst,kill}], Ts); +update({bs_put,_,_,_}, Ts) -> + Ts; +update({bs_save2,_,_}, Ts) -> + Ts; +update({bs_restore2,_,_}, Ts) -> + Ts; +update({bs_context_to_binary,Dst}, Ts) -> + tdb_update([{Dst,kill}], Ts); +update({test,bs_start_match2,_,_,_,Dst}, Ts) -> + tdb_update([{Dst,kill}], Ts); +update({test,bs_get_binary2,_,_,_,Dst}, Ts) -> + tdb_update([{Dst,kill}], Ts); +update({test,bs_get_float2,_,_,_,Dst}, Ts) -> + tdb_update([{Dst,float}], Ts); + +update({test,_Test,_Fail,_Other}, Ts) -> + Ts; + +%% Calls + update({call_ext,Ar,{extfunc,math,Math,Ar}}, Ts) -> case is_math_bif(Math, Ar) of true -> tdb_update([{{x,0},float}], Ts); @@ -537,9 +572,10 @@ update({call_ext,3,{extfunc,erlang,setelement,3}}, Ts0) -> update({call,_Arity,_Func}, Ts) -> tdb_kill_xregs(Ts); update({call_ext,_Arity,_Func}, Ts) -> tdb_kill_xregs(Ts); update({make_fun2,_,_,_,_}, Ts) -> tdb_kill_xregs(Ts); +update({call_fun, _}, Ts) -> tdb_kill_xregs(Ts); +update({apply, _}, Ts) -> tdb_kill_xregs(Ts); + update({line,_}, Ts) -> Ts; -update({bs_save2,_,_}, Ts) -> Ts; -update({bs_restore2,_,_}, Ts) -> Ts; %% The instruction is unknown. Kill all information. update(_I, _Ts) -> tdb_new(). diff --git a/lib/compiler/src/beam_validator.erl b/lib/compiler/src/beam_validator.erl index c26e5719aa..ca60e1b2de 100644 --- a/lib/compiler/src/beam_validator.erl +++ b/lib/compiler/src/beam_validator.erl @@ -623,17 +623,17 @@ valfun_4({test,bs_skip_utf16,{f,Fail},[Ctx,Live,_]}, Vst) -> valfun_4({test,bs_skip_utf32,{f,Fail},[Ctx,Live,_]}, Vst) -> validate_bs_skip_utf(Fail, Ctx, Live, Vst); valfun_4({test,bs_get_integer2,{f,Fail},Live,[Ctx,_,_,_],Dst}, Vst) -> - validate_bs_get(Fail, Ctx, Live, Dst, Vst); + validate_bs_get(Fail, Ctx, Live, {integer, []}, Dst, Vst); valfun_4({test,bs_get_float2,{f,Fail},Live,[Ctx,_,_,_],Dst}, Vst) -> - validate_bs_get(Fail, Ctx, Live, Dst, Vst); + validate_bs_get(Fail, Ctx, Live, {float, []}, Dst, Vst); valfun_4({test,bs_get_binary2,{f,Fail},Live,[Ctx,_,_,_],Dst}, Vst) -> - validate_bs_get(Fail, Ctx, Live, Dst, Vst); + validate_bs_get(Fail, Ctx, Live, term, Dst, Vst); valfun_4({test,bs_get_utf8,{f,Fail},Live,[Ctx,_],Dst}, Vst) -> - validate_bs_get(Fail, Ctx, Live, Dst, Vst); + validate_bs_get(Fail, Ctx, Live, {integer, []}, Dst, Vst); valfun_4({test,bs_get_utf16,{f,Fail},Live,[Ctx,_],Dst}, Vst) -> - validate_bs_get(Fail, Ctx, Live, Dst, Vst); + validate_bs_get(Fail, Ctx, Live, {integer, []}, Dst, Vst); valfun_4({test,bs_get_utf32,{f,Fail},Live,[Ctx,_],Dst}, Vst) -> - validate_bs_get(Fail, Ctx, Live, Dst, Vst); + validate_bs_get(Fail, Ctx, Live, {integer, []}, Dst, Vst); valfun_4({bs_save2,Ctx,SavePoint}, Vst) -> bsm_save(Ctx, SavePoint, Vst); valfun_4({bs_restore2,Ctx,SavePoint}, Vst) -> @@ -794,12 +794,12 @@ verify_put_map(Fail, Src, Dst, Live, List, Vst0) -> %% %% Common code for validating bs_get* instructions. %% -validate_bs_get(Fail, Ctx, Live, Dst, Vst0) -> +validate_bs_get(Fail, Ctx, Live, Type, Dst, Vst0) -> bsm_validate_context(Ctx, Vst0), verify_live(Live, Vst0), Vst1 = prune_x_regs(Live, Vst0), Vst = branch_state(Fail, Vst1), - set_type_reg(term, Dst, Vst). + set_type_reg(Type, Dst, Vst). %% %% Common code for validating bs_skip_utf* instructions. diff --git a/lib/compiler/src/compile.erl b/lib/compiler/src/compile.erl index 03b52932d1..019d8ba864 100644 --- a/lib/compiler/src/compile.erl +++ b/lib/compiler/src/compile.erl @@ -213,14 +213,6 @@ expand_opt(report, Os) -> [report_errors,report_warnings|Os]; expand_opt(return, Os) -> [return_errors,return_warnings|Os]; -expand_opt(r12, Os) -> - [no_recv_opt,no_line_info,no_utf8_atoms|Os]; -expand_opt(r13, Os) -> - [no_record_opt,no_recv_opt,no_line_info,no_utf8_atoms|Os]; -expand_opt(r14, Os) -> - [no_record_opt,no_line_info,no_utf8_atoms|Os]; -expand_opt(r15, Os) -> - [no_record_opt,no_utf8_atoms|Os]; expand_opt(r16, Os) -> [no_record_opt,no_utf8_atoms|Os]; expand_opt(r17, Os) -> diff --git a/lib/compiler/test/beam_type_SUITE.erl b/lib/compiler/test/beam_type_SUITE.erl index 7ca544a537..c11883d5ff 100644 --- a/lib/compiler/test/beam_type_SUITE.erl +++ b/lib/compiler/test/beam_type_SUITE.erl @@ -22,7 +22,7 @@ -export([all/0,suite/0,groups/0,init_per_suite/1,end_per_suite/1, init_per_group/2,end_per_group/2, integers/1,coverage/1,booleans/1,setelement/1,cons/1, - tuple/1,record_float/1]). + tuple/1,record_float/1,binary_float/1]). suite() -> [{ct_hooks,[ts_install_cth]}]. @@ -38,7 +38,8 @@ groups() -> setelement, cons, tuple, - record_float + record_float, + binary_float ]}]. init_per_suite(Config) -> @@ -143,6 +144,12 @@ record_float(R, N0) -> N end. +binary_float(_Config) -> + <<-1/float>> = binary_negate_float(<<1/float>>), + ok. + +binary_negate_float(<<Float/float>>) -> + <<-Float/float>>. id(I) -> I. diff --git a/lib/crypto/src/crypto.erl b/lib/crypto/src/crypto.erl index 1287ec6176..765998b85d 100644 --- a/lib/crypto/src/crypto.erl +++ b/lib/crypto/src/crypto.erl @@ -35,7 +35,6 @@ -export([rand_plugin_next/1]). -export([rand_plugin_uniform/1]). -export([rand_plugin_uniform/2]). --export([rand_plugin_jump/1]). -export([rand_uniform/2]). -export([block_encrypt/3, block_decrypt/3, block_encrypt/4, block_decrypt/4]). -export([next_iv/2, next_iv/3]). @@ -316,11 +315,10 @@ rand_seed() -> rand_seed_s() -> {#{ type => ?MODULE, - max => infinity, + bits => 64, next => fun ?MODULE:rand_plugin_next/1, uniform => fun ?MODULE:rand_plugin_uniform/1, - uniform_n => fun ?MODULE:rand_plugin_uniform/2, - jump => fun ?MODULE:rand_plugin_jump/1}, + uniform_n => fun ?MODULE:rand_plugin_uniform/2}, no_seed}. rand_plugin_next(Seed) -> @@ -332,8 +330,6 @@ rand_plugin_uniform(State) -> rand_plugin_uniform(Max, State) -> {bytes_to_integer(strong_rand_range(Max)) + 1, State}. -rand_plugin_jump(State) -> - State. strong_rand_range(Range) when is_integer(Range), Range > 0 -> BinRange = int_to_bin(Range), diff --git a/lib/dialyzer/doc/src/Makefile b/lib/dialyzer/doc/src/Makefile index 77d0a6fc68..8fe6cd30eb 100644 --- a/lib/dialyzer/doc/src/Makefile +++ b/lib/dialyzer/doc/src/Makefile @@ -34,7 +34,7 @@ RELSYSDIR = $(RELEASE_PATH)/lib/$(APPLICATION)-$(VSN) # Target Specs # ---------------------------------------------------- XML_APPLICATION_FILES = ref_man.xml -XML_REF3_FILES = dialyzer.xml +XML_REF3_FILES = dialyzer.xml typer.xml XML_PART_FILES = part.xml part_notes.xml XML_CHAPTER_FILES = dialyzer_chapter.xml notes.xml diff --git a/lib/dialyzer/doc/src/ref_man.xml b/lib/dialyzer/doc/src/ref_man.xml index ddac047f2e..d820fc5e00 100644 --- a/lib/dialyzer/doc/src/ref_man.xml +++ b/lib/dialyzer/doc/src/ref_man.xml @@ -31,5 +31,6 @@ <description> </description> <xi:include href="dialyzer.xml"/> + <xi:include href="typer.xml"/> </application> diff --git a/lib/dialyzer/doc/src/typer.xml b/lib/dialyzer/doc/src/typer.xml new file mode 100644 index 0000000000..abd7f07ccf --- /dev/null +++ b/lib/dialyzer/doc/src/typer.xml @@ -0,0 +1,157 @@ +<?xml version="1.0" encoding="utf-8" ?> +<!DOCTYPE erlref SYSTEM "erlref.dtd"> + +<erlref> + <header> + <copyright> + <year>2006</year><year>2016</year> + <holder>Ericsson AB. All Rights Reserved.</holder> + </copyright> + <legalnotice> + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. + + </legalnotice> + + <title>typer</title> + <prepared></prepared> + <docno></docno> + <date>2017-04-13</date> + <rev></rev> + <file>type.xml</file> + </header> + <module>typer</module> + <modulesummary>Typer, a Type annotator for ERlang programs. + </modulesummary> + <description> + <p>TypEr shows type information for Erlang modules to the user. + Additionally, it can annotate the code of files with such type + information.</p> + </description> + + <section> + <marker id="command_line"></marker> + <title>Using TypEr from the Command Line</title> + <p>TypEr is used from the command-line. This section provides a + brief description of the options. The same information can be + obtained by writing the following in a shell:</p> + + <code type="none"> +typer --help</code> + + <p><em>Usage:</em></p> + + <code type="none"> +typer [--help] [--version] [--plt PLT] [--edoc] + [--show | --show-exported | --annotate | --annotate-inc-files] + [-Ddefine]* [-I include_dir]* [-pa dir]* [-pz dir]* + [-T application]* [-r] file*</code> + + <note> + <p>* denotes that multiple occurrences of the option are possible.</p> + </note> + + <p><em>Options:</em></p> + + <taglist> + + <tag><c>--r</c></tag> + <item> + <p>Search directories recursively for .erl files below them.</p> + </item> + <tag><c>--show</c></tag> + <item> + <p>Print type specifications for all functions on stdout. + (This is the default behaviour; this option is not really + needed.)</p> + </item> + + <tag><c>--show-exported</c> (or <c>show_exported</c>)</tag> + <item> + <p>Same as <c>--show</c>, but print specifications for + exported functions only. Specs are displayed sorted + alphabetically on the function's name.</p> + </item> + + <tag><c>--annotate</c></tag> + <item> + <p>Annotate the specified files with type specifications.</p> + </item> + + <tag><c>--annotate-inc-files</c></tag> + <item> + <p>Same as <c>--annotate</c> but annotates all + <c>-include()</c> files as well as all .erl files. (Use this + option with caution - it has not been tested much).</p> + </item> + + <tag><c>--edoc</c></tag> + <item> + <p>Print type information as Edoc <c>@spec</c> comments, not + as type specs.</p> + </item> + + <tag><c>--plt</c></tag> + <item> + <p>Use the specified dialyzer PLT file rather than the default one.</p> + </item> + + <tag><c>-T file*</c></tag> + <item> + <p>The specified file(s) already contain type specifications + and these are to be trusted in order to print specs for the + rest of the files. (Multiple files or dirs, separated by + spaces, can be specified.)</p> + </item> + + <tag><c>-Dname</c> (or <c>-Dname=value</c>)</tag> + <item> + <p>Pass the defined name(s) to TypEr. (**)</p> + </item> + + <tag><c>-I</c></tag> + <item> + <p>Pass the include_dir to TypEr. (**)</p> + </item> + + <tag><c>-pa dir</c></tag> + <item> + <p>Include <c>dir</c> in the path for Erlang. This is useful + when analyzing files that have <c>-include_lib()</c> + directives or use parse transforms.</p> + </item> + + <tag><c>-pz dir</c></tag> + <item> + <p>Include <c>dir</c> in the path for Erlang. This is useful + when analyzing files that have <c>-include_lib()</c> + directives or use parse transforms.</p> + </item> + + <tag><c>--version</c> (or <c>-v</c>)</tag> + <item> + <p>Print the TypEr version and some more information and + exit.</p> + </item> + + </taglist> + + <note> + <p>** options <c>-D</c> and <c>-I</c> work both + from the command line and in the TypEr GUI; the syntax of + defines and includes is the same as that used by + <seealso marker="erts:erlc">erlc(1)</seealso>.</p> + </note> + + </section> + +</erlref> diff --git a/lib/dialyzer/src/Makefile b/lib/dialyzer/src/Makefile index 256f20f549..28f74ed441 100644 --- a/lib/dialyzer/src/Makefile +++ b/lib/dialyzer/src/Makefile @@ -68,7 +68,8 @@ MODULES = \ dialyzer_typesig \ dialyzer_coordinator \ dialyzer_worker \ - dialyzer_utils + dialyzer_utils \ + typer HRL_FILES= dialyzer.hrl dialyzer_gui_wx.hrl ERL_FILES= $(MODULES:%=%.erl) @@ -117,6 +118,9 @@ $(EBIN)/dialyzer_plt.$(EMULATOR): dialyzer_plt.erl ../vsn.mk $(EBIN)/dialyzer_gui_wx.$(EMULATOR): dialyzer_gui_wx.erl ../vsn.mk $(erlc_verbose)erlc -W $(ERL_COMPILE_FLAGS) -DVSN="\"v$(VSN)\"" -o$(EBIN) dialyzer_gui_wx.erl +$(EBIN)/typer.$(EMULATOR): typer.erl ../vsn.mk + $(erlc_verbose)erlc -W $(ERL_COMPILE_FLAGS) -DVSN="\"v$(VSN)\"" -o$(EBIN) typer.erl + $(APP_TARGET): $(APP_SRC) ../vsn.mk $(vsn_verbose)sed -e 's;%VSN%;$(VSN);' $< > $@ diff --git a/lib/dialyzer/src/dialyzer.app.src b/lib/dialyzer/src/dialyzer.app.src index f517c51ec1..5f803875b0 100644 --- a/lib/dialyzer/src/dialyzer.app.src +++ b/lib/dialyzer/src/dialyzer.app.src @@ -43,7 +43,8 @@ dialyzer_typesig, dialyzer_utils, dialyzer_timing, - dialyzer_worker]}, + dialyzer_worker, + typer]}, {registered, []}, {applications, [compiler, hipe, kernel, stdlib, wx]}, {env, []}, diff --git a/lib/dialyzer/src/typer.erl b/lib/dialyzer/src/typer.erl new file mode 100644 index 0000000000..18c4fe902d --- /dev/null +++ b/lib/dialyzer/src/typer.erl @@ -0,0 +1,1110 @@ +%% -*- erlang-indent-level: 2 -*- +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. + +%%----------------------------------------------------------------------- +%% File : typer.erl +%% Author(s) : The first version of typer was written by Bingwen He +%% with guidance from Kostis Sagonas and Tobias Lindahl. +%% Since June 2008 typer is maintained by Kostis Sagonas. +%% Description : An Erlang/OTP application that shows type information +%% for Erlang modules to the user. Additionally, it can +%% annotate the code of files with such type information. +%%----------------------------------------------------------------------- + +-module(typer). + +-export([start/0]). + +%%----------------------------------------------------------------------- + +-define(SHOW, show). +-define(SHOW_EXPORTED, show_exported). +-define(ANNOTATE, annotate). +-define(ANNOTATE_INC_FILES, annotate_inc_files). + +-type mode() :: ?SHOW | ?SHOW_EXPORTED | ?ANNOTATE | ?ANNOTATE_INC_FILES. + +%%----------------------------------------------------------------------- + +-type files() :: [file:filename()]. +-type callgraph() :: dialyzer_callgraph:callgraph(). +-type codeserver() :: dialyzer_codeserver:codeserver(). +-type plt() :: dialyzer_plt:plt(). + +-record(analysis, + {mode :: mode() | 'undefined', + macros = [] :: [{atom(), term()}], + includes = [] :: files(), + codeserver = dialyzer_codeserver:new():: codeserver(), + callgraph = dialyzer_callgraph:new() :: callgraph(), + files = [] :: files(), % absolute names + plt = none :: 'none' | file:filename(), + no_spec = false :: boolean(), + show_succ = false :: boolean(), + %% For choosing between specs or edoc @spec comments + edoc = false :: boolean(), + %% Files in 'fms' are compilable with option 'to_pp'; we keep them + %% as {FileName, ModuleName} in case the ModuleName is different + fms = [] :: [{file:filename(), module()}], + ex_func = map__new() :: map_dict(), + record = map__new() :: map_dict(), + func = map__new() :: map_dict(), + inc_func = map__new() :: map_dict(), + trust_plt = dialyzer_plt:new() :: plt()}). +-type analysis() :: #analysis{}. + +-record(args, {files = [] :: files(), + files_r = [] :: files(), + trusted = [] :: files()}). +-type args() :: #args{}. + +%%-------------------------------------------------------------------- + +-spec start() -> no_return(). + +start() -> + {Args, Analysis} = process_cl_args(), + %% io:format("Args: ~p\n", [Args]), + %% io:format("Analysis: ~p\n", [Analysis]), + Timer = dialyzer_timing:init(false), + TrustedFiles = filter_fd(Args#args.trusted, [], fun is_erl_file/1), + Analysis2 = extract(Analysis, TrustedFiles), + All_Files = get_all_files(Args), + %% io:format("All_Files: ~p\n", [All_Files]), + Analysis3 = Analysis2#analysis{files = All_Files}, + Analysis4 = collect_info(Analysis3), + %% io:format("Final: ~p\n", [Analysis4#analysis.fms]), + TypeInfo = get_type_info(Analysis4), + dialyzer_timing:stop(Timer), + show_or_annotate(TypeInfo), + %% io:format("\nTyper analysis finished\n"), + erlang:halt(0). + +%%-------------------------------------------------------------------- + +-spec extract(analysis(), files()) -> analysis(). + +extract(#analysis{macros = Macros, + includes = Includes, + trust_plt = TrustPLT} = Analysis, TrustedFiles) -> + %% io:format("--- Extracting trusted typer_info... "), + Ds = [{d, Name, Value} || {Name, Value} <- Macros], + CodeServer = dialyzer_codeserver:new(), + Fun = + fun(File, CS) -> + %% We include one more dir; the one above the one we are trusting + %% E.g, for /home/tests/typer_ann/test.ann.erl, we should include + %% /home/tests/ rather than /home/tests/typer_ann/ + AllIncludes = [filename:dirname(filename:dirname(File)) | Includes], + Is = [{i, Dir} || Dir <- AllIncludes], + CompOpts = dialyzer_utils:src_compiler_opts() ++ Is ++ Ds, + case dialyzer_utils:get_abstract_code_from_src(File, CompOpts) of + {ok, AbstractCode} -> + case dialyzer_utils:get_record_and_type_info(AbstractCode) of + {ok, RecDict} -> + Mod = list_to_atom(filename:basename(File, ".erl")), + case dialyzer_utils:get_spec_info(Mod, AbstractCode, RecDict) of + {ok, SpecDict, CbDict} -> + CS1 = dialyzer_codeserver:store_temp_records(Mod, RecDict, CS), + dialyzer_codeserver:store_temp_contracts(Mod, SpecDict, CbDict, CS1); + {error, Reason} -> compile_error([Reason]) + end; + {error, Reason} -> compile_error([Reason]) + end; + {error, Reason} -> compile_error(Reason) + end + end, + CodeServer1 = lists:foldl(Fun, CodeServer, TrustedFiles), + %% Process remote types + NewCodeServer = + try + CodeServer2 = + dialyzer_utils:merge_types(CodeServer1, + TrustPLT), % XXX change to the PLT? + NewExpTypes = dialyzer_codeserver:get_temp_exported_types(CodeServer1), + case sets:size(NewExpTypes) of 0 -> ok end, + CodeServer3 = dialyzer_codeserver:finalize_exported_types(NewExpTypes, CodeServer2), + CodeServer4 = dialyzer_utils:process_record_remote_types(CodeServer3), + dialyzer_contracts:process_contract_remote_types(CodeServer4) + catch + throw:{error, ErrorMsg} -> + compile_error(ErrorMsg) + end, + %% Create TrustPLT + ContractsDict = dialyzer_codeserver:get_contracts(NewCodeServer), + Contracts = orddict:from_list(dict:to_list(ContractsDict)), + NewTrustPLT = dialyzer_plt:insert_contract_list(TrustPLT, Contracts), + Analysis#analysis{trust_plt = NewTrustPLT}. + +%%-------------------------------------------------------------------- + +-spec get_type_info(analysis()) -> analysis(). + +get_type_info(#analysis{callgraph = CallGraph, + trust_plt = TrustPLT, + codeserver = CodeServer} = Analysis) -> + StrippedCallGraph = remove_external(CallGraph, TrustPLT), + %% io:format("--- Analyzing callgraph... "), + try + NewMiniPlt = dialyzer_succ_typings:analyze_callgraph(StrippedCallGraph, + TrustPLT, + CodeServer), + NewPlt = dialyzer_plt:restore_full_plt(NewMiniPlt), + Analysis#analysis{callgraph = StrippedCallGraph, trust_plt = NewPlt} + catch + error:What -> + fatal_error(io_lib:format("Analysis failed with message: ~p", + [{What, erlang:get_stacktrace()}])); + throw:{dialyzer_succ_typing_error, Msg} -> + fatal_error(io_lib:format("Analysis failed with message: ~s", [Msg])) + end. + +-spec remove_external(callgraph(), plt()) -> callgraph(). + +remove_external(CallGraph, PLT) -> + {StrippedCG0, Ext} = dialyzer_callgraph:remove_external(CallGraph), + case get_external(Ext, PLT) of + [] -> ok; + Externals -> + msg(io_lib:format(" Unknown functions: ~p\n", [lists:usort(Externals)])), + ExtTypes = rcv_ext_types(), + case ExtTypes of + [] -> ok; + _ -> msg(io_lib:format(" Unknown types: ~p\n", [ExtTypes])) + end + end, + StrippedCG0. + +-spec get_external([{mfa(), mfa()}], plt()) -> [mfa()]. + +get_external(Exts, Plt) -> + Fun = fun ({_From, To = {M, F, A}}, Acc) -> + case dialyzer_plt:contains_mfa(Plt, To) of + false -> + case erl_bif_types:is_known(M, F, A) of + true -> Acc; + false -> [To|Acc] + end; + true -> Acc + end + end, + lists:foldl(Fun, [], Exts). + +%%-------------------------------------------------------------------- +%% Showing type information or annotating files with such information. +%%-------------------------------------------------------------------- + +-define(TYPER_ANN_DIR, "typer_ann"). + +-type line() :: non_neg_integer(). +-type fa() :: {atom(), arity()}. +-type func_info() :: {line(), atom(), arity()}. + +-record(info, {records = maps:new() :: erl_types:type_table(), + functions = [] :: [func_info()], + types = map__new() :: map_dict(), + edoc = false :: boolean()}). +-record(inc, {map = map__new() :: map_dict(), filter = [] :: files()}). +-type inc() :: #inc{}. + +-spec show_or_annotate(analysis()) -> 'ok'. + +show_or_annotate(#analysis{mode = Mode, fms = Files} = Analysis) -> + case Mode of + ?SHOW -> show(Analysis); + ?SHOW_EXPORTED -> show(Analysis); + ?ANNOTATE -> + Fun = fun ({File, Module}) -> + Info = get_final_info(File, Module, Analysis), + write_typed_file(File, Info) + end, + lists:foreach(Fun, Files); + ?ANNOTATE_INC_FILES -> + IncInfo = write_and_collect_inc_info(Analysis), + write_inc_files(IncInfo) + end. + +write_and_collect_inc_info(Analysis) -> + Fun = fun ({File, Module}, Inc) -> + Info = get_final_info(File, Module, Analysis), + write_typed_file(File, Info), + IncFuns = get_functions(File, Analysis), + collect_imported_functions(IncFuns, Info#info.types, Inc) + end, + NewInc = lists:foldl(Fun, #inc{}, Analysis#analysis.fms), + clean_inc(NewInc). + +write_inc_files(Inc) -> + Fun = + fun (File) -> + Val = map__lookup(File, Inc#inc.map), + %% Val is function with its type info + %% in form [{{Line,F,A},Type}] + Functions = [Key || {Key, _} <- Val], + Val1 = [{{F,A},Type} || {{_Line,F,A},Type} <- Val], + Info = #info{types = map__from_list(Val1), + records = maps:new(), + %% Note we need to sort functions here! + functions = lists:keysort(1, Functions)}, + %% io:format("Types ~p\n", [Info#info.types]), + %% io:format("Functions ~p\n", [Info#info.functions]), + %% io:format("Records ~p\n", [Info#info.records]), + write_typed_file(File, Info) + end, + lists:foreach(Fun, dict:fetch_keys(Inc#inc.map)). + +show(Analysis) -> + Fun = fun ({File, Module}) -> + Info = get_final_info(File, Module, Analysis), + show_type_info(File, Info) + end, + lists:foreach(Fun, Analysis#analysis.fms). + +get_final_info(File, Module, Analysis) -> + Records = get_records(File, Analysis), + Types = get_types(Module, Analysis, Records), + Functions = get_functions(File, Analysis), + Edoc = Analysis#analysis.edoc, + #info{records = Records, functions = Functions, types = Types, edoc = Edoc}. + +collect_imported_functions(Functions, Types, Inc) -> + %% Coming from other sourses, including: + %% FIXME: How to deal with yecc-generated file???? + %% --.yrl (yecc-generated file)??? + %% -- yeccpre.hrl (yecc-generated file)??? + %% -- other cases + Fun = fun ({File, _} = Obj, I) -> + case is_yecc_gen(File, I) of + {true, NewI} -> NewI; + {false, NewI} -> + check_imported_functions(Obj, NewI, Types) + end + end, + lists:foldl(Fun, Inc, Functions). + +-spec is_yecc_gen(file:filename(), inc()) -> {boolean(), inc()}. + +is_yecc_gen(File, #inc{filter = Fs} = Inc) -> + case lists:member(File, Fs) of + true -> {true, Inc}; + false -> + case filename:extension(File) of + ".yrl" -> + Rootname = filename:rootname(File, ".yrl"), + Obj = Rootname ++ ".erl", + case lists:member(Obj, Fs) of + true -> {true, Inc}; + false -> + NewInc = Inc#inc{filter = [Obj|Fs]}, + {true, NewInc} + end; + _ -> + case filename:basename(File) of + "yeccpre.hrl" -> {true, Inc}; + _ -> {false, Inc} + end + end + end. + +check_imported_functions({File, {Line, F, A}}, Inc, Types) -> + IncMap = Inc#inc.map, + FA = {F, A}, + Type = get_type_info(FA, Types), + case map__lookup(File, IncMap) of + none -> %% File is not added. Add it + Obj = {File,[{FA, {Line, Type}}]}, + NewMap = map__insert(Obj, IncMap), + Inc#inc{map = NewMap}; + Val -> %% File is already in. Check. + case lists:keyfind(FA, 1, Val) of + false -> + %% Function is not in; add it + Obj = {File, Val ++ [{FA, {Line, Type}}]}, + NewMap = map__insert(Obj, IncMap), + Inc#inc{map = NewMap}; + Type -> + %% Function is in and with same type + Inc; + _ -> + %% Function is in but with diff type + inc_warning(FA, File), + Elem = lists:keydelete(FA, 1, Val), + NewMap = case Elem of + [] -> map__remove(File, IncMap); + _ -> map__insert({File, Elem}, IncMap) + end, + Inc#inc{map = NewMap} + end + end. + +inc_warning({F, A}, File) -> + io:format(" ***Warning: Skip function ~p/~p ", [F, A]), + io:format("in file ~p because of inconsistent type\n", [File]). + +clean_inc(Inc) -> + Inc1 = remove_yecc_generated_file(Inc), + normalize_obj(Inc1). + +remove_yecc_generated_file(#inc{filter = Filter} = Inc) -> + Fun = fun (Key, #inc{map = Map} = I) -> + I#inc{map = map__remove(Key, Map)} + end, + lists:foldl(Fun, Inc, Filter). + +normalize_obj(TmpInc) -> + Fun = fun (Key, Val, Inc) -> + NewVal = [{{Line,F,A},Type} || {{F,A},{Line,Type}} <- Val], + map__insert({Key, NewVal}, Inc) + end, + TmpInc#inc{map = map__fold(Fun, map__new(), TmpInc#inc.map)}. + +get_records(File, Analysis) -> + map__lookup(File, Analysis#analysis.record). + +get_types(Module, Analysis, Records) -> + TypeInfoPlt = Analysis#analysis.trust_plt, + TypeInfo = + case dialyzer_plt:lookup_module(TypeInfoPlt, Module) of + none -> []; + {value, List} -> List + end, + CodeServer = Analysis#analysis.codeserver, + TypeInfoList = + case Analysis#analysis.show_succ of + true -> + [convert_type_info(I) || I <- TypeInfo]; + false -> + [get_type(I, CodeServer, Records) || I <- TypeInfo] + end, + map__from_list(TypeInfoList). + +convert_type_info({{_M, F, A}, Range, Arg}) -> + {{F, A}, {Range, Arg}}. + +get_type({{M, F, A} = MFA, Range, Arg}, CodeServer, Records) -> + case dialyzer_codeserver:lookup_mfa_contract(MFA, CodeServer) of + error -> + {{F, A}, {Range, Arg}}; + {ok, {_FileLine, Contract, _Xtra}} -> + Sig = erl_types:t_fun(Arg, Range), + case dialyzer_contracts:check_contract(Contract, Sig) of + ok -> {{F, A}, {contract, Contract}}; + {error, {extra_range, _, _}} -> + {{F, A}, {contract, Contract}}; + {error, {overlapping_contract, []}} -> + {{F, A}, {contract, Contract}}; + {error, invalid_contract} -> + CString = dialyzer_contracts:contract_to_string(Contract), + SigString = dialyzer_utils:format_sig(Sig, Records), + Msg = io_lib:format("Error in contract of function ~w:~w/~w\n" + "\t The contract is: " ++ CString ++ "\n" ++ + "\t but the inferred signature is: ~s", + [M, F, A, SigString]), + fatal_error(Msg); + {error, ErrorStr} when is_list(ErrorStr) -> % ErrorStr is a string() + Msg = io_lib:format("Error in contract of function ~w:~w/~w: ~s", + [M, F, A, ErrorStr]), + fatal_error(Msg) + end + end. + +get_functions(File, Analysis) -> + case Analysis#analysis.mode of + ?SHOW -> + Funcs = map__lookup(File, Analysis#analysis.func), + Inc_Funcs = map__lookup(File, Analysis#analysis.inc_func), + remove_module_info(Funcs) ++ normalize_incFuncs(Inc_Funcs); + ?SHOW_EXPORTED -> + Ex_Funcs = map__lookup(File, Analysis#analysis.ex_func), + remove_module_info(Ex_Funcs); + ?ANNOTATE -> + Funcs = map__lookup(File, Analysis#analysis.func), + remove_module_info(Funcs); + ?ANNOTATE_INC_FILES -> + map__lookup(File, Analysis#analysis.inc_func) + end. + +normalize_incFuncs(Functions) -> + [FunInfo || {_FileName, FunInfo} <- Functions]. + +-spec remove_module_info([func_info()]) -> [func_info()]. + +remove_module_info(FunInfoList) -> + F = fun ({_,module_info,0}) -> false; + ({_,module_info,1}) -> false; + ({Line,F,A}) when is_integer(Line), is_atom(F), is_integer(A) -> true + end, + lists:filter(F, FunInfoList). + +write_typed_file(File, Info) -> + io:format(" Processing file: ~p\n", [File]), + Dir = filename:dirname(File), + RootName = filename:basename(filename:rootname(File)), + Ext = filename:extension(File), + TyperAnnDir = filename:join(Dir, ?TYPER_ANN_DIR), + TmpNewFilename = lists:concat([RootName, ".ann", Ext]), + NewFileName = filename:join(TyperAnnDir, TmpNewFilename), + case file:make_dir(TyperAnnDir) of + {error, Reason} -> + case Reason of + eexist -> %% TypEr dir exists; remove old typer files if they exist + case file:delete(NewFileName) of + ok -> ok; + {error, enoent} -> ok; + {error, _} -> + Msg = io_lib:format("Error in deleting file ~s\n", [NewFileName]), + fatal_error(Msg) + end, + write_typed_file(File, Info, NewFileName); + enospc -> + Msg = io_lib:format("Not enough space in ~p\n", [Dir]), + fatal_error(Msg); + eacces -> + Msg = io_lib:format("No write permission in ~p\n", [Dir]), + fatal_error(Msg); + _ -> + Msg = io_lib:format("Unhandled error ~s when writing ~p\n", + [Reason, Dir]), + fatal_error(Msg) + end; + ok -> %% Typer dir does NOT exist + write_typed_file(File, Info, NewFileName) + end. + +write_typed_file(File, Info, NewFileName) -> + {ok, Binary} = file:read_file(File), + Chars = binary_to_list(Binary), + write_typed_file(Chars, NewFileName, Info, 1, []), + io:format(" Saved as: ~p\n", [NewFileName]). + +write_typed_file(Chars, File, #info{functions = []}, _LNo, _Acc) -> + ok = file:write_file(File, list_to_binary(Chars), [append]); +write_typed_file([Ch|Chs] = Chars, File, Info, LineNo, Acc) -> + [{Line,F,A}|RestFuncs] = Info#info.functions, + case Line of + 1 -> %% This will happen only for inc files + ok = raw_write(F, A, Info, File, []), + NewInfo = Info#info{functions = RestFuncs}, + NewAcc = [], + write_typed_file(Chars, File, NewInfo, Line, NewAcc); + _ -> + case Ch of + 10 -> + NewLineNo = LineNo + 1, + {NewInfo, NewAcc} = + case NewLineNo of + Line -> + ok = raw_write(F, A, Info, File, [Ch|Acc]), + {Info#info{functions = RestFuncs}, []}; + _ -> + {Info, [Ch|Acc]} + end, + write_typed_file(Chs, File, NewInfo, NewLineNo, NewAcc); + _ -> + write_typed_file(Chs, File, Info, LineNo, [Ch|Acc]) + end + end. + +raw_write(F, A, Info, File, Content) -> + TypeInfo = get_type_string(F, A, Info, file), + ContentList = lists:reverse(Content) ++ TypeInfo ++ "\n", + ContentBin = list_to_binary(ContentList), + file:write_file(File, ContentBin, [append]). + +get_type_string(F, A, Info, Mode) -> + Type = get_type_info({F,A}, Info#info.types), + TypeStr = + case Type of + {contract, C} -> + dialyzer_contracts:contract_to_string(C); + {RetType, ArgType} -> + Sig = erl_types:t_fun(ArgType, RetType), + dialyzer_utils:format_sig(Sig, Info#info.records) + end, + case Info#info.edoc of + false -> + case {Mode, Type} of + {file, {contract, _}} -> ""; + _ -> + Prefix = lists:concat(["-spec ", erl_types:atom_to_string(F)]), + lists:concat([Prefix, TypeStr, "."]) + end; + true -> + Prefix = lists:concat(["%% @spec ", F]), + lists:concat([Prefix, TypeStr, "."]) + end. + +show_type_info(File, Info) -> + io:format("\n%% File: ~p\n%% ", [File]), + OutputString = lists:concat(["~.", length(File)+8, "c~n"]), + io:fwrite(OutputString, [$-]), + Fun = fun ({_LineNo, F, A}) -> + TypeInfo = get_type_string(F, A, Info, show), + io:format("~s\n", [TypeInfo]) + end, + lists:foreach(Fun, Info#info.functions). + +get_type_info(Func, Types) -> + case map__lookup(Func, Types) of + none -> + %% Note: Typeinfo of any function should exist in + %% the result offered by dialyzer, otherwise there + %% *must* be something wrong with the analysis + Msg = io_lib:format("No type info for function: ~p\n", [Func]), + fatal_error(Msg); + {contract, _Fun} = C -> C; + {_RetType, _ArgType} = RA -> RA + end. + +%%-------------------------------------------------------------------- +%% Processing of command-line options and arguments. +%%-------------------------------------------------------------------- + +-spec process_cl_args() -> {args(), analysis()}. + +process_cl_args() -> + ArgList = init:get_plain_arguments(), + %% io:format("Args is ~p\n", [ArgList]), + {Args, Analysis} = analyze_args(ArgList, #args{}, #analysis{}), + %% if the mode has not been set, set it to the default mode (show) + {Args, case Analysis#analysis.mode of + undefined -> Analysis#analysis{mode = ?SHOW}; + Mode when is_atom(Mode) -> Analysis + end}. + +analyze_args([], Args, Analysis) -> + {Args, Analysis}; +analyze_args(ArgList, Args, Analysis) -> + {Result, Rest} = cl(ArgList), + {NewArgs, NewAnalysis} = analyze_result(Result, Args, Analysis), + analyze_args(Rest, NewArgs, NewAnalysis). + +cl(["-h"|_]) -> help_message(); +cl(["--help"|_]) -> help_message(); +cl(["-v"|_]) -> version_message(); +cl(["--version"|_]) -> version_message(); +cl(["--edoc"|Opts]) -> {edoc, Opts}; +cl(["--show"|Opts]) -> {{mode, ?SHOW}, Opts}; +cl(["--show_exported"|Opts]) -> {{mode, ?SHOW_EXPORTED}, Opts}; +cl(["--show-exported"|Opts]) -> {{mode, ?SHOW_EXPORTED}, Opts}; +cl(["--show_success_typings"|Opts]) -> {show_succ, Opts}; +cl(["--show-success-typings"|Opts]) -> {show_succ, Opts}; +cl(["--annotate"|Opts]) -> {{mode, ?ANNOTATE}, Opts}; +cl(["--annotate-inc-files"|Opts]) -> {{mode, ?ANNOTATE_INC_FILES}, Opts}; +cl(["--no_spec"|Opts]) -> {no_spec, Opts}; +cl(["--plt",Plt|Opts]) -> {{plt, Plt}, Opts}; +cl(["-D"++Def|Opts]) -> + case Def of + "" -> fatal_error("no variable name specified after -D"); + _ -> + DefPair = process_def_list(re:split(Def, "=", [{return, list}])), + {{def, DefPair}, Opts} + end; +cl(["-I",Dir|Opts]) -> {{inc, Dir}, Opts}; +cl(["-I"++Dir|Opts]) -> + case Dir of + "" -> fatal_error("no include directory specified after -I"); + _ -> {{inc, Dir}, Opts} + end; +cl(["-T"|Opts]) -> + {Files, RestOpts} = dialyzer_cl_parse:collect_args(Opts), + case Files of + [] -> fatal_error("no file or directory specified after -T"); + [_|_] -> {{trusted, Files}, RestOpts} + end; +cl(["-r"|Opts]) -> + {Files, RestOpts} = dialyzer_cl_parse:collect_args(Opts), + {{files_r, Files}, RestOpts}; +cl(["-pa",Dir|Opts]) -> {{pa,Dir}, Opts}; +cl(["-pz",Dir|Opts]) -> {{pz,Dir}, Opts}; +cl(["-"++H|_]) -> fatal_error("unknown option -"++H); +cl(Opts) -> + {Files, RestOpts} = dialyzer_cl_parse:collect_args(Opts), + {{files, Files}, RestOpts}. + +process_def_list(L) -> + case L of + [Name, Value] -> + {ok, Tokens, _} = erl_scan:string(Value ++ "."), + {ok, ErlValue} = erl_parse:parse_term(Tokens), + {list_to_atom(Name), ErlValue}; + [Name] -> + {list_to_atom(Name), true} + end. + +%% Get information about files that the user trusts and wants to analyze +analyze_result({files, Val}, Args, Analysis) -> + NewVal = Args#args.files ++ Val, + {Args#args{files = NewVal}, Analysis}; +analyze_result({files_r, Val}, Args, Analysis) -> + NewVal = Args#args.files_r ++ Val, + {Args#args{files_r = NewVal}, Analysis}; +analyze_result({trusted, Val}, Args, Analysis) -> + NewVal = Args#args.trusted ++ Val, + {Args#args{trusted = NewVal}, Analysis}; +analyze_result(edoc, Args, Analysis) -> + {Args, Analysis#analysis{edoc = true}}; +%% Get useful information for actual analysis +analyze_result({mode, Mode}, Args, Analysis) -> + case Analysis#analysis.mode of + undefined -> {Args, Analysis#analysis{mode = Mode}}; + OldMode -> mode_error(OldMode, Mode) + end; +analyze_result({def, Val}, Args, Analysis) -> + NewVal = Analysis#analysis.macros ++ [Val], + {Args, Analysis#analysis{macros = NewVal}}; +analyze_result({inc, Val}, Args, Analysis) -> + NewVal = Analysis#analysis.includes ++ [Val], + {Args, Analysis#analysis{includes = NewVal}}; +analyze_result({plt, Plt}, Args, Analysis) -> + {Args, Analysis#analysis{plt = Plt}}; +analyze_result(show_succ, Args, Analysis) -> + {Args, Analysis#analysis{show_succ = true}}; +analyze_result(no_spec, Args, Analysis) -> + {Args, Analysis#analysis{no_spec = true}}; +analyze_result({pa, Dir}, Args, Analysis) -> + true = code:add_patha(Dir), + {Args, Analysis}; +analyze_result({pz, Dir}, Args, Analysis) -> + true = code:add_pathz(Dir), + {Args, Analysis}. + +%%-------------------------------------------------------------------- +%% File processing. +%%-------------------------------------------------------------------- + +-spec get_all_files(args()) -> [file:filename(),...]. + +get_all_files(#args{files = Fs, files_r = Ds}) -> + case filter_fd(Fs, Ds, fun test_erl_file_exclude_ann/1) of + [] -> fatal_error("no file(s) to analyze"); + AllFiles -> AllFiles + end. + +-spec test_erl_file_exclude_ann(file:filename()) -> boolean(). + +test_erl_file_exclude_ann(File) -> + case is_erl_file(File) of + true -> %% Exclude files ending with ".ann.erl" + case re:run(File, "[\.]ann[\.]erl$") of + {match, _} -> false; + nomatch -> true + end; + false -> false + end. + +-spec is_erl_file(file:filename()) -> boolean(). + +is_erl_file(File) -> + filename:extension(File) =:= ".erl". + +-type test_file_fun() :: fun((file:filename()) -> boolean()). + +-spec filter_fd(files(), files(), test_file_fun()) -> files(). + +filter_fd(File_Dir, Dir_R, Fun) -> + All_File_1 = process_file_and_dir(File_Dir, Fun), + All_File_2 = process_dir_rec(Dir_R, Fun), + remove_dup(All_File_1 ++ All_File_2). + +-spec process_file_and_dir(files(), test_file_fun()) -> files(). + +process_file_and_dir(File_Dir, TestFun) -> + Fun = + fun (Elem, Acc) -> + case filelib:is_regular(Elem) of + true -> process_file(Elem, TestFun, Acc); + false -> check_dir(Elem, false, Acc, TestFun) + end + end, + lists:foldl(Fun, [], File_Dir). + +-spec process_dir_rec(files(), test_file_fun()) -> files(). + +process_dir_rec(Dirs, TestFun) -> + Fun = fun (Dir, Acc) -> check_dir(Dir, true, Acc, TestFun) end, + lists:foldl(Fun, [], Dirs). + +-spec check_dir(file:filename(), boolean(), files(), test_file_fun()) -> files(). + +check_dir(Dir, Recursive, Acc, Fun) -> + case file:list_dir(Dir) of + {ok, Files} -> + {TmpDirs, TmpFiles} = split_dirs_and_files(Files, Dir), + case Recursive of + false -> + FinalFiles = process_file_and_dir(TmpFiles, Fun), + Acc ++ FinalFiles; + true -> + TmpAcc1 = process_file_and_dir(TmpFiles, Fun), + TmpAcc2 = process_dir_rec(TmpDirs, Fun), + Acc ++ TmpAcc1 ++ TmpAcc2 + end; + {error, eacces} -> + fatal_error("no access permission to dir \""++Dir++"\""); + {error, enoent} -> + fatal_error("cannot access "++Dir++": No such file or directory"); + {error, _Reason} -> + fatal_error("error involving a use of file:list_dir/1") + end. + +%% Same order as the input list +-spec process_file(file:filename(), test_file_fun(), files()) -> files(). + +process_file(File, TestFun, Acc) -> + case TestFun(File) of + true -> Acc ++ [File]; + false -> Acc + end. + +%% Same order as the input list +-spec split_dirs_and_files(files(), file:filename()) -> {files(), files()}. + +split_dirs_and_files(Elems, Dir) -> + Test_Fun = + fun (Elem, {DirAcc, FileAcc}) -> + File = filename:join(Dir, Elem), + case filelib:is_regular(File) of + false -> {[File|DirAcc], FileAcc}; + true -> {DirAcc, [File|FileAcc]} + end + end, + {Dirs, Files} = lists:foldl(Test_Fun, {[], []}, Elems), + {lists:reverse(Dirs), lists:reverse(Files)}. + +%% Removes duplicate filenames but keeps the order of the input list +-spec remove_dup(files()) -> files(). + +remove_dup(Files) -> + Test_Dup = fun (File, Acc) -> + case lists:member(File, Acc) of + true -> Acc; + false -> [File|Acc] + end + end, + Reversed_Elems = lists:foldl(Test_Dup, [], Files), + lists:reverse(Reversed_Elems). + +%%-------------------------------------------------------------------- +%% Collect information. +%%-------------------------------------------------------------------- + +-type inc_file_info() :: {file:filename(), func_info()}. + +-record(tmpAcc, {file :: file:filename(), + module :: atom(), + funcAcc = [] :: [func_info()], + incFuncAcc = [] :: [inc_file_info()], + dialyzerObj = [] :: [{mfa(), {_, _}}]}). + +-spec collect_info(analysis()) -> analysis(). + +collect_info(Analysis) -> + NewPlt = + try get_dialyzer_plt(Analysis) of + DialyzerPlt -> + dialyzer_plt:merge_plts([Analysis#analysis.trust_plt, DialyzerPlt]) + catch + throw:{dialyzer_error,_Reason} -> + fatal_error("Dialyzer's PLT is missing or is not up-to-date; please (re)create it") + end, + NewAnalysis = lists:foldl(fun collect_one_file_info/2, + Analysis#analysis{trust_plt = NewPlt}, + Analysis#analysis.files), + %% Process Remote Types + TmpCServer = NewAnalysis#analysis.codeserver, + NewCServer = + try + TmpCServer1 = dialyzer_utils:merge_types(TmpCServer, NewPlt), + NewExpTypes = dialyzer_codeserver:get_temp_exported_types(TmpCServer), + OldExpTypes = dialyzer_plt:get_exported_types(NewPlt), + MergedExpTypes = sets:union(NewExpTypes, OldExpTypes), + TmpCServer2 = + dialyzer_codeserver:finalize_exported_types(MergedExpTypes, TmpCServer1), + TmpCServer3 = dialyzer_utils:process_record_remote_types(TmpCServer2), + dialyzer_contracts:process_contract_remote_types(TmpCServer3) + catch + throw:{error, ErrorMsg} -> + fatal_error(ErrorMsg) + end, + NewAnalysis#analysis{codeserver = NewCServer}. + +collect_one_file_info(File, Analysis) -> + Ds = [{d,Name,Val} || {Name,Val} <- Analysis#analysis.macros], + %% Current directory should also be included in "Includes". + Includes = [filename:dirname(File)|Analysis#analysis.includes], + Is = [{i,Dir} || Dir <- Includes], + Options = dialyzer_utils:src_compiler_opts() ++ Is ++ Ds, + case dialyzer_utils:get_abstract_code_from_src(File, Options) of + {error, Reason} -> + %% io:format("File=~p\n,Options=~p\n,Error=~p\n", [File,Options,Reason]), + compile_error(Reason); + {ok, AbstractCode} -> + case dialyzer_utils:get_core_from_abstract_code(AbstractCode, Options) of + error -> compile_error(["Could not get core erlang for "++File]); + {ok, Core} -> + case dialyzer_utils:get_record_and_type_info(AbstractCode) of + {error, Reason} -> compile_error([Reason]); + {ok, Records} -> + Mod = cerl:concrete(cerl:module_name(Core)), + case dialyzer_utils:get_spec_info(Mod, AbstractCode, Records) of + {error, Reason} -> compile_error([Reason]); + {ok, SpecInfo, CbInfo} -> + ExpTypes = get_exported_types_from_core(Core), + analyze_core_tree(Core, Records, SpecInfo, CbInfo, + ExpTypes, Analysis, File) + end + end + end + end. + +analyze_core_tree(Core, Records, SpecInfo, CbInfo, ExpTypes, Analysis, File) -> + Module = cerl:concrete(cerl:module_name(Core)), + TmpTree = cerl:from_records(Core), + CS1 = Analysis#analysis.codeserver, + NextLabel = dialyzer_codeserver:get_next_core_label(CS1), + {Tree, NewLabel} = cerl_trees:label(TmpTree, NextLabel), + CS2 = dialyzer_codeserver:insert(Module, Tree, CS1), + CS3 = dialyzer_codeserver:set_next_core_label(NewLabel, CS2), + CS4 = dialyzer_codeserver:store_temp_records(Module, Records, CS3), + CS5 = + case Analysis#analysis.no_spec of + true -> CS4; + false -> + dialyzer_codeserver:store_temp_contracts(Module, SpecInfo, CbInfo, CS4) + end, + OldExpTypes = dialyzer_codeserver:get_temp_exported_types(CS5), + MergedExpTypes = sets:union(ExpTypes, OldExpTypes), + CS6 = dialyzer_codeserver:insert_temp_exported_types(MergedExpTypes, CS5), + Ex_Funcs = [{0,F,A} || {_,_,{F,A}} <- cerl:module_exports(Tree)], + CG = Analysis#analysis.callgraph, + {V, E} = dialyzer_callgraph:scan_core_tree(Tree, CG), + dialyzer_callgraph:add_edges(E, V, CG), + Fun = fun analyze_one_function/2, + All_Defs = cerl:module_defs(Tree), + Acc = lists:foldl(Fun, #tmpAcc{file = File, module = Module}, All_Defs), + Exported_FuncMap = map__insert({File, Ex_Funcs}, Analysis#analysis.ex_func), + %% we must sort all functions in the file which + %% originate from this file by *numerical order* of lineNo + Sorted_Functions = lists:keysort(1, Acc#tmpAcc.funcAcc), + FuncMap = map__insert({File, Sorted_Functions}, Analysis#analysis.func), + %% we do not need to sort functions which are imported from included files + IncFuncMap = map__insert({File, Acc#tmpAcc.incFuncAcc}, + Analysis#analysis.inc_func), + FMs = Analysis#analysis.fms ++ [{File, Module}], + RecordMap = map__insert({File, Records}, Analysis#analysis.record), + Analysis#analysis{fms = FMs, + callgraph = CG, + codeserver = CS6, + ex_func = Exported_FuncMap, + inc_func = IncFuncMap, + record = RecordMap, + func = FuncMap}. + +analyze_one_function({Var, FunBody} = Function, Acc) -> + F = cerl:fname_id(Var), + A = cerl:fname_arity(Var), + TmpDialyzerObj = {{Acc#tmpAcc.module, F, A}, Function}, + NewDialyzerObj = Acc#tmpAcc.dialyzerObj ++ [TmpDialyzerObj], + Anno = cerl:get_ann(FunBody), + LineNo = get_line(Anno), + FileName = get_file(Anno), + BaseName = filename:basename(FileName), + FuncInfo = {LineNo, F, A}, + OriginalName = Acc#tmpAcc.file, + {FuncAcc, IncFuncAcc} = + case (FileName =:= OriginalName) orelse (BaseName =:= OriginalName) of + true -> %% Coming from original file + %% io:format("Added function ~p\n", [{LineNo, F, A}]), + {Acc#tmpAcc.funcAcc ++ [FuncInfo], Acc#tmpAcc.incFuncAcc}; + false -> + %% Coming from other sourses, including: + %% -- .yrl (yecc-generated file) + %% -- yeccpre.hrl (yecc-generated file) + %% -- other cases + {Acc#tmpAcc.funcAcc, Acc#tmpAcc.incFuncAcc ++ [{FileName, FuncInfo}]} + end, + Acc#tmpAcc{funcAcc = FuncAcc, + incFuncAcc = IncFuncAcc, + dialyzerObj = NewDialyzerObj}. + +get_line([Line|_]) when is_integer(Line) -> Line; +get_line([_|T]) -> get_line(T); +get_line([]) -> none. + +get_file([{file,File}|_]) -> File; +get_file([_|T]) -> get_file(T); +get_file([]) -> "no_file". % should not happen + +-spec get_dialyzer_plt(analysis()) -> plt(). + +get_dialyzer_plt(#analysis{plt = PltFile0}) -> + PltFile = + case PltFile0 =:= none of + true -> dialyzer_plt:get_default_plt(); + false -> PltFile0 + end, + dialyzer_plt:from_file(PltFile). + +%% Exported Types + +get_exported_types_from_core(Core) -> + Attrs = cerl:module_attrs(Core), + ExpTypes1 = [cerl:concrete(L2) || {L1, L2} <- Attrs, + cerl:is_literal(L1), + cerl:is_literal(L2), + cerl:concrete(L1) =:= 'export_type'], + ExpTypes2 = lists:flatten(ExpTypes1), + M = cerl:atom_val(cerl:module_name(Core)), + sets:from_list([{M, F, A} || {F, A} <- ExpTypes2]). + +%%-------------------------------------------------------------------- +%% Utilities for error reporting. +%%-------------------------------------------------------------------- + +-spec fatal_error(string()) -> no_return(). + +fatal_error(Slogan) -> + msg(io_lib:format("typer: ~s\n", [Slogan])), + erlang:halt(1). + +-spec mode_error(mode(), mode()) -> no_return(). + +mode_error(OldMode, NewMode) -> + Msg = io_lib:format("Mode was previously set to '~s'; " + "can not set it to '~s' now", + [OldMode, NewMode]), + fatal_error(Msg). + +-spec compile_error([string()]) -> no_return(). + +compile_error(Reason) -> + JoinedString = lists:flatten([X ++ "\n" || X <- Reason]), + Msg = "Analysis failed with error report:\n" ++ JoinedString, + fatal_error(Msg). + +-spec msg(string()) -> 'ok'. + +msg(Msg) -> + io:format(standard_error, "~s", [Msg]). + +%%-------------------------------------------------------------------- +%% Version and help messages. +%%-------------------------------------------------------------------- + +-spec version_message() -> no_return(). + +version_message() -> + io:format("TypEr version "++?VSN++"\n"), + erlang:halt(0). + +-spec help_message() -> no_return(). + +help_message() -> + S = <<" Usage: typer [--help] [--version] [--plt PLT] [--edoc] + [--show | --show-exported | --annotate | --annotate-inc-files] + [-Ddefine]* [-I include_dir]* [-pa dir]* [-pz dir]* + [-T application]* [-r] file* + + Options: + -r dir* + search directories recursively for .erl files below them + --show + Prints type specifications for all functions on stdout. + (this is the default behaviour; this option is not really needed) + --show-exported (or --show_exported) + Same as --show, but prints specifications for exported functions only + Specs are displayed sorted alphabetically on the function's name + --annotate + Annotates the specified files with type specifications + --annotate-inc-files + Same as --annotate but annotates all -include() files as well as + all .erl files (use this option with caution - has not been tested much) + --edoc + Prints type information as Edoc @spec comments, not as type specs + --plt PLT + Use the specified dialyzer PLT file rather than the default one + -T file* + The specified file(s) already contain type specifications and these + are to be trusted in order to print specs for the rest of the files + (Multiple files or dirs, separated by spaces, can be specified.) + -Dname (or -Dname=value) + pass the defined name(s) to TypEr + (The syntax of defines is the same as that used by \"erlc\".) + -I include_dir + pass the include_dir to TypEr + (The syntax of includes is the same as that used by \"erlc\".) + -pa dir + -pz dir + Set code path options to TypEr + (This is useful for files that use parse tranforms.) + --version (or -v) + prints the Typer version and exits + --help (or -h) + prints this message and exits + + Note: + * denotes that multiple occurrences of these options are possible. +">>, + io:put_chars(S), + erlang:halt(0). + +%%-------------------------------------------------------------------- +%% Handle messages. +%%-------------------------------------------------------------------- + +rcv_ext_types() -> + Self = self(), + Self ! {Self, done}, + rcv_ext_types(Self, []). + +rcv_ext_types(Self, ExtTypes) -> + receive + {Self, ext_types, ExtType} -> + rcv_ext_types(Self, [ExtType|ExtTypes]); + {Self, done} -> + lists:usort(ExtTypes) + end. + +%%-------------------------------------------------------------------- +%% A convenient abstraction of a Key-Value mapping data structure +%% specialized for the uses in this module +%%-------------------------------------------------------------------- + +-type map_dict() :: dict:dict(). + +-spec map__new() -> map_dict(). +map__new() -> + dict:new(). + +-spec map__insert({term(), term()}, map_dict()) -> map_dict(). +map__insert(Object, Map) -> + {Key, Value} = Object, + dict:store(Key, Value, Map). + +-spec map__lookup(term(), map_dict()) -> term(). +map__lookup(Key, Map) -> + try dict:fetch(Key, Map) catch error:_ -> none end. + +-spec map__from_list([{fa(), term()}]) -> map_dict(). +map__from_list(List) -> + dict:from_list(List). + +-spec map__remove(term(), map_dict()) -> map_dict(). +map__remove(Key, Dict) -> + dict:erase(Key, Dict). + +-spec map__fold(fun((term(), term(), term()) -> map_dict()), map_dict(), map_dict()) -> map_dict(). +map__fold(Fun, Acc0, Dict) -> + dict:fold(Fun, Acc0, Dict). diff --git a/lib/dialyzer/test/Makefile b/lib/dialyzer/test/Makefile index 0d8fba438c..43c8a61ce1 100644 --- a/lib/dialyzer/test/Makefile +++ b/lib/dialyzer/test/Makefile @@ -13,7 +13,8 @@ AUXILIARY_FILES=\ file_utils.erl\ dialyzer_SUITE.erl\ abstract_SUITE.erl\ - plt_SUITE.erl + plt_SUITE.erl\ + typer_SUITE.erl # ---------------------------------------------------- # Release directory specification diff --git a/lib/dialyzer/test/typer_SUITE.erl b/lib/dialyzer/test/typer_SUITE.erl new file mode 100644 index 0000000000..da5b961643 --- /dev/null +++ b/lib/dialyzer/test/typer_SUITE.erl @@ -0,0 +1,158 @@ +%% +%% %CopyrightBegin% +%% +%% Copyright Ericsson AB 2017. All Rights Reserved. +%% +%% Licensed under the Apache License, Version 2.0 (the "License"); +%% you may not use this file except in compliance with the License. +%% You may obtain a copy of the License at +%% +%% http://www.apache.org/licenses/LICENSE-2.0 +%% +%% Unless required by applicable law or agreed to in writing, software +%% distributed under the License is distributed on an "AS IS" BASIS, +%% WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +%% See the License for the specific language governing permissions and +%% limitations under the License. +%% +%% %CopyrightEnd% +%% +-module(typer_SUITE). + +-export([all/0,suite/0,groups/0,init_per_suite/1,end_per_suite/1, + init_per_group/2,end_per_group/2, + smoke/1]). + +-include_lib("common_test/include/ct.hrl"). + +suite() -> [{ct_hooks,[ts_install_cth]}]. + +all() -> + [smoke]. + +groups() -> + []. + +init_per_suite(Config) -> + OutDir = proplists:get_value(priv_dir, Config), + case dialyzer_common:check_plt(OutDir) of + fail -> {skip, "Plt creation/check failed."}; + ok -> [{dialyzer_options, []}|Config] + end. + +end_per_suite(_Config) -> + ok. + +init_per_group(_GroupName, Config) -> + Config. + +end_per_group(_GroupName, Config) -> + Config. + +smoke(Config) -> + Code = <<"-module(typer_test_module). + -compile([export_all,nowarn_export_all]). + a(L) -> + L ++ [1,2,3].">>, + PrivDir = proplists:get_value(priv_dir, Config), + Src = filename:join(PrivDir, "typer_test_module.erl"), + ok = file:write_file(Src, Code), + Args = "--plt " ++ PrivDir ++ "dialyzer_plt", + Res = ["^$", + "^%% File:", + "^%% ----", + "^-spec a", + "^_OK_"], + run(Config, Args, Src, Res), + ok. + +typer() -> + case os:find_executable("typer") of + false -> + ct:fail("Can't find typer"); + Typer -> + Typer + end. + +%% Runs a command. + +run(Config, Args0, Name, Expect) -> + Args = Args0 ++ " " ++ Name, + Result = run_command(Config, Args), + verify_result(Result, Expect). + +verify_result(Result, Expect) -> + Messages = split(Result, [], []), + io:format("Result: ~p", [Messages]), + io:format("Expected: ~p", [Expect]), + match_messages(Messages, Expect). + +split([$\n|Rest], Current, Lines) -> + split(Rest, [], [lists:reverse(Current)|Lines]); +split([$\r|Rest], Current, Lines) -> + split(Rest, Current, Lines); +split([Char|Rest], Current, Lines) -> + split(Rest, [Char|Current], Lines); +split([], [], Lines) -> + lists:reverse(Lines); +split([], Current, Lines) -> + split([], [], [lists:reverse(Current)|Lines]). + +match_messages([Msg|Rest1], [Regexp|Rest2]) -> + case re:run(Msg, Regexp, [{capture,none}, unicode]) of + match -> + ok; + nomatch -> + io:format("Not matching: ~s\n", [Msg]), + io:format("Regexp : ~s\n", [Regexp]), + ct:fail(message_mismatch) + end, + match_messages(Rest1, Rest2); +match_messages([], [Expect|Rest]) -> + ct:fail({too_few_messages, [Expect|Rest]}); +match_messages([Msg|Rest], []) -> + ct:fail({too_many_messages, [Msg|Rest]}); +match_messages([], []) -> + ok. + +%% Runs the command using os:cmd/1. +%% +%% Returns the output from the command (as a list of characters with +%% embedded newlines). The very last line will indicate the +%% exit status of the command, where _OK_ means zero, and _ERROR_ +%% a non-zero exit status. + +run_command(Config, Args) -> + TmpDir = filename:join(proplists:get_value(priv_dir, Config), "tmp"), + file:make_dir(TmpDir), + {RunFile, Run, Script} = run_command(TmpDir, os:type(), Args), + ok = file:write_file(filename:join(TmpDir, RunFile), + unicode:characters_to_binary(Script)), + io:format("~ts\n", [Script]), + os:cmd(Run). + +run_command(Dir, {win32, _}, Args) -> + BatchFile = filename:join(Dir, "run.bat"), + Run = re:replace(filename:rootname(BatchFile), "/", "\\", + [global,{return,list}]), + Typer = typer(), + {BatchFile, + Run, + ["@echo off\r\n", + "\"",Typer,"\" ",Args, "\r\n", + "if errorlevel 1 echo _ERROR_\r\n", + "if not errorlevel 1 echo _OK_\r\n"]}; +run_command(Dir, {unix, _}, Args) -> + TyperDir = filename:dirname(typer()), + Name = filename:join(Dir, "run"), + {Name, + "/bin/sh " ++ Name, + ["#!/bin/sh\n", + "PATH=\"",TyperDir,":$PATH\"\n", + "typer ",Args,"\n", + "case $? in\n", + " 0) echo '_OK_';;\n", + " *) echo '_ERROR_';;\n", + "esac\n"]}; +run_command(_Dir, Other, _Args) -> + ct:fail("Don't know how to test exit code for ~p", [Other]). diff --git a/lib/observer/test/crashdump_helper.erl b/lib/observer/test/crashdump_helper.erl index e57c8162e4..fce15bca89 100644 --- a/lib/observer/test/crashdump_helper.erl +++ b/lib/observer/test/crashdump_helper.erl @@ -20,7 +20,7 @@ -module(crashdump_helper). -export([n1_proc/2,remote_proc/2]). --compile(r13). +-compile(r18). -include_lib("common_test/include/ct.hrl"). n1_proc(N2,Creator) -> diff --git a/lib/ssl/src/dtls_record.erl b/lib/ssl/src/dtls_record.erl index 0ee51c24b6..049f83e49e 100644 --- a/lib/ssl/src/dtls_record.erl +++ b/lib/ssl/src/dtls_record.erl @@ -439,43 +439,59 @@ encode_dtls_cipher_text(Type, {MajVer, MinVer}, Fragment, encode_plain_text(Type, Version, Data, #{compression_state := CompS0, epoch := Epoch, sequence_number := Seq, + cipher_state := CipherS0, security_parameters := #security_parameters{ cipher_type = ?AEAD, + bulk_cipher_algorithm = + BulkCipherAlgo, compression_algorithm = CompAlg} } = WriteState0) -> {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), - WriteState1 = WriteState0#{compression_state => CompS1}, AAD = calc_aad(Type, Version, Epoch, Seq), - ssl_record:cipher_aead(dtls_v1:corresponding_tls_version(Version), Comp, WriteState1, AAD); -encode_plain_text(Type, Version, Data, #{compression_state := CompS0, + TLSVersion = dtls_v1:corresponding_tls_version(Version), + {CipherFragment, CipherS1} = + ssl_cipher:cipher_aead(BulkCipherAlgo, CipherS0, Seq, AAD, Comp, TLSVersion), + {CipherFragment, WriteState0#{compression_state => CompS1, + cipher_state => CipherS1}}; +encode_plain_text(Type, Version, Fragment, #{compression_state := CompS0, epoch := Epoch, sequence_number := Seq, + cipher_state := CipherS0, security_parameters := - #security_parameters{compression_algorithm = CompAlg} + #security_parameters{compression_algorithm = CompAlg, + bulk_cipher_algorithm = + BulkCipherAlgo} }= WriteState0) -> - {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), + {Comp, CompS1} = ssl_record:compress(CompAlg, Fragment, CompS0), WriteState1 = WriteState0#{compression_state => CompS1}, - MacHash = calc_mac_hash(Type, Version, WriteState1, Epoch, Seq, Comp), - ssl_record:cipher(dtls_v1:corresponding_tls_version(Version), Comp, WriteState1, MacHash). + MAC = calc_mac_hash(Type, Version, WriteState1, Epoch, Seq, Comp), + TLSVersion = dtls_v1:corresponding_tls_version(Version), + {CipherFragment, CipherS1} = + ssl_cipher:cipher(BulkCipherAlgo, CipherS0, MAC, Fragment, TLSVersion), + {CipherFragment, WriteState0#{cipher_state => CipherS1}}. decode_cipher_text(#ssl_tls{type = Type, version = Version, epoch = Epoch, sequence_number = Seq, fragment = CipherFragment} = CipherText, #{compression_state := CompressionS0, + cipher_state := CipherS0, security_parameters := #security_parameters{ cipher_type = ?AEAD, + bulk_cipher_algorithm = + BulkCipherAlgo, compression_algorithm = CompAlg}} = ReadState0, ConnnectionStates0) -> AAD = calc_aad(Type, Version, Epoch, Seq), - case ssl_record:decipher_aead(dtls_v1:corresponding_tls_version(Version), - CipherFragment, ReadState0, AAD) of - {PlainFragment, ReadState1} -> + TLSVersion = dtls_v1:corresponding_tls_version(Version), + case ssl_cipher:decipher_aead(BulkCipherAlgo, CipherS0, Seq, AAD, CipherFragment, TLSVersion) of + {PlainFragment, CipherState} -> {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, PlainFragment, CompressionS0), - ReadState = ReadState1#{compression_state => CompressionS1}, + ReadState = ReadState0#{compression_state => CompressionS1, + cipher_state => CipherState}, ConnnectionStates = set_connection_state_by_epoch(ReadState, Epoch, ConnnectionStates0, read), {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; #alert{} = Alert -> @@ -528,5 +544,4 @@ mac_hash(Version, MacAlg, MacSecret, SeqNo, Type, Length, Fragment) -> Length, Fragment). calc_aad(Type, {MajVer, MinVer}, Epoch, SeqNo) -> - NewSeq = (Epoch bsl 48) + SeqNo, - <<NewSeq:64/integer, ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. + <<?UINT16(Epoch), ?UINT48(SeqNo), ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. diff --git a/lib/ssl/src/ssl_cipher.erl b/lib/ssl/src/ssl_cipher.erl index 8e6860e9dc..d04f09efdc 100644 --- a/lib/ssl/src/ssl_cipher.erl +++ b/lib/ssl/src/ssl_cipher.erl @@ -40,7 +40,7 @@ ec_keyed_suites/0, anonymous_suites/1, psk_suites/1, srp_suites/0, rc4_suites/1, des_suites/1, openssl_suite/1, openssl_suite_name/1, filter/2, filter_suites/1, hash_algorithm/1, sign_algorithm/1, is_acceptable_hash/2, is_fallback/1, - random_bytes/1, calc_aad/3, calc_mac_hash/4, + random_bytes/1, calc_mac_hash/4, is_stream_ciphersuite/1]). -export_type([cipher_suite/0, @@ -157,7 +157,7 @@ cipher_aead(?CHACHA20_POLY1305, CipherState, SeqNo, AAD, Fragment, Version) -> aead_cipher(chacha20_poly1305, #cipher_state{key=Key} = CipherState, SeqNo, AAD0, Fragment, _Version) -> CipherLen = erlang:iolist_size(Fragment), AAD = <<AAD0/binary, ?UINT16(CipherLen)>>, - Nonce = <<SeqNo:64/integer>>, + Nonce = ?uint64(SeqNo), {Content, CipherTag} = crypto:block_encrypt(chacha20_poly1305, Key, Nonce, {AAD, Fragment}), {<<Content/binary, CipherTag/binary>>, CipherState}; aead_cipher(Type, #cipher_state{key=Key, iv = IV0, nonce = Nonce} = CipherState, _SeqNo, AAD0, Fragment, _Version) -> @@ -280,7 +280,7 @@ aead_ciphertext_to_state(chacha20_poly1305, SeqNo, _IV, AAD0, Fragment, _Version CipherLen = size(Fragment) - 16, <<CipherText:CipherLen/bytes, CipherTag:16/bytes>> = Fragment, AAD = <<AAD0/binary, ?UINT16(CipherLen)>>, - Nonce = <<SeqNo:64/integer>>, + Nonce = ?uint64(SeqNo), {Nonce, AAD, CipherText, CipherTag}; aead_ciphertext_to_state(_, _SeqNo, <<Salt:4/bytes, _/binary>>, AAD0, Fragment, _Version) -> CipherLen = size(Fragment) - 24, @@ -1531,10 +1531,6 @@ is_fallback(CipherSuites)-> random_bytes(N) -> crypto:strong_rand_bytes(N). -calc_aad(Type, {MajVer, MinVer}, - #{sequence_number := SeqNo}) -> - <<SeqNo:64/integer, ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. - calc_mac_hash(Type, Version, PlainFragment, #{sequence_number := SeqNo, mac_secret := MacSecret, diff --git a/lib/ssl/src/ssl_record.erl b/lib/ssl/src/ssl_record.erl index 539e189c4f..24e52655b0 100644 --- a/lib/ssl/src/ssl_record.erl +++ b/lib/ssl/src/ssl_record.erl @@ -45,11 +45,7 @@ -export([compress/3, uncompress/3, compressions/0]). %% Payload encryption/decryption --export([cipher/4, decipher/4, is_correct_mac/2, - cipher_aead/4, decipher_aead/4]). - -%% Encoding --export([encode_plain_text/4]). +-export([cipher/4, decipher/4, cipher_aead/4, is_correct_mac/2]). -export_type([ssl_version/0, ssl_atom_version/0, connection_states/0, connection_state/0]). @@ -271,26 +267,6 @@ set_pending_cipher_state(#{pending_read := Read, pending_read => Read#{cipher_state => ServerState}, pending_write => Write#{cipher_state => ClientState}}. -encode_plain_text(Type, Version, Data, #{compression_state := CompS0, - security_parameters := - #security_parameters{ - cipher_type = ?AEAD, - compression_algorithm = CompAlg} - } = WriteState0) -> - {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), - WriteState1 = WriteState0#{compression_state => CompS1}, - AAD = ssl_cipher:calc_aad(Type, Version, WriteState1), - ssl_record:cipher_aead(Version, Comp, WriteState1, AAD); -encode_plain_text(Type, Version, Data, #{compression_state := CompS0, - security_parameters := - #security_parameters{compression_algorithm = CompAlg} - }= WriteState0) -> - {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), - WriteState1 = WriteState0#{compression_state => CompS1}, - MacHash = ssl_cipher:calc_mac_hash(Type, Version, Comp, WriteState1), - ssl_record:cipher(Version, Comp, WriteState1, MacHash); -encode_plain_text(_,_,_,CS) -> - exit({cs, CS}). uncompress(?NULL, Data, CS) -> {Data, CS}. @@ -322,12 +298,12 @@ cipher(Version, Fragment, {CipherFragment, CipherS1} = ssl_cipher:cipher(BulkCipherAlgo, CipherS0, MacHash, Fragment, Version), {CipherFragment, WriteState0#{cipher_state => CipherS1}}. -%%-------------------------------------------------------------------- --spec cipher_aead(ssl_version(), iodata(), connection_state(), MacHash::binary()) -> - {CipherFragment::binary(), connection_state()}. -%% -%% Description: Payload encryption -%%-------------------------------------------------------------------- +%% %%-------------------------------------------------------------------- +%% -spec cipher_aead(ssl_version(), iodata(), connection_state(), MacHash::binary()) -> +%% {CipherFragment::binary(), connection_state()}. +%% %% +%% %% Description: Payload encryption +%% %%-------------------------------------------------------------------- cipher_aead(Version, Fragment, #{cipher_state := CipherS0, sequence_number := SeqNo, @@ -341,7 +317,8 @@ cipher_aead(Version, Fragment, {CipherFragment, WriteState0#{cipher_state => CipherS1}}. %%-------------------------------------------------------------------- --spec decipher(ssl_version(), binary(), connection_state(), boolean()) -> {binary(), binary(), connection_state} | #alert{}. +-spec decipher(ssl_version(), binary(), connection_state(), boolean()) -> + {binary(), binary(), connection_state} | #alert{}. %% %% Description: Payload decryption %%-------------------------------------------------------------------- @@ -359,26 +336,7 @@ decipher(Version, CipherFragment, #alert{} = Alert -> Alert end. -%%-------------------------------------------------------------------- --spec decipher_aead(ssl_version(), binary(), connection_state(), binary()) -> - {binary(), binary(), connection_state()} | #alert{}. -%% -%% Description: Payload decryption -%%-------------------------------------------------------------------- -decipher_aead(Version, CipherFragment, - #{sequence_number := SeqNo, - security_parameters := - #security_parameters{bulk_cipher_algorithm = - BulkCipherAlgo}, - cipher_state := CipherS0 - } = ReadState, AAD) -> - case ssl_cipher:decipher_aead(BulkCipherAlgo, CipherS0, SeqNo, AAD, CipherFragment, Version) of - {PlainFragment, CipherS1} -> - CS1 = ReadState#{cipher_state => CipherS1}, - {PlainFragment, CS1}; - #alert{} = Alert -> - Alert - end. + %%-------------------------------------------------------------------- %%% Internal functions %%-------------------------------------------------------------------- diff --git a/lib/ssl/src/tls_record.erl b/lib/ssl/src/tls_record.erl index 993a1622fe..065c6dc8a7 100644 --- a/lib/ssl/src/tls_record.erl +++ b/lib/ssl/src/tls_record.erl @@ -372,7 +372,7 @@ get_tls_records_aux(Data, Acc) -> end. encode_plain_text(Type, Version, Data, #{current_write := Write0} = ConnectionStates) -> - {CipherFragment, Write1} = ssl_record:encode_plain_text(Type, Version, Data, Write0), + {CipherFragment, Write1} = do_encode_plain_text(Type, Version, Data, Write0), {CipherText, Write} = encode_tls_cipher_text(Type, Version, CipherFragment, Write1), {CipherText, ConnectionStates#{current_write => Write}}. @@ -446,19 +446,24 @@ decode_cipher_text(#ssl_tls{type = Type, version = Version, #{current_read := #{compression_state := CompressionS0, sequence_number := Seq, + cipher_state := CipherS0, security_parameters := #security_parameters{ cipher_type = ?AEAD, + bulk_cipher_algorithm = + BulkCipherAlgo, compression_algorithm = CompAlg} } = ReadState0} = ConnnectionStates0, _) -> - AAD = ssl_cipher:calc_aad(Type, Version, ReadState0), - case ssl_record:decipher_aead(Version, CipherFragment, ReadState0, AAD) of - {PlainFragment, ReadState1} -> + AAD = calc_aad(Type, Version, ReadState0), + case ssl_cipher:decipher_aead(BulkCipherAlgo, CipherS0, Seq, AAD, CipherFragment, Version) of + {PlainFragment, CipherS1} -> {Plain, CompressionS1} = ssl_record:uncompress(CompAlg, PlainFragment, CompressionS0), ConnnectionStates = ConnnectionStates0#{ - current_read => ReadState1#{sequence_number => Seq + 1, - compression_state => CompressionS1}}, + current_read => ReadState0#{ + cipher_state => CipherS1, + sequence_number => Seq + 1, + compression_state => CompressionS1}}, {CipherText#ssl_tls{fragment = Plain}, ConnnectionStates}; #alert{} = Alert -> Alert @@ -489,4 +494,29 @@ decode_cipher_text(#ssl_tls{type = Type, version = Version, end; #alert{} = Alert -> Alert - end. + end. + +do_encode_plain_text(Type, Version, Data, #{compression_state := CompS0, + security_parameters := + #security_parameters{ + cipher_type = ?AEAD, + compression_algorithm = CompAlg} + } = WriteState0) -> + {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), + WriteState1 = WriteState0#{compression_state => CompS1}, + AAD = calc_aad(Type, Version, WriteState1), + ssl_record:cipher_aead(Version, Comp, WriteState1, AAD); +do_encode_plain_text(Type, Version, Data, #{compression_state := CompS0, + security_parameters := + #security_parameters{compression_algorithm = CompAlg} + }= WriteState0) -> + {Comp, CompS1} = ssl_record:compress(CompAlg, Data, CompS0), + WriteState1 = WriteState0#{compression_state => CompS1}, + MacHash = ssl_cipher:calc_mac_hash(Type, Version, Comp, WriteState1), + ssl_record:cipher(Version, Comp, WriteState1, MacHash); +do_encode_plain_text(_,_,_,CS) -> + exit({cs, CS}). + +calc_aad(Type, {MajVer, MinVer}, + #{sequence_number := SeqNo}) -> + <<?UINT64(SeqNo), ?BYTE(Type), ?BYTE(MajVer), ?BYTE(MinVer)>>. diff --git a/lib/ssl/test/ssl_ECC_SUITE.erl b/lib/ssl/test/ssl_ECC_SUITE.erl index b77f909dfa..b05e2c74db 100644 --- a/lib/ssl/test/ssl_ECC_SUITE.erl +++ b/lib/ssl/test/ssl_ECC_SUITE.erl @@ -91,11 +91,7 @@ init_per_suite(Config0) -> end_per_suite(Config0), try crypto:start() of ok -> - %% make rsa certs using oppenssl - Config1 = ssl_test_lib:make_rsa_cert(Config0), - Config2 = ssl_test_lib:make_ecdsa_cert(Config1), - Config = ssl_test_lib:make_ecdh_rsa_cert(Config2), - ssl_test_lib:cert_options(Config) + Config0 catch _:_ -> {skip, "Crypto did not start"} end. diff --git a/lib/ssl/test/ssl_certificate_verify_SUITE.erl b/lib/ssl/test/ssl_certificate_verify_SUITE.erl index 66b0c09b73..45bcdf1f78 100644 --- a/lib/ssl/test/ssl_certificate_verify_SUITE.erl +++ b/lib/ssl/test/ssl_certificate_verify_SUITE.erl @@ -74,7 +74,7 @@ tests() -> cert_expired, invalid_signature_client, invalid_signature_server, - extended_key_usage_verify_client, + extended_key_usage_verify_both, extended_key_usage_verify_server, critical_extension_verify_client, critical_extension_verify_server, @@ -88,18 +88,14 @@ error_handling_tests()-> unknown_server_ca_accept_verify_peer, unknown_server_ca_accept_backwardscompatibility, no_authority_key_identifier, - no_authority_key_identifier_and_nonstandard_encoding]. + no_authority_key_identifier_keyEncipherment]. -init_per_suite(Config0) -> +init_per_suite(Config) -> catch crypto:stop(), try crypto:start() of ok -> - ssl_test_lib:clean_start(), - %% make rsa certs using oppenssl - {ok, _} = make_certs:all(proplists:get_value(data_dir, Config0), - proplists:get_value(priv_dir, Config0)), - Config = ssl_test_lib:make_dsa_cert(Config0), - ssl_test_lib:cert_options(Config) + ssl_test_lib:clean_start(), + ssl_test_lib:make_rsa_cert(Config) catch _:_ -> {skip, "Crypto did not start"} end. @@ -108,49 +104,39 @@ end_per_suite(_Config) -> ssl:stop(), application:stop(crypto). -init_per_group(tls, Config) -> +init_per_group(tls, Config0) -> Version = tls_record:protocol_version(tls_record:highest_protocol_version([])), ssl:stop(), application:load(ssl), application:set_env(ssl, protocol_version, Version), - application:set_env(ssl, bypass_pem_cache, Version), ssl:start(), - NewConfig = proplists:delete(protocol, Config), - [{protocol, tls}, {version, tls_record:protocol_version(Version)} | NewConfig]; + Config = proplists:delete(protocol, Config0), + [{protocol, tls}, {version, tls_record:protocol_version(Version)} | Config]; -init_per_group(dtls, Config) -> +init_per_group(dtls, Config0) -> Version = dtls_record:protocol_version(dtls_record:highest_protocol_version([])), ssl:stop(), application:load(ssl), application:set_env(ssl, protocol_version, Version), - application:set_env(ssl, bypass_pem_cache, Version), ssl:start(), - NewConfig = proplists:delete(protocol_opts, proplists:delete(protocol, Config)), - [{protocol, dtls}, {protocol_opts, [{protocol, dtls}]}, {version, dtls_record:protocol_version(Version)} | NewConfig]; + Config = proplists:delete(protocol_opts, proplists:delete(protocol, Config0)), + [{protocol, dtls}, {protocol_opts, [{protocol, dtls}]}, {version, dtls_record:protocol_version(Version)} | Config]; init_per_group(active, Config) -> - [{active, true}, {receive_function, send_recv_result_active} | Config]; + [{active, true}, {receive_function, send_recv_result_active} | Config]; init_per_group(active_once, Config) -> - [{active, once}, {receive_function, send_recv_result_active_once} | Config]; + [{active, once}, {receive_function, send_recv_result_active_once} | Config]; init_per_group(passive, Config) -> - [{active, false}, {receive_function, send_recv_result} | Config]; + [{active, false}, {receive_function, send_recv_result} | Config]; +init_per_group(error_handling, Config) -> + [{active, false}, {receive_function, send_recv_result} | Config]; + init_per_group(_, Config) -> Config. end_per_group(_GroupName, Config) -> Config. -init_per_testcase(TestCase, Config) when TestCase == cert_expired; - TestCase == invalid_signature_client; - TestCase == invalid_signature_server; - TestCase == extended_key_usage_verify_none; - TestCase == extended_key_usage_verify_peer; - TestCase == critical_extension_verify_none; - TestCase == critical_extension_verify_peer; - TestCase == no_authority_key_identifier; - TestCase == no_authority_key_identifier_and_nonstandard_encoding-> - ssl:clear_pem_cache(), - init_per_testcase(common, Config); init_per_testcase(_TestCase, Config) -> ssl:stop(), ssl:start(), @@ -168,23 +154,23 @@ end_per_testcase(_TestCase, Config) -> verify_peer() -> [{doc,"Test option verify_peer"}]. verify_peer(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_opts, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, - {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{active, Active}, {verify, verify_peer} - | ServerOpts]}]), + {mfa, {ssl_test_lib, ReceiveFunction, []}}, + {options, [{active, Active}, {verify, verify_peer} + | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client([{node, ClientNode}, {port, Port}, {host, Hostname}, - {from, self()}, - {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{active, Active} | ClientOpts]}]), - + {from, self()}, + {mfa, {ssl_test_lib, ReceiveFunction, []}}, + {options, [{active, Active}, {verify, verify_peer} | ClientOpts]}]), + ssl_test_lib:check_result(Server, ok, Client, ok), ssl_test_lib:close(Server), ssl_test_lib:close(Client). @@ -194,23 +180,24 @@ verify_none() -> [{doc,"Test option verify_none"}]. verify_none(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_opts, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, - {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{active, Active}, {verify, verify_none} - | ServerOpts]}]), + {mfa, {ssl_test_lib, ReceiveFunction, []}}, + {options, [{active, Active}, {verify, verify_none} + | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client([{node, ClientNode}, {port, Port}, {host, Hostname}, - {from, self()}, - {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{active, Active} | ClientOpts]}]), + {from, self()}, + {mfa, {ssl_test_lib, ReceiveFunction, []}}, + {options, [{active, Active}, + {verify, verify_none} | ClientOpts]}]), ssl_test_lib:check_result(Server, ok, Client, ok), ssl_test_lib:close(Server), @@ -222,8 +209,8 @@ server_verify_client_once() -> [{doc,"Test server option verify_client_once"}]. server_verify_client_once(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_opts, []), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, []), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_opts, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), @@ -239,7 +226,7 @@ server_verify_client_once(Config) when is_list(Config) -> {host, Hostname}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{active, Active} | ClientOpts]}]), + {options, [{active, Active} | ClientOpts]}]), ssl_test_lib:check_result(Server, ok, Client0, ok), Server ! {listen, {mfa, {ssl_test_lib, no_result, []}}}, @@ -261,8 +248,8 @@ server_require_peer_cert_ok() -> server_require_peer_cert_ok(Config) when is_list(Config) -> ServerOpts = [{verify, verify_peer}, {fail_if_no_peer_cert, true} - | ssl_test_lib:ssl_options(server_verification_opts, Config)], - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), + | ssl_test_lib:ssl_options(server_rsa_opts, Config)], + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), @@ -290,20 +277,21 @@ server_require_peer_cert_fail() -> server_require_peer_cert_fail(Config) when is_list(Config) -> ServerOpts = [{verify, verify_peer}, {fail_if_no_peer_cert, true} - | ssl_test_lib:ssl_options(server_verification_opts, Config)], + | ssl_test_lib:ssl_options(server_rsa_opts, Config)], BadClientOpts = ssl_test_lib:ssl_options(empty_client_opts, Config), + Active = proplists:get_value(active, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server_error([{node, ServerNode}, {port, 0}, {from, self()}, - {options, [{active, false} | ServerOpts]}]), + {options, [{active, Active} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client_error([{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, - {options, [{active, false} | BadClientOpts]}]), + {options, [{active, Active} | BadClientOpts]}]), receive {Server, {error, {tls_alert, "handshake failure"}}} -> receive @@ -321,24 +309,25 @@ server_require_peer_cert_partial_chain() -> server_require_peer_cert_partial_chain(Config) when is_list(Config) -> ServerOpts = [{verify, verify_peer}, {fail_if_no_peer_cert, true} - | ssl_test_lib:ssl_options(server_verification_opts, Config)], - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), + | ssl_test_lib:ssl_options(server_rsa_opts, Config)], + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), + Active = proplists:get_value(active, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), {ok, ClientCAs} = file:read_file(proplists:get_value(cacertfile, ClientOpts)), - [{_,RootCA,_}, {_, _, _}] = public_key:pem_decode(ClientCAs), + [{_,RootCA,_} | _] = public_key:pem_decode(ClientCAs), Server = ssl_test_lib:start_server_error([{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, no_result, []}}, - {options, [{active, false} | ServerOpts]}]), + {options, [{active, Active} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client_error([{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, {mfa, {ssl_test_lib, no_result, []}}, - {options, [{active, false}, + {options, [{active, Active}, {cacerts, [RootCA]} | proplists:delete(cacertfile, ClientOpts)]}]), receive @@ -356,14 +345,14 @@ server_require_peer_cert_allow_partial_chain() -> server_require_peer_cert_allow_partial_chain(Config) when is_list(Config) -> ServerOpts = [{verify, verify_peer}, {fail_if_no_peer_cert, true} - | ssl_test_lib:ssl_options(server_verification_opts, Config)], - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), + | ssl_test_lib:ssl_options(server_rsa_opts, Config)], + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), {ok, ClientCAs} = file:read_file(proplists:get_value(cacertfile, ClientOpts)), - [{_,_,_}, {_, IntermidiateCA, _}] = public_key:pem_decode(ClientCAs), + [{_,_,_}, {_, IntermidiateCA, _} | _] = public_key:pem_decode(ClientCAs), PartialChain = fun(CertChain) -> case lists:member(IntermidiateCA, CertChain) of @@ -398,12 +387,12 @@ server_require_peer_cert_do_not_allow_partial_chain() -> server_require_peer_cert_do_not_allow_partial_chain(Config) when is_list(Config) -> ServerOpts = [{verify, verify_peer}, {fail_if_no_peer_cert, true} - | ssl_test_lib:ssl_options(server_verification_opts, Config)], - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), + | ssl_test_lib:ssl_options(server_rsa_opts, Config)], + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), {ok, ServerCAs} = file:read_file(proplists:get_value(cacertfile, ServerOpts)), - [{_,_,_}, {_, IntermidiateCA, _}] = public_key:pem_decode(ServerCAs), + [{_,_,_}, {_, IntermidiateCA, _} | _] = public_key:pem_decode(ServerCAs), PartialChain = fun(_CertChain) -> unknown_ca @@ -439,12 +428,12 @@ server_require_peer_cert_partial_chain_fun_fail() -> server_require_peer_cert_partial_chain_fun_fail(Config) when is_list(Config) -> ServerOpts = [{verify, verify_peer}, {fail_if_no_peer_cert, true} - | ssl_test_lib:ssl_options(server_verification_opts, Config)], - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), + | ssl_test_lib:ssl_options(server_rsa_opts, Config)], + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), {ok, ServerCAs} = file:read_file(proplists:get_value(cacertfile, ServerOpts)), - [{_,_,_}, {_, IntermidiateCA, _}] = public_key:pem_decode(ServerCAs), + [{_,_,_}, {_, IntermidiateCA, _} | _] = public_key:pem_decode(ServerCAs), PartialChain = fun(_CertChain) -> ture = false %% crash on purpose @@ -479,8 +468,8 @@ verify_fun_always_run_client() -> [{doc,"Verify that user verify_fun is always run (for valid and valid_peer not only unknown_extension)"}]. verify_fun_always_run_client(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_opts, Config), + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server_error([{node, ServerNode}, {port, 0}, {from, self()}, @@ -524,8 +513,8 @@ verify_fun_always_run_client(Config) when is_list(Config) -> verify_fun_always_run_server() -> [{doc,"Verify that user verify_fun is always run (for valid and valid_peer not only unknown_extension)"}]. verify_fun_always_run_server(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), %% If user verify fun is called correctly we fail the connection. @@ -573,63 +562,28 @@ cert_expired() -> [{doc,"Test server with expired certificate"}]. cert_expired(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), - - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), - Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), - - ServerCertFile = proplists:get_value(certfile, ServerOpts), - NewServerCertFile = filename:join(PrivDir, "server/expired_cert.pem"), - [{'Certificate', DerCert, _}] = ssl_test_lib:pem_to_der(ServerCertFile), - OTPCert = public_key:pkix_decode_cert(DerCert, otp), - OTPTbsCert = OTPCert#'OTPCertificate'.tbsCertificate, - {Year, Month, Day} = date(), - {Hours, Min, Sec} = time(), - NotBeforeStr = lists:flatten(io_lib:format("~p~s~s~s~s~sZ",[Year-2, - two_digits_str(Month), - two_digits_str(Day), - two_digits_str(Hours), - two_digits_str(Min), - two_digits_str(Sec)])), - NotAfterStr = lists:flatten(io_lib:format("~p~s~s~s~s~sZ",[Year-1, - two_digits_str(Month), - two_digits_str(Day), - two_digits_str(Hours), - two_digits_str(Min), - two_digits_str(Sec)])), - NewValidity = {'Validity', {generalTime, NotBeforeStr}, {generalTime, NotAfterStr}}, - - ct:log("Validity: ~p ~n NewValidity: ~p ~n", - [OTPTbsCert#'OTPTBSCertificate'.validity, NewValidity]), - - NewOTPTbsCert = OTPTbsCert#'OTPTBSCertificate'{validity = NewValidity}, - NewServerDerCert = public_key:pkix_sign(NewOTPTbsCert, Key), - ssl_test_lib:der_to_pem(NewServerCertFile, [{'Certificate', NewServerDerCert, not_encrypted}]), - NewServerOpts = [{certfile, NewServerCertFile} | proplists:delete(certfile, ServerOpts)], - + Active = proplists:get_value(active, Config), + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{server_ca_0, + [{validity, {{Year-2, Month, Day}, + {Year-1, Month, Day}}}]}], + Config, "_expired"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), + {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server_error([{node, ServerNode}, {port, 0}, {from, self()}, - {options, NewServerOpts}]), + {options, [{active, Active}| ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client_error([{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, - {options, [{verify, verify_peer} | ClientOpts]}]), - receive - {Client, {error, {tls_alert, "certificate expired"}}} -> - receive - {Server, {error, {tls_alert, "certificate expired"}}} -> - ok; - {Server, {error, closed}} -> - ok - end - end. + {options, [{verify, verify_peer}, {active, Active} | ClientOpts]}]), + + tcp_delivery_workaround(Server, {error, {tls_alert, "certificate expired"}}, + Client, {error, {tls_alert, "certificate expired"}}). two_digits_str(N) when N < 10 -> lists:flatten(io_lib:format("0~p", [N])); @@ -638,60 +592,32 @@ two_digits_str(N) -> %%-------------------------------------------------------------------- extended_key_usage_verify_server() -> - [{doc,"Test cert that has a critical extended_key_usage extension in verify_peer mode for server"}]. - -extended_key_usage_verify_server(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), + [{doc,"Test cert that has a critical extended_key_usage extension in server cert"}]. + +extended_key_usage_verify_server(Config) when is_list(Config) -> + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{server_peer_opts, + [{extensions, + [{?'id-ce-extKeyUsage', + [?'id-kp-serverAuth'], true}] + }]}], Config, "_keyusage_server"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), - Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), - - ServerCertFile = proplists:get_value(certfile, ServerOpts), - NewServerCertFile = filename:join(PrivDir, "server/new_cert.pem"), - [{'Certificate', ServerDerCert, _}] = ssl_test_lib:pem_to_der(ServerCertFile), - ServerOTPCert = public_key:pkix_decode_cert(ServerDerCert, otp), - ServerExtKeyUsageExt = {'Extension', ?'id-ce-extKeyUsage', true, [?'id-kp-serverAuth']}, - ServerOTPTbsCert = ServerOTPCert#'OTPCertificate'.tbsCertificate, - ServerExtensions = ServerOTPTbsCert#'OTPTBSCertificate'.extensions, - NewServerOTPTbsCert = ServerOTPTbsCert#'OTPTBSCertificate'{extensions = - [ServerExtKeyUsageExt | - ServerExtensions]}, - NewServerDerCert = public_key:pkix_sign(NewServerOTPTbsCert, Key), - ssl_test_lib:der_to_pem(NewServerCertFile, [{'Certificate', NewServerDerCert, not_encrypted}]), - NewServerOpts = [{certfile, NewServerCertFile} | proplists:delete(certfile, ServerOpts)], - - ClientCertFile = proplists:get_value(certfile, ClientOpts), - NewClientCertFile = filename:join(PrivDir, "client/new_cert.pem"), - [{'Certificate', ClientDerCert, _}] = ssl_test_lib:pem_to_der(ClientCertFile), - ClientOTPCert = public_key:pkix_decode_cert(ClientDerCert, otp), - ClientExtKeyUsageExt = {'Extension', ?'id-ce-extKeyUsage', true, [?'id-kp-clientAuth']}, - ClientOTPTbsCert = ClientOTPCert#'OTPCertificate'.tbsCertificate, - ClientExtensions = ClientOTPTbsCert#'OTPTBSCertificate'.extensions, - NewClientOTPTbsCert = ClientOTPTbsCert#'OTPTBSCertificate'{extensions = - [ClientExtKeyUsageExt | - ClientExtensions]}, - NewClientDerCert = public_key:pkix_sign(NewClientOTPTbsCert, Key), - ssl_test_lib:der_to_pem(NewClientCertFile, [{'Certificate', NewClientDerCert, not_encrypted}]), - NewClientOpts = [{certfile, NewClientCertFile} | proplists:delete(certfile, ClientOpts)], - {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_peer}, {active, Active} | NewServerOpts]}]), + {options, [{verify, verify_none}, {active, Active} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client([{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_none}, {active, Active} | - NewClientOpts]}]), + {options, [{verify, verify_peer}, {active, Active} | + ClientOpts]}]), ssl_test_lib:check_result(Server, ok, Client, ok), @@ -699,60 +625,35 @@ extended_key_usage_verify_server(Config) when is_list(Config) -> ssl_test_lib:close(Client). %%-------------------------------------------------------------------- -extended_key_usage_verify_client() -> +extended_key_usage_verify_both() -> [{doc,"Test cert that has a critical extended_key_usage extension in client verify_peer mode"}]. -extended_key_usage_verify_client(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), +extended_key_usage_verify_both(Config) when is_list(Config) -> + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{server_peer_opts, + [{extensions, [{?'id-ce-extKeyUsage', + [?'id-kp-serverAuth'], true}] + }]}, + {client_peer_opts, + [{extensions, [{?'id-ce-extKeyUsage', + [?'id-kp-clientAuth'], true}] + }]}], Config, "_keyusage_both"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), - Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), - - ServerCertFile = proplists:get_value(certfile, ServerOpts), - NewServerCertFile = filename:join(PrivDir, "server/new_cert.pem"), - [{'Certificate', ServerDerCert, _}] = ssl_test_lib:pem_to_der(ServerCertFile), - ServerOTPCert = public_key:pkix_decode_cert(ServerDerCert, otp), - ServerExtKeyUsageExt = {'Extension', ?'id-ce-extKeyUsage', true, [?'id-kp-serverAuth']}, - ServerOTPTbsCert = ServerOTPCert#'OTPCertificate'.tbsCertificate, - ServerExtensions = ServerOTPTbsCert#'OTPTBSCertificate'.extensions, - NewServerOTPTbsCert = ServerOTPTbsCert#'OTPTBSCertificate'{extensions = - [ServerExtKeyUsageExt | - ServerExtensions]}, - NewServerDerCert = public_key:pkix_sign(NewServerOTPTbsCert, Key), - ssl_test_lib:der_to_pem(NewServerCertFile, [{'Certificate', NewServerDerCert, not_encrypted}]), - NewServerOpts = [{certfile, NewServerCertFile} | proplists:delete(certfile, ServerOpts)], - - ClientCertFile = proplists:get_value(certfile, ClientOpts), - NewClientCertFile = filename:join(PrivDir, "client/new_cert.pem"), - [{'Certificate', ClientDerCert, _}] = ssl_test_lib:pem_to_der(ClientCertFile), - ClientOTPCert = public_key:pkix_decode_cert(ClientDerCert, otp), - ClientExtKeyUsageExt = {'Extension', ?'id-ce-extKeyUsage', true, [?'id-kp-clientAuth']}, - ClientOTPTbsCert = ClientOTPCert#'OTPCertificate'.tbsCertificate, - ClientExtensions = ClientOTPTbsCert#'OTPTBSCertificate'.extensions, - NewClientOTPTbsCert = ClientOTPTbsCert#'OTPTBSCertificate'{extensions = - [ClientExtKeyUsageExt | - ClientExtensions]}, - NewClientDerCert = public_key:pkix_sign(NewClientOTPTbsCert, Key), - ssl_test_lib:der_to_pem(NewClientCertFile, [{'Certificate', NewClientDerCert, not_encrypted}]), - NewClientOpts = [{certfile, NewClientCertFile} | proplists:delete(certfile, ClientOpts)], - {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_none}, {active, Active} | NewServerOpts]}]), + {options, [{verify, verify_peer}, {active, Active} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client([{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_none}, {active, Active} | NewClientOpts]}]), + {options, [{verify, verify_peer}, {active, Active} | ClientOpts]}]), ssl_test_lib:check_result(Server, ok, Client, ok), @@ -764,132 +665,103 @@ critical_extension_verify_server() -> [{doc,"Test cert that has a critical unknown extension in verify_peer mode"}]. critical_extension_verify_server(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{client_peer_opts, + [{extensions, [{{2,16,840,1,113730,1,1}, + <<3,2,6,192>>, true}] + }]}], Config, "_client_unknown_extension"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - NewCertName = integer_to_list(erlang:unique_integer()) ++ ".pem", - - ServerCertFile = proplists:get_value(certfile, ServerOpts), - NewServerCertFile = filename:join([PrivDir, "server", NewCertName]), - add_critical_netscape_cert_type(ServerCertFile, NewServerCertFile, KeyFile), - NewServerOpts = [{certfile, NewServerCertFile} | proplists:delete(certfile, ServerOpts)], - - ClientCertFile = proplists:get_value(certfile, ClientOpts), - NewClientCertFile = filename:join([PrivDir, "client", NewCertName]), - add_critical_netscape_cert_type(ClientCertFile, NewClientCertFile, KeyFile), - NewClientOpts = [{certfile, NewClientCertFile} | proplists:delete(certfile, ClientOpts)], - {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server_error( [{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_peer}, {active, Active} | NewServerOpts]}]), + {options, [{verify, verify_peer}, {active, Active} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client_error( [{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_none}, {active, Active} | NewClientOpts]}]), + {options, [{verify, verify_none}, {active, Active} | ClientOpts]}]), %% This certificate has a critical extension that we don't - %% understand. Therefore, verification should fail. - tcp_delivery_workaround(Server, {error, {tls_alert, "unsupported certificate"}}, - Client, {error, {tls_alert, "unsupported certificate"}}), + %% understand. Therefore, verification should fail. - ssl_test_lib:close(Server), - ok. + tcp_delivery_workaround(Server, {error, {tls_alert, "unsupported certificate"}}, + Client, {error, {tls_alert, "unsupported certificate"}}), + + ssl_test_lib:close(Server). %%-------------------------------------------------------------------- critical_extension_verify_client() -> [{doc,"Test cert that has a critical unknown extension in verify_peer mode"}]. critical_extension_verify_client(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{server_peer_opts, + [{extensions, [{{2,16,840,1,113730,1,1}, + <<3,2,6,192>>, true}] + }]}], Config, "_server_unknown_extensions"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - NewCertName = integer_to_list(erlang:unique_integer()) ++ ".pem", - - ServerCertFile = proplists:get_value(certfile, ServerOpts), - NewServerCertFile = filename:join([PrivDir, "server", NewCertName]), - add_critical_netscape_cert_type(ServerCertFile, NewServerCertFile, KeyFile), - NewServerOpts = [{certfile, NewServerCertFile} | proplists:delete(certfile, ServerOpts)], - - ClientCertFile = proplists:get_value(certfile, ClientOpts), - NewClientCertFile = filename:join([PrivDir, "client", NewCertName]), - add_critical_netscape_cert_type(ClientCertFile, NewClientCertFile, KeyFile), - NewClientOpts = [{certfile, NewClientCertFile} | proplists:delete(certfile, ClientOpts)], - {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server_error( [{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_none}, {active, Active} | NewServerOpts]}]), + {options, [{verify, verify_none}, {active, Active} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client_error( [{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_peer}, {active, Active} | NewClientOpts]}]), + {options, [{verify, verify_peer}, {active, Active} | ClientOpts]}]), %% This certificate has a critical extension that we don't %% understand. Therefore, verification should fail. - tcp_delivery_workaround(Server, {error, {tls_alert, "unsupported certificate"}}, - Client, {error, {tls_alert, "unsupported certificate"}}), + ssl_test_lib:check_result(Server, {error, {tls_alert, "unsupported certificate"}}, + Client, {error, {tls_alert, "unsupported certificate"}}), + + ssl_test_lib:close(Server). - ssl_test_lib:close(Server), - ok. %%-------------------------------------------------------------------- critical_extension_verify_none() -> [{doc,"Test cert that has a critical unknown extension in verify_none mode"}]. critical_extension_verify_none(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{client_peer_opts, + [{extensions, + [{{2,16,840,1,113730,1,1}, + <<3,2,6,192>>, true}] + }]}], Config, "_unknown_extensions"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), Active = proplists:get_value(active, Config), ReceiveFunction = proplists:get_value(receive_function, Config), - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - NewCertName = integer_to_list(erlang:unique_integer()) ++ ".pem", - - ServerCertFile = proplists:get_value(certfile, ServerOpts), - NewServerCertFile = filename:join([PrivDir, "server", NewCertName]), - add_critical_netscape_cert_type(ServerCertFile, NewServerCertFile, KeyFile), - NewServerOpts = [{certfile, NewServerCertFile} | proplists:delete(certfile, ServerOpts)], - - ClientCertFile = proplists:get_value(certfile, ClientOpts), - NewClientCertFile = filename:join([PrivDir, "client", NewCertName]), - add_critical_netscape_cert_type(ClientCertFile, NewClientCertFile, KeyFile), - NewClientOpts = [{certfile, NewClientCertFile} | proplists:delete(certfile, ClientOpts)], - {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server( [{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_none}, {active, Active} | NewServerOpts]}]), + {options, [{verify, verify_none}, {active, Active} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client( [{node, ClientNode}, {port, Port}, {host, Hostname}, {from, self()}, {mfa, {ssl_test_lib, ReceiveFunction, []}}, - {options, [{verify, verify_none}, {active, Active} | NewClientOpts]}]), + {options, [{verify, verify_none}, {active, Active} | ClientOpts]}]), %% This certificate has a critical extension that we don't %% understand. But we're using `verify_none', so verification @@ -897,28 +769,7 @@ critical_extension_verify_none(Config) when is_list(Config) -> ssl_test_lib:check_result(Server, ok, Client, ok), ssl_test_lib:close(Server), - ssl_test_lib:close(Client), - ok. - -add_critical_netscape_cert_type(CertFile, NewCertFile, KeyFile) -> - [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), - Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), - - [{'Certificate', DerCert, _}] = ssl_test_lib:pem_to_der(CertFile), - OTPCert = public_key:pkix_decode_cert(DerCert, otp), - %% This is the "Netscape Cert Type" extension, telling us that the - %% certificate can be used for SSL clients and SSL servers. - NetscapeCertTypeExt = #'Extension'{ - extnID = {2,16,840,1,113730,1,1}, - critical = true, - extnValue = <<3,2,6,192>>}, - OTPTbsCert = OTPCert#'OTPCertificate'.tbsCertificate, - Extensions = OTPTbsCert#'OTPTBSCertificate'.extensions, - NewOTPTbsCert = OTPTbsCert#'OTPTBSCertificate'{ - extensions = [NetscapeCertTypeExt] ++ Extensions}, - NewDerCert = public_key:pkix_sign(NewOTPTbsCert, Key), - ssl_test_lib:der_to_pem(NewCertFile, [{'Certificate', NewDerCert, not_encrypted}]), - ok. + ssl_test_lib:close(Client). %%-------------------------------------------------------------------- no_authority_key_identifier() -> @@ -926,35 +777,21 @@ no_authority_key_identifier() -> " but are present in trusted certs db."}]. no_authority_key_identifier(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), - - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), - Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), - - CertFile = proplists:get_value(certfile, ServerOpts), - NewCertFile = filename:join(PrivDir, "server/new_cert.pem"), - [{'Certificate', DerCert, _}] = ssl_test_lib:pem_to_der(CertFile), - OTPCert = public_key:pkix_decode_cert(DerCert, otp), - OTPTbsCert = OTPCert#'OTPCertificate'.tbsCertificate, - Extensions = OTPTbsCert#'OTPTBSCertificate'.extensions, - NewExtensions = delete_authority_key_extension(Extensions, []), - NewOTPTbsCert = OTPTbsCert#'OTPTBSCertificate'{extensions = NewExtensions}, - - ct:log("Extensions ~p~n, NewExtensions: ~p~n", [Extensions, NewExtensions]), - - NewDerCert = public_key:pkix_sign(NewOTPTbsCert, Key), - ssl_test_lib:der_to_pem(NewCertFile, [{'Certificate', NewDerCert, not_encrypted}]), - NewServerOpts = [{certfile, NewCertFile} | proplists:delete(certfile, ServerOpts)], + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{server_peer_opts, + [{extensions, [{auth_key_id, undefined}] + }]}, + {client_peer_opts, + [{extensions, [{auth_key_id, undefined}] + }]}], Config, "_peer_no_auth_key_id"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, send_recv_result_active, []}}, - {options, NewServerOpts}]), + {options, ServerOpts}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client([{node, ClientNode}, {port, Port}, {host, Hostname}, @@ -970,53 +807,35 @@ no_authority_key_identifier(Config) when is_list(Config) -> delete_authority_key_extension([], Acc) -> lists:reverse(Acc); delete_authority_key_extension([#'Extension'{extnID = ?'id-ce-authorityKeyIdentifier'} | Rest], - Acc) -> + Acc) -> delete_authority_key_extension(Rest, Acc); delete_authority_key_extension([Head | Rest], Acc) -> delete_authority_key_extension(Rest, [Head | Acc]). %%-------------------------------------------------------------------- -no_authority_key_identifier_and_nonstandard_encoding() -> - [{doc, "Test cert with nonstandard encoding that does not have" - " authorityKeyIdentifier extension but are present in trusted certs db."}]. - -no_authority_key_identifier_and_nonstandard_encoding(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), - PrivDir = proplists:get_value(priv_dir, Config), - - KeyFile = filename:join(PrivDir, "otpCA/private/key.pem"), - [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), - Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), - - CertFile = proplists:get_value(certfile, ServerOpts), - NewCertFile = filename:join(PrivDir, "server/new_cert.pem"), - [{'Certificate', DerCert, _}] = ssl_test_lib:pem_to_der(CertFile), - ServerCert = public_key:pkix_decode_cert(DerCert, plain), - ServerTbsCert = ServerCert#'Certificate'.tbsCertificate, - Extensions0 = ServerTbsCert#'TBSCertificate'.extensions, - %% need to remove authorityKeyIdentifier extension to cause DB lookup by signature - Extensions = delete_authority_key_extension(Extensions0, []), - NewExtensions = replace_key_usage_extension(Extensions, []), - NewServerTbsCert = ServerTbsCert#'TBSCertificate'{extensions = NewExtensions}, - - ct:log("Extensions ~p~n, NewExtensions: ~p~n", [Extensions, NewExtensions]), - - TbsDer = public_key:pkix_encode('TBSCertificate', NewServerTbsCert, plain), - Sig = public_key:sign(TbsDer, md5, Key), - NewServerCert = ServerCert#'Certificate'{tbsCertificate = NewServerTbsCert, signature = Sig}, - NewDerCert = public_key:pkix_encode('Certificate', NewServerCert, plain), - ssl_test_lib:der_to_pem(NewCertFile, [{'Certificate', NewDerCert, not_encrypted}]), - NewServerOpts = [{certfile, NewCertFile} | proplists:delete(certfile, ServerOpts)], - +no_authority_key_identifier_keyEncipherment() -> + [{doc, "Test cert with keyEncipherment key_usage an no" + " authorityKeyIdentifier extension, but are present in trusted certs db."}]. + +no_authority_key_identifier_keyEncipherment(Config) when is_list(Config) -> + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{server_peer_opts, + [{extensions, [{auth_key_id, undefined}, + {key_usage, [digitalSignature, + keyEncipherment]}] + }]}, + {client_peer_opts, + [{extensions, [{auth_key_id, undefined}] + }]}], Config, "_peer_keyEncipherment"), + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, {mfa, {ssl_test_lib, send_recv_result_active, []}}, - {options, [{active, true} | NewServerOpts]}]), + {options, [{active, true} | ServerOpts]}]), Port = ssl_test_lib:inet_port(Server), Client = ssl_test_lib:start_client([{node, ClientNode}, {port, Port}, {host, Hostname}, @@ -1028,14 +847,6 @@ no_authority_key_identifier_and_nonstandard_encoding(Config) when is_list(Config ssl_test_lib:close(Server), ssl_test_lib:close(Client). -replace_key_usage_extension([], Acc) -> - lists:reverse(Acc); -replace_key_usage_extension([#'Extension'{extnID = ?'id-ce-keyUsage'} = E | Rest], Acc) -> - %% A nonstandard DER encoding of [digitalSignature, keyEncipherment] - Val = <<3, 2, 0, 16#A0>>, - replace_key_usage_extension(Rest, [E#'Extension'{extnValue = Val} | Acc]); -replace_key_usage_extension([Head | Rest], Acc) -> - replace_key_usage_extension(Rest, [Head | Acc]). %%-------------------------------------------------------------------- @@ -1043,16 +854,16 @@ invalid_signature_server() -> [{doc,"Test client with invalid signature"}]. invalid_signature_server(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_opts, Config), PrivDir = proplists:get_value(priv_dir, Config), - KeyFile = filename:join(PrivDir, "server/key.pem"), + KeyFile = proplists:get_value(keyfile, ServerOpts), [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), ServerCertFile = proplists:get_value(certfile, ServerOpts), - NewServerCertFile = filename:join(PrivDir, "server/invalid_cert.pem"), + NewServerCertFile = filename:join(PrivDir, "server_invalid_cert.pem"), [{'Certificate', ServerDerCert, _}] = ssl_test_lib:pem_to_der(ServerCertFile), ServerOTPCert = public_key:pkix_decode_cert(ServerDerCert, otp), ServerOTPTbsCert = ServerOTPCert#'OTPCertificate'.tbsCertificate, @@ -1071,8 +882,8 @@ invalid_signature_server(Config) when is_list(Config) -> {from, self()}, {options, [{verify, verify_peer} | ClientOpts]}]), - tcp_delivery_workaround(Server, {error, {tls_alert, "bad certificate"}}, - Client, {error, {tls_alert, "bad certificate"}}). + tcp_delivery_workaround(Server, {error, {tls_alert, "unknown ca"}}, + Client, {error, {tls_alert, "unknown ca"}}). %%-------------------------------------------------------------------- @@ -1080,16 +891,16 @@ invalid_signature_client() -> [{doc,"Test server with invalid signature"}]. invalid_signature_client(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ClientOpts = ssl_test_lib:ssl_options(client_rsa_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_opts, Config), PrivDir = proplists:get_value(priv_dir, Config), - KeyFile = filename:join(PrivDir, "client/key.pem"), + KeyFile = proplists:get_value(keyfile, ClientOpts), [KeyEntry] = ssl_test_lib:pem_to_der(KeyFile), Key = ssl_test_lib:public_key(public_key:pem_entry_decode(KeyEntry)), ClientCertFile = proplists:get_value(certfile, ClientOpts), - NewClientCertFile = filename:join(PrivDir, "client/invalid_cert.pem"), + NewClientCertFile = filename:join(PrivDir, "client_invalid_cert.pem"), [{'Certificate', ClientDerCert, _}] = ssl_test_lib:pem_to_der(ClientCertFile), ClientOTPCert = public_key:pkix_decode_cert(ClientDerCert, otp), ClientOTPTbsCert = ClientOTPCert#'OTPCertificate'.tbsCertificate, @@ -1108,8 +919,8 @@ invalid_signature_client(Config) when is_list(Config) -> {from, self()}, {options, NewClientOpts}]), - tcp_delivery_workaround(Server, {error, {tls_alert, "bad certificate"}}, - Client, {error, {tls_alert, "bad certificate"}}). + tcp_delivery_workaround(Server, {error, {tls_alert, "unknown ca"}}, + Client, {error, {tls_alert, "unknown ca"}}). %%-------------------------------------------------------------------- @@ -1118,8 +929,14 @@ client_with_cert_cipher_suites_handshake() -> [{doc, "Test that client with a certificate without keyEncipherment usage " " extension can connect to a server with restricted cipher suites "}]. client_with_cert_cipher_suites_handshake(Config) when is_list(Config) -> - ClientOpts = ssl_test_lib:ssl_options(client_verification_opts_digital_signature_only, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + {ClientOpts0, ServerOpts0} = ssl_test_lib:make_rsa_cert_chains([{client_peer_opts, + [{extensions, + [{key_usage, [digitalSignature]}] + }]}], Config, "_sign_only_extensions"), + + + ClientOpts = ssl_test_lib:ssl_options(ClientOpts0, Config), + ServerOpts = ssl_test_lib:ssl_options(ServerOpts0, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, @@ -1148,7 +965,7 @@ client_with_cert_cipher_suites_handshake(Config) when is_list(Config) -> server_verify_no_cacerts() -> [{doc,"Test server must have cacerts if it wants to verify client"}]. server_verify_no_cacerts(Config) when is_list(Config) -> - ServerOpts = proplists:delete(cacertfile, ssl_test_lib:ssl_options(server_opts, Config)), + ServerOpts = proplists:delete(cacertfile, ssl_test_lib:ssl_options(server_rsa_opts, Config)), {_, ServerNode, _} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server_error([{node, ServerNode}, {port, 0}, {from, self()}, @@ -1163,7 +980,7 @@ unknown_server_ca_fail() -> [{doc,"Test that the client fails if the ca is unknown in verify_peer mode"}]. unknown_server_ca_fail(Config) when is_list(Config) -> ClientOpts = ssl_test_lib:ssl_options(empty_client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server_error([{node, ServerNode}, {port, 0}, {from, self()}, @@ -1207,7 +1024,7 @@ unknown_server_ca_accept_verify_none() -> [{doc,"Test that the client succeds if the ca is unknown in verify_none mode"}]. unknown_server_ca_accept_verify_none(Config) when is_list(Config) -> ClientOpts = ssl_test_lib:ssl_options(empty_client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, @@ -1232,7 +1049,7 @@ unknown_server_ca_accept_verify_peer() -> " with a verify_fun that accepts the unknown ca error"}]. unknown_server_ca_accept_verify_peer(Config) when is_list(Config) -> ClientOpts = ssl_test_lib:ssl_options(empty_client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, @@ -1271,7 +1088,7 @@ unknown_server_ca_accept_backwardscompatibility() -> [{doc,"Test that old style verify_funs will work"}]. unknown_server_ca_accept_backwardscompatibility(Config) when is_list(Config) -> ClientOpts = ssl_test_lib:ssl_options(empty_client_opts, Config), - ServerOpts = ssl_test_lib:ssl_options(server_verification_opts, Config), + ServerOpts = ssl_test_lib:ssl_options(server_rsa_verify_opts, Config), {ClientNode, ServerNode, Hostname} = ssl_test_lib:run_where(Config), Server = ssl_test_lib:start_server([{node, ServerNode}, {port, 0}, {from, self()}, diff --git a/lib/ssl/test/ssl_test_lib.erl b/lib/ssl/test/ssl_test_lib.erl index 302b5178a5..b8fd5dc975 100644 --- a/lib/ssl/test/ssl_test_lib.erl +++ b/lib/ssl/test/ssl_test_lib.erl @@ -485,6 +485,18 @@ make_dsa_cert(Config) -> {certfile, ClientCertFile}, {keyfile, ClientKeyFile}]} | Config]. +make_rsa_cert_chains(ChainConf, Config, Suffix) -> + CryptoSupport = crypto:supports(), + KeyGenSpec = key_gen_info(rsa, rsa), + ClientFileBase = filename:join([proplists:get_value(priv_dir, Config), "rsa" ++ Suffix]), + ServerFileBase = filename:join([proplists:get_value(priv_dir, Config), "rsa" ++ Suffix]), + GenCertData = x509_test:gen_test_certs([{digest, appropriate_sha(CryptoSupport)} | KeyGenSpec] ++ ChainConf), + [{server_config, ServerConf}, + {client_config, ClientConf}] = + x509_test:gen_pem_config_files(GenCertData, ClientFileBase, ServerFileBase), + {[{verify, verify_peer} | ClientConf], + [{reuseaddr, true}, {verify, verify_peer} | ServerConf] + }. make_ec_cert_chains(ClientChainType, ServerChainType, Config) -> CryptoSupport = crypto:supports(), @@ -524,6 +536,11 @@ key_gen_spec(Role, ecdhe_rsa) -> [{list_to_atom(Role ++ "_key_gen"), hardcode_rsa_key(1)}, {list_to_atom(Role ++ "_key_gen_chain"), [hardcode_rsa_key(2), hardcode_rsa_key(3)]} + ]; +key_gen_spec(Role, rsa) -> + [{list_to_atom(Role ++ "_key_gen"), hardcode_rsa_key(1)}, + {list_to_atom(Role ++ "_key_gen_chain"), [hardcode_rsa_key(2), + hardcode_rsa_key(3)]} ]. make_ecdsa_cert(Config) -> CryptoSupport = crypto:supports(), @@ -571,7 +588,8 @@ make_rsa_cert(Config) -> {server_rsa_verify_opts, [{ssl_imp, new}, {reuseaddr, true}, {verify, verify_peer} | ServerConf]}, - {client_rsa_opts, ClientConf} + {client_rsa_opts, ClientConf}, + {client_rsa_verify_opts, [{verify, verify_peer} |ClientConf]} | Config]; false -> Config @@ -935,9 +953,9 @@ available_suites(Version) -> rsa_non_signed_suites(Version) -> lists:filter(fun({rsa, _, _}) -> - true; + false; (_) -> - false + true end, available_suites(Version)). @@ -1398,10 +1416,13 @@ do_supports_ssl_tls_version(Port) -> true end. -ssl_options(Option, Config) -> +ssl_options(Option, Config) when is_atom(Option) -> ProtocolOpts = proplists:get_value(protocol_opts, Config, []), Opts = proplists:get_value(Option, Config, []), - Opts ++ ProtocolOpts. + Opts ++ ProtocolOpts; +ssl_options(Options, Config) -> + ProtocolOpts = proplists:get_value(protocol_opts, Config, []), + Options ++ ProtocolOpts. protocol_version(Config) -> protocol_version(Config, atom). diff --git a/lib/ssl/test/x509_test.erl b/lib/ssl/test/x509_test.erl index 13f8dfdaa9..c36e96013b 100644 --- a/lib/ssl/test/x509_test.erl +++ b/lib/ssl/test/x509_test.erl @@ -96,7 +96,7 @@ gen_pem_config_files(GenCertData, ClientBase, ServerBase) -> public_key:generate_key(KeyGen) end. - root_cert(Role, PrivKey, Opts) -> +root_cert(Role, PrivKey, Opts) -> TBS = cert_template(), Issuer = issuer("root", Role, " ROOT CA"), OTPTBS = TBS#'OTPTBSCertificate'{ @@ -105,7 +105,7 @@ gen_pem_config_files(GenCertData, ClientBase, ServerBase) -> validity = validity(Opts), subject = Issuer, subjectPublicKeyInfo = public_key(PrivKey), - extensions = extensions(Opts) + extensions = extensions(ca, Opts) }, public_key:pkix_sign(OTPTBS, PrivKey). @@ -175,32 +175,31 @@ validity(Opts) -> #'Validity'{notBefore={generalTime, Format(DefFrom)}, notAfter ={generalTime, Format(DefTo)}}. -extensions(Opts) -> - case proplists:get_value(extensions, Opts, []) of - false -> - asn1_NOVALUE; - Exts -> - lists:flatten([extension(Ext) || Ext <- default_extensions(Exts)]) - end. +extensions(Type, Opts) -> + Exts = proplists:get_value(extensions, Opts, []), + lists:flatten([extension(Ext) || Ext <- default_extensions(Type, Exts)]). + +%% Common extension: name_constraints, policy_constraints, ext_key_usage, inhibit_any, +%% auth_key_id, subject_key_id, policy_mapping, + +default_extensions(ca, Exts) -> + Def = [{key_usage, [keyCertSign, cRLSign]}, + {basic_constraints, default}], + add_default_extensions(Def, Exts); -default_extensions(Exts) -> - Def = [{key_usage,undefined}, - {subject_altname, undefined}, - {issuer_altname, undefined}, - {basic_constraints, default}, - {name_constraints, undefined}, - {policy_constraints, undefined}, - {ext_key_usage, undefined}, - {inhibit_any, undefined}, - {auth_key_id, undefined}, - {subject_key_id, undefined}, - {policy_mapping, undefined}], +default_extensions(peer, Exts) -> + Def = [{key_usage, [digitalSignature, keyAgreement]}], + add_default_extensions(Def, Exts). + +add_default_extensions(Def, Exts) -> Filter = fun({Key, _}, D) -> - lists:keydelete(Key, 1, D) + lists:keydelete(Key, 1, D); + ({Key, _, _}, D) -> + lists:keydelete(Key, 1, D) end, Exts ++ lists:foldl(Filter, Def, Exts). - -extension({_, undefined}) -> + +extension({_, undefined}) -> []; extension({basic_constraints, Data}) -> case Data of @@ -218,6 +217,17 @@ extension({basic_constraints, Data}) -> #'Extension'{extnID = ?'id-ce-basicConstraints', extnValue = Data} end; +extension({auth_key_id, {Oid, Issuer, SNr}}) -> + #'Extension'{extnID = ?'id-ce-authorityKeyIdentifier', + extnValue = #'AuthorityKeyIdentifier'{ + keyIdentifier = Oid, + authorityCertIssuer = Issuer, + authorityCertSerialNumber = SNr}, + critical = false}; +extension({key_usage, Value}) -> + #'Extension'{extnID = ?'id-ce-keyUsage', + extnValue = Value, + critical = false}; extension({Id, Data, Critical}) -> #'Extension'{extnID = Id, extnValue = Data, critical = Critical}. @@ -277,24 +287,31 @@ cert_chain(Role, Root, RootKey, Opts, Keys) -> cert_chain(Role, Root, RootKey, Opts, Keys, 0, []). cert_chain(Role, IssuerCert, IssuerKey, Opts, [Key], _, Acc) -> + PeerOpts = list_to_atom(atom_to_list(Role) ++ "_peer_opts"), Cert = cert(Role, public_key:pkix_decode_cert(IssuerCert, otp), - IssuerKey, Key, "admin", " Peer cert", Opts), + IssuerKey, Key, "admin", " Peer cert", Opts, PeerOpts, peer), [{Cert, Key}, {IssuerCert, IssuerKey} | Acc]; cert_chain(Role, IssuerCert, IssuerKey, Opts, [Key | Keys], N, Acc) -> + CAOpts = list_to_atom(atom_to_list(Role) ++ "_ca_" ++ integer_to_list(N)), Cert = cert(Role, public_key:pkix_decode_cert(IssuerCert, otp), IssuerKey, Key, "webadmin", - " Intermidiate CA " ++ integer_to_list(N), Opts), + " Intermidiate CA " ++ integer_to_list(N), Opts, CAOpts, ca), cert_chain(Role, Cert, Key, Opts, Keys, N+1, [{IssuerCert, IssuerKey} | Acc]). -cert(Role, #'OTPCertificate'{tbsCertificate = #'OTPTBSCertificate'{subject = Issuer}}, - PrivKey, Key, Contact, Name, Opts) -> +cert(Role, #'OTPCertificate'{tbsCertificate = #'OTPTBSCertificate'{subject = Issuer, + serialNumber = SNr + }}, + PrivKey, Key, Contact, Name, Opts, CertOptsName, Type) -> + CertOpts = proplists:get_value(CertOptsName, Opts, []), TBS = cert_template(), OTPTBS = TBS#'OTPTBSCertificate'{ signature = sign_algorithm(PrivKey, Opts), issuer = Issuer, - validity = validity(Opts), + validity = validity(CertOpts), subject = subject(Contact, atom_to_list(Role) ++ Name), subjectPublicKeyInfo = public_key(Key), - extensions = extensions(Opts) + extensions = extensions(Type, + add_default_extensions([{auth_key_id, {auth_key_oid(Role), Issuer, SNr}}], + CertOpts)) }, public_key:pkix_sign(OTPTBS, PrivKey). @@ -319,3 +336,8 @@ default_key_gen() -> [{namedCurve, hd(tls_v1:ecc_curves(0))}, {namedCurve, hd(tls_v1:ecc_curves(0))}] end. + +auth_key_oid(server) -> + ?'id-kp-serverAuth'; +auth_key_oid(client) -> + ?'id-kp-clientAuth'. diff --git a/lib/stdlib/doc/src/rand.xml b/lib/stdlib/doc/src/rand.xml index 2ddf3021ac..e06d7e467d 100644 --- a/lib/stdlib/doc/src/rand.xml +++ b/lib/stdlib/doc/src/rand.xml @@ -4,7 +4,7 @@ <erlref> <header> <copyright> - <year>2015</year><year>2016</year> + <year>2015</year><year>2017</year> <holder>Ericsson AB. All Rights Reserved.</holder> </copyright> <legalnotice> @@ -50,26 +50,73 @@ <p>The following algorithms are provided:</p> <taglist> - <tag><c>exsplus</c></tag> + <tag><c>exrop</c></tag> <item> - <p>Xorshift116+, 58 bits precision and period of 2^116-1</p> + <p>Xoroshiro116+, 58 bits precision and period of 2^116-1</p> <p>Jump function: equivalent to 2^64 calls</p> </item> - <tag><c>exs64</c></tag> - <item> - <p>Xorshift64*, 64 bits precision and a period of 2^64-1</p> - <p>Jump function: not available</p> - </item> - <tag><c>exs1024</c></tag> + <tag><c>exs1024s</c></tag> <item> <p>Xorshift1024*, 64 bits precision and a period of 2^1024-1</p> <p>Jump function: equivalent to 2^512 calls</p> </item> + <tag><c>exsp</c></tag> + <item> + <p>Xorshift116+, 58 bits precision and period of 2^116-1</p> + <p>Jump function: equivalent to 2^64 calls</p> + <p> + This is a corrected version of the previous default algorithm, + that now has been superseeded by Xoroshiro116+ (<c>exrop</c>). + Since there is no native 58 bit rotate instruction this + algorithm executes a little (say < 15%) faster than <c>exrop</c>. + See the + <url href="http://xorshift.di.unimi.it">algorithms' homepage</url>. + </p> + </item> </taglist> - <p>The default algorithm is <c>exsplus</c>. If a specific algorithm is + <p> + The default algorithm is <c>exrop</c> (Xoroshiro116+). + If a specific algorithm is required, ensure to always use <seealso marker="#seed-1"> - <c>seed/1</c></seealso> to initialize the state.</p> + <c>seed/1</c></seealso> to initialize the state. + </p> + + <p> + Undocumented (old) algorithms are deprecated but still implemented + so old code relying on them will produce + the same pseudo random sequences as before. + </p> + + <note> + <p> + There were a number of problems in the implementation + of the now undocumented algorithms, which is why + they are deprecated. The new algorithms are a bit slower + but do not have these problems: + </p> + <p> + Uniform integer ranges had a skew in the probability distribution + that was not noticable for small ranges but for large ranges + less than the generator's precision the probability to produce + a low number could be twice the probability for a high. + </p> + <p> + Uniform integer ranges larger than or equal to the generator's + precision used a floating point fallback that only calculated + with 52 bits which is smaller than the requested range + and therefore were not all numbers in the requested range + even possible to produce. + </p> + <p> + Uniform floats had a non-uniform density so small values + i.e less than 0.5 had got smaller intervals decreasing + as the generated value approached 0.0 although still uniformly + distributed for sufficiently large subranges. The new algorithms + produces uniformly distributed floats on the form N * 2.0^(-53) + hence equally spaced. + </p> + </note> <p>Every time a random number is requested, a state is used to calculate it and a new state is produced. The state can either be @@ -99,19 +146,19 @@ R1 = rand:uniform(),</pre> <p>Use a specified algorithm:</p> <pre> -_ = rand:seed(exs1024), +_ = rand:seed(exs1024s), R2 = rand:uniform(),</pre> <p>Use a specified algorithm with a constant seed:</p> <pre> -_ = rand:seed(exs1024, {123, 123534, 345345}), +_ = rand:seed(exs1024s, {123, 123534, 345345}), R3 = rand:uniform(),</pre> <p>Use the functional API with a non-constant seed:</p> <pre> -S0 = rand:seed_s(exsplus), +S0 = rand:seed_s(exrop), {R4, S1} = rand:uniform_s(S0),</pre> <p>Create a standard normal deviate:</p> @@ -119,6 +166,11 @@ S0 = rand:seed_s(exsplus), <pre> {SND0, S2} = rand:normal_s(S1),</pre> + <p>Create a normal deviate with mean -3 and variance 0.5:</p> + + <pre> +{ND0, S3} = rand:normal_s(-3, 0.5, S2),</pre> + <note> <p>The builtin random number generator algorithms are not cryptographically strong. If a cryptographically strong @@ -127,6 +179,39 @@ S0 = rand:seed_s(exsplus), </p> </note> + <p> + For all these generators the lowest bit(s) has got + a slightly less random behaviour than all other bits. + 1 bit for <c>exrop</c> (and <c>exsp</c>), + and 3 bits for <c>exs1024s</c>. + See for example the explanation in the + <url href="http://xoroshiro.di.unimi.it/xoroshiro128plus.c"> + Xoroshiro128+ + </url> + generator source code: + </p> + <pre> +Beside passing BigCrush, this generator passes the PractRand test suite +up to (and included) 16TB, with the exception of binary rank tests, +which fail due to the lowest bit being an LFSR; all other bits pass all +tests. We suggest to use a sign test to extract a random Boolean value.</pre> + <p> + If this is a problem; to generate a boolean + use something like this: + </p> + <pre>(rand:uniform(16) > 8)</pre> + <p> + And for a general range, with <c>N = 1</c> for <c>exrop</c>, + and <c>N = 3</c> for <c>exs1024s</c>: + </p> + <pre>(((rand:uniform(Range bsl N) - 1) bsr N) + 1)</pre> + <p> + The floating point generating functions in this module + waste the lowest bits when converting from an integer + so they avoid this snag. + </p> + + </description> <datatypes> <datatype> @@ -142,6 +227,18 @@ S0 = rand:seed_s(exsplus), <name name="alg_state"/> </datatype> <datatype> + <name name="state"/> + <desc><p>Algorithm-dependent state.</p></desc> + </datatype> + <datatype> + <name name="export_state"/> + <desc> + <p> + Algorithm-dependent state that can be printed or saved to file. + </p> + </desc> + </datatype> + <datatype> <name name="exs64_state"/> <desc><p>Algorithm specific internal state</p></desc> </datatype> @@ -154,16 +251,8 @@ S0 = rand:seed_s(exsplus), <desc><p>Algorithm specific internal state</p></desc> </datatype> <datatype> - <name name="state"/> - <desc><p>Algorithm-dependent state.</p></desc> - </datatype> - <datatype> - <name name="export_state"/> - <desc> - <p> - Algorithm-dependent state that can be printed or saved to file. - </p> - </desc> + <name name="exrop_state"/> + <desc><p>Algorithm specific internal state</p></desc> </datatype> </datatypes> @@ -224,6 +313,15 @@ S0 = rand:seed_s(exsplus), </func> <func> + <name name="normal" arity="2"/> + <fsummary>Return a normal distributed random float.</fsummary> + <desc> + <p>Returns a normal N(Mean, Variance) deviate float + and updates the state in the process dictionary.</p> + </desc> + </func> + + <func> <name name="normal_s" arity="1"/> <fsummary>Return a standard normal distributed random float.</fsummary> <desc> @@ -234,6 +332,15 @@ S0 = rand:seed_s(exsplus), </func> <func> + <name name="normal_s" arity="3"/> + <fsummary>Return a normal distributed random float.</fsummary> + <desc> + <p>Returns, for a specified state, a normal N(Mean, Variance) + deviate float and a new state.</p> + </desc> + </func> + + <func> <name name="seed" arity="1"/> <fsummary>Seed random number generator.</fsummary> <desc> diff --git a/lib/stdlib/src/rand.erl b/lib/stdlib/src/rand.erl index dfd102f9ef..ab9731180f 100644 --- a/lib/stdlib/src/rand.erl +++ b/lib/stdlib/src/rand.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2015-2016. All Rights Reserved. +%% Copyright Ericsson AB 2015-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -20,6 +20,9 @@ %% ===================================================================== %% Multiple PRNG module for Erlang/OTP %% Copyright (c) 2015-2016 Kenji Rikitake +%% +%% exrop (xoroshiro116+) added and statistical distribution +%% improvements by the Erlang/OTP team 2017 %% ===================================================================== -module(rand). @@ -28,48 +31,179 @@ export_seed/0, export_seed_s/1, uniform/0, uniform/1, uniform_s/1, uniform_s/2, jump/0, jump/1, - normal/0, normal_s/1 + normal/0, normal/2, normal_s/1, normal_s/3 ]). -compile({inline, [exs64_next/1, exsplus_next/1, - exsplus_jump/1, exs1024_next/1, exs1024_calc/2, - exs1024_jump/1, + exrop_next/1, exrop_next_s/2, get_52/1, normal_kiwi/1]}). --define(DEFAULT_ALG_HANDLER, exsplus). +-define(DEFAULT_ALG_HANDLER, exrop). -define(SEED_DICT, rand_seed). %% ===================================================================== +%% Bit fiddling macros +%% ===================================================================== + +-define(BIT(Bits), (1 bsl (Bits))). +-define(MASK(Bits), (?BIT(Bits) - 1)). +-define(MASK(Bits, X), ((X) band ?MASK(Bits))). +-define( + BSL(Bits, X, N), + %% N is evaluated 2 times + (?MASK((Bits)-(N), (X)) bsl (N))). +-define( + ROTL(Bits, X, N), + %% Bits is evaluated 2 times + %% X is evaluated 2 times + %% N i evaluated 3 times + (?BSL((Bits), (X), (N)) bor ((X) bsr ((Bits)-(N))))). + +%%-define(TWO_POW_MINUS53, (math:pow(2, -53))). +-define(TWO_POW_MINUS53, 1.11022302462515657e-16). + +%% ===================================================================== %% Types %% ===================================================================== +-type uint64() :: 0..?MASK(64). +-type uint58() :: 0..?MASK(58). + %% This depends on the algorithm handler function -type alg_state() :: - exs64_state() | exsplus_state() | exs1024_state() | term(). + exs64_state() | exsplus_state() | exs1024_state() | + exrop_state() | term(). -%% This is the algorithm handler function within this module +%% This is the algorithm handling definition within this module, +%% and the type to use for plugins. +%% +%% The 'type' field must be recognized by the module that implements +%% the algorithm, to interpret an exported state. +%% +%% The 'bits' field indicates how many bits the integer +%% returned from 'next' has got, i.e 'next' shall return +%% an random integer in the range 0..(2^Bits - 1). +%% At least 53 bits is required for the floating point +%% producing fallbacks. This field is only used when +%% the 'uniform' or 'uniform_n' fields are not defined. +%% +%% The fields 'next', 'uniform' and 'uniform_n' +%% implement the algorithm. If 'uniform' or 'uinform_n' +%% is not present there is a fallback using 'next' and either +%% 'bits' or the deprecated 'max'. +%% -type alg_handler() :: #{type := alg(), - max := integer() | infinity, + bits => non_neg_integer(), + weak_low_bits => non_neg_integer(), + max => non_neg_integer(), % Deprecated next := - fun((alg_state()) -> {non_neg_integer(), alg_state()}), - uniform := - fun((state()) -> {float(), state()}), - uniform_n := - fun((pos_integer(), state()) -> {pos_integer(), state()}), - jump := - fun((state()) -> state())}. + fun ((alg_state()) -> {non_neg_integer(), alg_state()}), + uniform => + fun ((state()) -> {float(), state()}), + uniform_n => + fun ((pos_integer(), state()) -> {pos_integer(), state()}), + jump => + fun ((state()) -> state())}. %% Algorithm state -type state() :: {alg_handler(), alg_state()}. --type builtin_alg() :: exs64 | exsplus | exs1024. +-type builtin_alg() :: exs64 | exsplus | exsp | exs1024 | exs1024s | exrop. -type alg() :: builtin_alg() | atom(). -type export_state() :: {alg(), alg_state()}. -export_type( [builtin_alg/0, alg/0, alg_handler/0, alg_state/0, state/0, export_state/0]). --export_type([exs64_state/0, exsplus_state/0, exs1024_state/0]). +-export_type( + [exs64_state/0, exsplus_state/0, exs1024_state/0, exrop_state/0]). + +%% ===================================================================== +%% Range macro and helper +%% ===================================================================== + +-define( + uniform_range(Range, Alg, R, V, MaxMinusRange, I), + if + 0 =< (MaxMinusRange) -> + if + %% Really work saving in odd cases; + %% large ranges in particular + (V) < (Range) -> + {(V) + 1, {(Alg), (R)}}; + true -> + (I) = (V) rem (Range), + if + (V) - (I) =< (MaxMinusRange) -> + {(I) + 1, {(Alg), (R)}}; + true -> + %% V in the truncated top range + %% - try again + ?FUNCTION_NAME((Range), {(Alg), (R)}) + end + end; + true -> + uniform_range((Range), (Alg), (R), (V)) + end). + +%% For ranges larger than the algorithm bit size +uniform_range(Range, #{next:=Next, bits:=Bits} = Alg, R, V) -> + WeakLowBits = + case Alg of + #{weak_low_bits:=WLB} -> WLB; + #{} -> 0 + end, + %% Maybe waste the lowest bit(s) when shifting in new bits + Shift = Bits - WeakLowBits, + ShiftMask = bnot ?MASK(WeakLowBits), + RangeMinus1 = Range - 1, + if + (Range band RangeMinus1) =:= 0 -> % Power of 2 + %% Generate at least the number of bits for the range + {V1, R1, _} = + uniform_range( + Range bsr Bits, Next, R, V, ShiftMask, Shift, Bits), + {(V1 band RangeMinus1) + 1, {Alg, R1}}; + true -> + %% Generate a value with at least two bits more than the range + %% and try that for a fit, otherwise recurse + %% + %% Just one bit more should ensure that the generated + %% number range is at least twice the size of the requested + %% range, which would make the probability to draw a good + %% number better than 0.5. And repeating that until + %% success i guess would take 2 times statistically amortized. + %% But since the probability for fairly many attemtpts + %% is not that low, use two bits more than the range which + %% should make the probability to draw a bad number under 0.25, + %% which decreases the bad case probability a lot. + {V1, R1, B} = + uniform_range( + Range bsr (Bits - 2), Next, R, V, ShiftMask, Shift, Bits), + I = V1 rem Range, + if + (V1 - I) =< (1 bsl B) - Range -> + {I + 1, {Alg, R1}}; + true -> + %% V1 drawn from the truncated top range + %% - try again + {V2, R2} = Next(R1), + uniform_range(Range, Alg, R2, V2) + end + end. +%% +uniform_range(Range, Next, R, V, ShiftMask, Shift, B) -> + if + Range =< 1 -> + {V, R, B}; + true -> + {V1, R1} = Next(R), + %% Waste the lowest bit(s) when shifting in new bits + uniform_range( + Range bsr Shift, Next, R1, + ((V band ShiftMask) bsl Shift) bor V1, + ShiftMask, Shift, B + Shift) + end. %% ===================================================================== %% API @@ -156,7 +290,16 @@ uniform(N) -> -spec uniform_s(State :: state()) -> {X :: float(), NewState :: state()}. uniform_s(State = {#{uniform:=Uniform}, _}) -> - Uniform(State). + Uniform(State); +uniform_s({#{bits:=Bits, next:=Next} = Alg, R0}) -> + {V, R1} = Next(R0), + %% Produce floats on the form N * 2^(-53) + {(V bsr (Bits - 53)) * ?TWO_POW_MINUS53, {Alg, R1}}; +uniform_s({#{max:=Max, next:=Next} = Alg, R0}) -> + {V, R1} = Next(R0), + %% Old broken algorithm with non-uniform density + {V / (Max + 1), {Alg, R1}}. + %% uniform_s/2: given an integer N >= 1 and a state, uniform_s/2 %% uniform_s/2 returns a random integer X where 1 =< X =< N, @@ -164,13 +307,26 @@ uniform_s(State = {#{uniform:=Uniform}, _}) -> -spec uniform_s(N :: pos_integer(), State :: state()) -> {X :: pos_integer(), NewState :: state()}. -uniform_s(N, State = {#{uniform_n:=Uniform, max:=Max}, _}) - when 0 < N, N =< Max -> - Uniform(N, State); -uniform_s(N, State0 = {#{uniform:=Uniform}, _}) - when is_integer(N), 0 < N -> - {F, State} = Uniform(State0), - {trunc(F * N) + 1, State}. +uniform_s(N, State = {#{uniform_n:=UniformN}, _}) + when is_integer(N), 1 =< N -> + UniformN(N, State); +uniform_s(N, {#{bits:=Bits, next:=Next} = Alg, R0}) + when is_integer(N), 1 =< N -> + {V, R1} = Next(R0), + MaxMinusN = ?BIT(Bits) - N, + ?uniform_range(N, Alg, R1, V, MaxMinusN, I); +uniform_s(N, {#{max:=Max, next:=Next} = Alg, R0}) + when is_integer(N), 1 =< N -> + %% Old broken algorithm with skewed probability + %% and gap in ranges > Max + {V, R1} = Next(R0), + if + N =< Max -> + {(V rem N) + 1, {Alg, R1}}; + true -> + F = V / (Max + 1), + {trunc(F * N) + 1, {Alg, R1}} + end. %% jump/1: given a state, jump/1 %% returns a new state which is equivalent to that @@ -179,7 +335,10 @@ uniform_s(N, State0 = {#{uniform:=Uniform}, _}) -spec jump(state()) -> NewState :: state(). jump(State = {#{jump:=Jump}, _}) -> - Jump(State). + Jump(State); +jump({#{}, _}) -> + erlang:error(not_implemented). + %% jump/0: read the internal state and %% apply the jump function for the state as in jump/1 @@ -187,7 +346,6 @@ jump(State = {#{jump:=Jump}, _}) -> %% then returns the new value. -spec jump() -> NewState :: state(). - jump() -> seed_put(jump(seed_get())). @@ -200,6 +358,13 @@ normal() -> _ = seed_put(Seed), X. +%% normal/2: returns a random float with N(μ, σ²) normal distribution +%% updating the state in the process dictionary. + +-spec normal(Mean :: number(), Variance :: number()) -> float(). +normal(Mean, Variance) -> + Mean + (math:sqrt(Variance) * normal()). + %% normal_s/1: returns a random float with standard normal distribution %% The Ziggurat Method for generating random variables - Marsaglia and Tsang %% Paper and reference code: http://www.jstatsoft.org/v05/i08/ @@ -207,7 +372,7 @@ normal() -> -spec normal_s(State :: state()) -> {float(), NewState :: state()}. normal_s(State0) -> {Sign, R, State} = get_52(State0), - Idx = R band 16#FF, + Idx = ?MASK(8, R), Idx1 = Idx+1, {Ki, Wi} = normal_kiwi(Idx1), X = R * Wi, @@ -220,18 +385,15 @@ normal_s(State0) -> false -> normal_s(Idx, Sign, -X, State) end. -%% ===================================================================== -%% Internal functions +%% normal_s/3: returns a random float with normal N(μ, σ²) distribution --define(UINT21MASK, 16#00000000001fffff). --define(UINT32MASK, 16#00000000ffffffff). --define(UINT33MASK, 16#00000001ffffffff). --define(UINT39MASK, 16#0000007fffffffff). --define(UINT58MASK, 16#03ffffffffffffff). --define(UINT64MASK, 16#ffffffffffffffff). +-spec normal_s(Mean :: number(), Variance :: number(), state()) -> {float(), NewS :: state()}. +normal_s(Mean, Variance, State0) when Variance > 0 -> + {X, State} = normal_s(State0), + {Mean + (math:sqrt(Variance) * X), State}. --type uint64() :: 0..16#ffffffffffffffff. --type uint58() :: 0..16#03ffffffffffffff. +%% ===================================================================== +%% Internal functions -spec seed_put(state()) -> state(). seed_put(Seed) -> @@ -246,20 +408,30 @@ seed_get() -> %% Setup alg record mk_alg(exs64) -> - {#{type=>exs64, max=>?UINT64MASK, next=>fun exs64_next/1, - uniform=>fun exs64_uniform/1, uniform_n=>fun exs64_uniform/2, - jump=>fun exs64_jump/1}, + {#{type=>exs64, max=>?MASK(64), next=>fun exs64_next/1}, fun exs64_seed/1}; mk_alg(exsplus) -> - {#{type=>exsplus, max=>?UINT58MASK, next=>fun exsplus_next/1, - uniform=>fun exsplus_uniform/1, uniform_n=>fun exsplus_uniform/2, + {#{type=>exsplus, max=>?MASK(58), next=>fun exsplus_next/1, + jump=>fun exsplus_jump/1}, + fun exsplus_seed/1}; +mk_alg(exsp) -> + {#{type=>exsp, bits=>58, weak_low_bits=>1, next=>fun exsplus_next/1, + uniform=>fun exsp_uniform/1, uniform_n=>fun exsp_uniform/2, jump=>fun exsplus_jump/1}, fun exsplus_seed/1}; mk_alg(exs1024) -> - {#{type=>exs1024, max=>?UINT64MASK, next=>fun exs1024_next/1, - uniform=>fun exs1024_uniform/1, uniform_n=>fun exs1024_uniform/2, + {#{type=>exs1024, max=>?MASK(64), next=>fun exs1024_next/1, jump=>fun exs1024_jump/1}, - fun exs1024_seed/1}. + fun exs1024_seed/1}; +mk_alg(exs1024s) -> + {#{type=>exs1024s, bits=>64, weak_low_bits=>3, next=>fun exs1024_next/1, + jump=>fun exs1024_jump/1}, + fun exs1024_seed/1}; +mk_alg(exrop) -> + {#{type=>exrop, bits=>58, weak_low_bits=>1, next=>fun exrop_next/1, + uniform=>fun exrop_uniform/1, uniform_n=>fun exrop_uniform/2, + jump=>fun exrop_jump/1}, + fun exrop_seed/1}. %% ===================================================================== %% exs64 PRNG: Xorshift64* @@ -270,29 +442,18 @@ mk_alg(exs1024) -> -opaque exs64_state() :: uint64(). exs64_seed({A1, A2, A3}) -> - {V1, _} = exs64_next(((A1 band ?UINT32MASK) * 4294967197 + 1)), - {V2, _} = exs64_next(((A2 band ?UINT32MASK) * 4294967231 + 1)), - {V3, _} = exs64_next(((A3 band ?UINT32MASK) * 4294967279 + 1)), - ((V1 * V2 * V3) rem (?UINT64MASK - 1)) + 1. + {V1, _} = exs64_next((?MASK(32, A1) * 4294967197 + 1)), + {V2, _} = exs64_next((?MASK(32, A2) * 4294967231 + 1)), + {V3, _} = exs64_next((?MASK(32, A3) * 4294967279 + 1)), + ((V1 * V2 * V3) rem (?MASK(64) - 1)) + 1. %% Advance xorshift64* state for one step and generate 64bit unsigned integer -spec exs64_next(exs64_state()) -> {uint64(), exs64_state()}. exs64_next(R) -> R1 = R bxor (R bsr 12), - R2 = R1 bxor ((R1 band ?UINT39MASK) bsl 25), + R2 = R1 bxor ?BSL(64, R1, 25), R3 = R2 bxor (R2 bsr 27), - {(R3 * 2685821657736338717) band ?UINT64MASK, R3}. - -exs64_uniform({Alg, R0}) -> - {V, R1} = exs64_next(R0), - {V / 18446744073709551616, {Alg, R1}}. - -exs64_uniform(Max, {Alg, R}) -> - {V, R1} = exs64_next(R), - {(V rem Max) + 1, {Alg, R1}}. - -exs64_jump(_) -> - erlang:error(not_implemented). + {?MASK(64, R3 * 2685821657736338717), R3}. %% ===================================================================== %% exsplus PRNG: Xorshift116+ @@ -307,10 +468,12 @@ exs64_jump(_) -> -dialyzer({no_improper_lists, exsplus_seed/1}). exsplus_seed({A1, A2, A3}) -> - {_, R1} = exsplus_next([(((A1 * 4294967197) + 1) band ?UINT58MASK)| - (((A2 * 4294967231) + 1) band ?UINT58MASK)]), - {_, R2} = exsplus_next([(((A3 * 4294967279) + 1) band ?UINT58MASK)| - tl(R1)]), + {_, R1} = exsplus_next( + [?MASK(58, (A1 * 4294967197) + 1)| + ?MASK(58, (A2 * 4294967231) + 1)]), + {_, R2} = exsplus_next( + [?MASK(58, (A3 * 4294967279) + 1)| + tl(R1)]), R2. -dialyzer({no_improper_lists, exsplus_next/1}). @@ -319,17 +482,22 @@ exsplus_seed({A1, A2, A3}) -> -spec exsplus_next(exsplus_state()) -> {uint58(), exsplus_state()}. exsplus_next([S1|S0]) -> %% Note: members s0 and s1 are swapped here - S11 = (S1 bxor (S1 bsl 24)) band ?UINT58MASK, + S11 = S1 bxor ?BSL(58, S1, 24), S12 = S11 bxor S0 bxor (S11 bsr 11) bxor (S0 bsr 41), - {(S0 + S12) band ?UINT58MASK, [S0|S12]}. + {?MASK(58, S0 + S12), [S0|S12]}. + -exsplus_uniform({Alg, R0}) -> +exsp_uniform({Alg, R0}) -> {I, R1} = exsplus_next(R0), - {I / (?UINT58MASK+1), {Alg, R1}}. + %% Waste the lowest bit since it is of lower + %% randomness quality than the others + {(I bsr (58-53)) * ?TWO_POW_MINUS53, {Alg, R1}}. -exsplus_uniform(Max, {Alg, R}) -> +exsp_uniform(Range, {Alg, R}) -> {V, R1} = exsplus_next(R), - {(V rem Max) + 1, {Alg, R1}}. + MaxMinusRange = ?BIT(58) - Range, + ?uniform_range(Range, Alg, R1, V, MaxMinusRange, I). + %% This is the jump function for the exsplus generator, equivalent %% to 2^64 calls to next/1; it can be used to generate 2^52 @@ -357,7 +525,7 @@ exsplus_jump(S, AS, _, 0) -> {S, AS}; exsplus_jump(S, [AS0|AS1], J, N) -> {_, NS} = exsplus_next(S), - case (J band 1) of + case ?MASK(1, J) of 1 -> [S0|S1] = S, exsplus_jump(NS, [(AS0 bxor S0)|(AS1 bxor S1)], J bsr 1, N-1); @@ -374,9 +542,9 @@ exsplus_jump(S, [AS0|AS1], J, N) -> -opaque exs1024_state() :: {list(uint64()), list(uint64())}. exs1024_seed({A1, A2, A3}) -> - B1 = (((A1 band ?UINT21MASK) + 1) * 2097131) band ?UINT21MASK, - B2 = (((A2 band ?UINT21MASK) + 1) * 2097133) band ?UINT21MASK, - B3 = (((A3 band ?UINT21MASK) + 1) * 2097143) band ?UINT21MASK, + B1 = ?MASK(21, (?MASK(21, A1) + 1) * 2097131), + B2 = ?MASK(21, (?MASK(21, A2) + 1) * 2097133), + B3 = ?MASK(21, (?MASK(21, A3) + 1) * 2097143), {exs1024_gen1024((B1 bsl 43) bor (B2 bsl 22) bor (B3 bsl 1) bor 1), []}. @@ -399,11 +567,11 @@ exs1024_gen1024(N, R, L) -> %% X: random number output -spec exs1024_calc(uint64(), uint64()) -> {uint64(), uint64()}. exs1024_calc(S0, S1) -> - S11 = S1 bxor ((S1 band ?UINT33MASK) bsl 31), + S11 = S1 bxor ?BSL(64, S1, 31), S12 = S11 bxor (S11 bsr 11), S01 = S0 bxor (S0 bsr 30), NS1 = S01 bxor S12, - {(NS1 * 1181783497276652981) band ?UINT64MASK, NS1}. + {?MASK(64, NS1 * 1181783497276652981), NS1}. %% Advance xorshift1024* state for one step and generate 64bit unsigned integer -spec exs1024_next(exs1024_state()) -> {uint64(), exs1024_state()}. @@ -414,13 +582,6 @@ exs1024_next({[H], RL}) -> NL = [H|lists:reverse(RL)], exs1024_next({NL, []}). -exs1024_uniform({Alg, R0}) -> - {V, R1} = exs1024_next(R0), - {V / 18446744073709551616, {Alg, R1}}. - -exs1024_uniform(Max, {Alg, R}) -> - {V, R1} = exs1024_next(R), - {(V rem Max) + 1, {Alg, R1}}. %% This is the jump function for the exs1024 generator, equivalent %% to 2^512 calls to next(); it can be used to generate 2^512 @@ -467,7 +628,7 @@ exs1024_jump(S, AS, [H|T], _, 0, TN) -> exs1024_jump(S, AS, T, H, ?JUMPELEMLEN, TN); exs1024_jump({L, RL}, AS, JL, J, N, TN) -> {_, NS} = exs1024_next({L, RL}), - case (J band 1) of + case ?MASK(1, J) of 1 -> AS2 = lists:zipwith(fun(X, Y) -> X bxor Y end, AS, L ++ lists:reverse(RL)), @@ -477,15 +638,149 @@ exs1024_jump({L, RL}, AS, JL, J, N, TN) -> end. %% ===================================================================== +%% exrop PRNG: Xoroshiro116+ +%% +%% Reference URL: http://xorshift.di.unimi.it/ +%% +%% 58 bits fits into an immediate on 64bits Erlang and is thus much faster. +%% In fact, an immediate number is 60 bits signed in Erlang so you can +%% add two positive 58 bit numbers and get a 59 bit number that still is +%% a positive immediate, which is a property we utilize here... +%% +%% Modification of the original Xororhiro128+ algorithm to 116 bits +%% by Sebastiano Vigna. A lot of thanks for his help and work. +%% ===================================================================== +%% (a, b, c) = (24, 2, 35) +%% JUMP Polynomial = 0x9863200f83fcd4a11293241fcb12a (116 bit) +%% +%% From http://xoroshiro.di.unimi.it/xoroshiro116plus.c: +%% --------------------------------------------------------------------- +%% /* Written in 2017 by Sebastiano Vigna ([email protected]). +%% +%% To the extent possible under law, the author has dedicated all copyright +%% and related and neighboring rights to this software to the public domain +%% worldwide. This software is distributed without any warranty. +%% +%% See <http://creativecommons.org/publicdomain/zero/1.0/>. */ +%% +%% #include <stdint.h> +%% +%% #define UINT58MASK (uint64_t)((UINT64_C(1) << 58) - 1) +%% +%% uint64_t s[2]; +%% +%% static inline uint64_t rotl58(const uint64_t x, int k) { +%% return (x << k) & UINT58MASK | (x >> (58 - k)); +%% } +%% +%% uint64_t next(void) { +%% uint64_t s1 = s[1]; +%% const uint64_t s0 = s[0]; +%% const uint64_t result = (s0 + s1) & UINT58MASK; +%% +%% s1 ^= s0; +%% s[0] = rotl58(s0, 24) ^ s1 ^ ((s1 << 2) & UINT58MASK); // a, b +%% s[1] = rotl58(s1, 35); // c +%% return result; +%% } +%% +%% void jump(void) { +%% static const uint64_t JUMP[] = +%% { 0x4a11293241fcb12a, 0x0009863200f83fcd }; +%% +%% uint64_t s0 = 0; +%% uint64_t s1 = 0; +%% for(int i = 0; i < sizeof JUMP / sizeof *JUMP; i++) +%% for(int b = 0; b < 64; b++) { +%% if (JUMP[i] & UINT64_C(1) << b) { +%% s0 ^= s[0]; +%% s1 ^= s[1]; +%% } +%% next(); +%% } +%% s[0] = s0; +%% s[1] = s1; +%% } + +-opaque exrop_state() :: nonempty_improper_list(uint58(), uint58()). + +-dialyzer({no_improper_lists, exrop_seed/1}). +exrop_seed({A1, A2, A3}) -> + [_|S1] = + exrop_next_s( + ?MASK(58, (A1 * 4294967197) + 1), + ?MASK(58, (A2 * 4294967231) + 1)), + exrop_next_s(?MASK(58, (A3 * 4294967279) + 1), S1). + +-dialyzer({no_improper_lists, exrop_next_s/2}). +%% Advance xoroshiro116+ state one step +%% [a, b, c] = [24, 2, 35] +-define( + exrop_next_s(S0, S1, S1_a), + begin + S1_a = S1 bxor S0, + [?ROTL(58, S0, 24) bxor S1_a bxor ?BSL(58, S1_a, 2)| % a, b + ?ROTL(58, S1_a, 35)] % c + end). +exrop_next_s(S0, S1) -> + ?exrop_next_s(S0, S1, S1_a). + +-dialyzer({no_improper_lists, exrop_next/1}). +%% Advance xoroshiro116+ state one step, generate 58 bit unsigned integer, +%% and waste the lowest bit since it is of lower randomness quality +exrop_next([S0|S1]) -> + {?MASK(58, S0 + S1), ?exrop_next_s(S0, S1, S1_a)}. + +exrop_uniform({Alg, R}) -> + {V, R1} = exrop_next(R), + %% Waste the lowest bit since it is of lower + %% randomness quality than the others + {(V bsr (58-53)) * ?TWO_POW_MINUS53, {Alg, R1}}. + +exrop_uniform(Range, {Alg, R}) -> + {V, R1} = exrop_next(R), + MaxMinusRange = ?BIT(58) - Range, + ?uniform_range(Range, Alg, R1, V, MaxMinusRange, I). + +%% Split a 116 bit constant into two '1'++58 bit words, +%% the top '1' marks the top of the word +-define( + JUMP_116(Jump), + [?BIT(58) bor ?MASK(58, (Jump)),?BIT(58) bor ((Jump) bsr 58)]). +%% +exrop_jump({Alg,S}) -> + [J|Js] = ?JUMP_116(16#9863200f83fcd4a11293241fcb12a), + {Alg, exrop_jump(S, 0, 0, J, Js)}. +%% +-dialyzer({no_improper_lists, exrop_jump/5}). +exrop_jump(_S, S0, S1, 1, []) -> % End of jump constant + [S0|S1]; +exrop_jump(S, S0, S1, 1, [J|Js]) -> % End of the word + exrop_jump(S, S0, S1, J, Js); +exrop_jump([S__0|S__1] = _S, S0, S1, J, Js) -> + case ?MASK(1, J) of + 1 -> + NewS = exrop_next_s(S__0, S__1), + exrop_jump(NewS, S0 bxor S__0, S1 bxor S__1, J bsr 1, Js); + 0 -> + NewS = exrop_next_s(S__0, S__1), + exrop_jump(NewS, S0, S1, J bsr 1, Js) + end. + +%% ===================================================================== %% Ziggurat cont %% ===================================================================== -define(NOR_R, 3.6541528853610087963519472518). -define(NOR_INV_R, 1/?NOR_R). %% return a {sign, Random51bits, State} +get_52({Alg=#{bits:=Bits, next:=Next}, S0}) -> + %% Use the high bits + {Int,S1} = Next(S0), + {?BIT(Bits - 51 - 1) band Int, Int bsr (Bits - 51), {Alg, S1}}; get_52({Alg=#{next:=Next}, S0}) -> {Int,S1} = Next(S0), - {((1 bsl 51) band Int), Int band ((1 bsl 51)-1), {Alg, S1}}. + {?BIT(51) band Int, ?MASK(51, Int), {Alg, S1}}. %% Slow path normal_s(0, Sign, X0, State0) -> diff --git a/lib/stdlib/test/io_proto_SUITE.erl b/lib/stdlib/test/io_proto_SUITE.erl index 4cc4e3292c..b795cb0b61 100644 --- a/lib/stdlib/test/io_proto_SUITE.erl +++ b/lib/stdlib/test/io_proto_SUITE.erl @@ -18,7 +18,6 @@ %% %CopyrightEnd% %% -module(io_proto_SUITE). --compile(r12). -export([all/0, suite/0,groups/0,init_per_suite/1, end_per_suite/1, init_per_group/2,end_per_group/2]). diff --git a/lib/stdlib/test/rand_SUITE.erl b/lib/stdlib/test/rand_SUITE.erl index 098eefeb61..36bc283aec 100644 --- a/lib/stdlib/test/rand_SUITE.erl +++ b/lib/stdlib/test/rand_SUITE.erl @@ -1,7 +1,7 @@ %% %% %CopyrightBegin% %% -%% Copyright Ericsson AB 2000-2016. All Rights Reserved. +%% Copyright Ericsson AB 2000-2017. All Rights Reserved. %% %% Licensed under the Apache License, Version 2.0 (the "License"); %% you may not use this file except in compliance with the License. @@ -27,6 +27,7 @@ -export([interval_int/1, interval_float/1, seed/1, api_eq/1, reference/1, basic_stats_uniform_1/1, basic_stats_uniform_2/1, + basic_stats_standard_normal/1, basic_stats_normal/1, plugin/1, measure/1, reference_jump_state/1, reference_jump_procdict/1]). @@ -52,7 +53,8 @@ all() -> groups() -> [{basic_stats, [parallel], - [basic_stats_uniform_1, basic_stats_uniform_2, basic_stats_normal]}, + [basic_stats_uniform_1, basic_stats_uniform_2, + basic_stats_standard_normal, basic_stats_normal]}, {reference_jump, [parallel], [reference_jump_state, reference_jump_procdict]}]. @@ -66,18 +68,19 @@ group(reference_jump) -> %% A simple helper to test without test_server during dev test() -> Tests = all(), - lists:foreach(fun(Test) -> - try - ok = ?MODULE:Test([]), - io:format("~p: ok~n", [Test]) - catch _:Reason -> - io:format("Failed: ~p: ~p ~p~n", - [Test, Reason, erlang:get_stacktrace()]) - end - end, Tests). + lists:foreach( + fun (Test) -> + try + ok = ?MODULE:Test([]), + io:format("~p: ok~n", [Test]) + catch _:Reason -> + io:format("Failed: ~p: ~p ~p~n", + [Test, Reason, erlang:get_stacktrace()]) + end + end, Tests). algs() -> - [exs64, exsplus, exs1024]. + [exs64, exsplus, exsp, exrop, exs1024, exs1024s]. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% @@ -226,10 +229,10 @@ interval_float_1(0) -> ok; interval_float_1(N) -> X = rand:uniform(), if - 0.0 < X, X < 1.0 -> + 0.0 =< X, X < 1.0 -> ok; true -> - io:format("X=~p 0<~p<1.0~n", [X,X]), + io:format("X=~p 0=<~p<1.0~n", [X,X]), exit({X, rand:export_seed()}) end, interval_float_1(N-1). @@ -246,6 +249,8 @@ reference_1(Alg) -> Testval = gen(Alg), case Refval =:= Testval of true -> ok; + false when Refval =:= not_implemented -> + exit({not_implemented,Alg}); false -> io:format("Failed: ~p~n",[Alg]), io:format("Length ~p ~p~n",[length(Refval), length(Testval)]), @@ -254,25 +259,29 @@ reference_1(Alg) -> end. gen(Algo) -> - Seed = case Algo of - exsplus -> %% Printed with orig 'C' code and this seed - rand:seed_s({exsplus, [12345678|12345678]}); - exs64 -> %% Printed with orig 'C' code and this seed - rand:seed_s({exs64, 12345678}); - exs1024 -> %% Printed with orig 'C' code and this seed - rand:seed_s({exs1024, {lists:duplicate(16, 12345678), []}}); - _ -> - rand:seed(Algo, {100, 200, 300}) - end, - gen(?LOOP, Seed, []). - -gen(N, State0 = {#{max:=Max}, _}, Acc) when N > 0 -> + State = + case Algo of + exs64 -> %% Printed with orig 'C' code and this seed + rand:seed_s({exs64, 12345678}); + _ when Algo =:= exsplus; Algo =:= exsp; Algo =:= exrop -> + %% Printed with orig 'C' code and this seed + rand:seed_s({Algo, [12345678|12345678]}); + _ when Algo =:= exs1024; Algo =:= exs1024s -> + %% Printed with orig 'C' code and this seed + rand:seed_s({Algo, {lists:duplicate(16, 12345678), []}}); + _ -> + rand:seed(Algo, {100, 200, 300}) + end, + Max = range(State), + gen(?LOOP, State, Max, []). + +gen(N, State0, Max, Acc) when N > 0 -> {Random, State} = rand:uniform_s(Max, State0), case N rem (?LOOP div 100) of - 0 -> gen(N-1, State, [Random|Acc]); - _ -> gen(N-1, State, Acc) + 0 -> gen(N-1, State, Max, [Random|Acc]); + _ -> gen(N-1, State, Max, Acc) end; -gen(_, _, Acc) -> lists:reverse(Acc). +gen(_, _, _, Acc) -> lists:reverse(Acc). %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% This just tests the basics so we have not made any serious errors @@ -294,12 +303,35 @@ basic_stats_uniform_2(Config) when is_list(Config) -> || Alg <- algs()], ok. -basic_stats_normal(Config) when is_list(Config) -> +basic_stats_standard_normal(Config) when is_list(Config) -> ct:timetrap({minutes,6}), %% valgrind needs a lot of time - io:format("Testing normal~n",[]), - [basic_normal_1(?LOOP, rand:seed_s(Alg), 0, 0) || Alg <- algs()], + io:format("Testing standard normal~n",[]), + IntendedMean = 0, + IntendedVariance = 1, + [basic_normal_1(?LOOP, IntendedMean, IntendedVariance, + rand:seed_s(Alg), 0, 0) + || Alg <- algs()], ok. +basic_stats_normal(Config) when is_list(Config) -> + IntendedMeans = [-1.0e6, -50, -math:pi(), -math:exp(-1), + 0.12345678, math:exp(1), 100, 1.0e6], + IntendedVariances = [1.0e-6, math:exp(-1), 1, math:pi(), 1.0e6], + IntendedMeanVariancePairs = + [{Mean, Variance} || Mean <- IntendedMeans, + Variance <- IntendedVariances], + + ct:timetrap({minutes, 6 * length(IntendedMeanVariancePairs)}), %% valgrind needs a lot of time + lists:foreach( + fun ({IntendedMean, IntendedVariance}) -> + io:format("Testing normal(~.2f, ~.2f)~n", + [float(IntendedMean), float(IntendedVariance)]), + [basic_normal_1(?LOOP, IntendedMean, IntendedVariance, + rand:seed_s(Alg), 0, 0) + || Alg <- algs()] + end, + IntendedMeanVariancePairs). + basic_uniform_1(N, S0, Sum, A0) when N > 0 -> {X,S} = rand:uniform_s(S0), I = trunc(X*100), @@ -307,11 +339,11 @@ basic_uniform_1(N, S0, Sum, A0) when N > 0 -> basic_uniform_1(N-1, S, Sum+X, A); basic_uniform_1(0, {#{type:=Alg}, _}, Sum, A) -> AverN = Sum / ?LOOP, - io:format("~.10w: Average: ~.4f~n", [Alg, AverN]), + io:format("~.12w: Average: ~.4f~n", [Alg, AverN]), Counters = array:to_list(A), Min = lists:min(Counters), Max = lists:max(Counters), - io:format("~.10w: Min: ~p Max: ~p~n", [Alg, Min, Max]), + io:format("~.12w: Min: ~p Max: ~p~n", [Alg, Min, Max]), %% Verify that the basic statistics are ok %% be gentle we don't want to see to many failing tests @@ -326,11 +358,11 @@ basic_uniform_2(N, S0, Sum, A0) when N > 0 -> basic_uniform_2(N-1, S, Sum+X, A); basic_uniform_2(0, {#{type:=Alg}, _}, Sum, A) -> AverN = Sum / ?LOOP, - io:format("~.10w: Average: ~.4f~n", [Alg, AverN]), + io:format("~.12w: Average: ~.4f~n", [Alg, AverN]), Counters = tl(array:to_list(A)), Min = lists:min(Counters), Max = lists:max(Counters), - io:format("~.10w: Min: ~p Max: ~p~n", [Alg, Min, Max]), + io:format("~.12w: Min: ~p Max: ~p~n", [Alg, Min, Max]), %% Verify that the basic statistics are ok %% be gentle we don't want to see to many failing tests @@ -339,19 +371,33 @@ basic_uniform_2(0, {#{type:=Alg}, _}, Sum, A) -> abs(?LOOP div 100 - Max) < 1000 orelse ct:fail({max, Alg, Max}), ok. -basic_normal_1(N, S0, Sum, Sq) when N > 0 -> - {X,S} = rand:normal_s(S0), - basic_normal_1(N-1, S, X+Sum, X*X+Sq); -basic_normal_1(0, {#{type:=Alg}, _}, Sum, SumSq) -> - Mean = Sum / ?LOOP, - StdDev = math:sqrt((SumSq - (Sum*Sum/?LOOP))/(?LOOP - 1)), - io:format("~.10w: Average: ~7.4f StdDev ~6.4f~n", [Alg, Mean, StdDev]), +basic_normal_1(N, IntendedMean, IntendedVariance, S0, StandardSum, StandardSq) when N > 0 -> + {X,S} = normal_s(IntendedMean, IntendedVariance, S0), + % We now shape X into a standard normal distribution (in case it wasn't already) + % in order to minimise the accumulated error on Sum / SumSq; + % otherwise said error would prevent us of making a fair judgment on + % the overall distribution when targeting large means and variances. + StandardX = (X - IntendedMean) / math:sqrt(IntendedVariance), + basic_normal_1(N-1, IntendedMean, IntendedVariance, S, + StandardX+StandardSum, StandardX*StandardX+StandardSq); +basic_normal_1(0, _IntendedMean, _IntendedVariance, {#{type:=Alg}, _}, StandardSum, StandardSumSq) -> + StandardMean = StandardSum / ?LOOP, + StandardVariance = (StandardSumSq - (StandardSum*StandardSum/?LOOP))/(?LOOP - 1), + StandardStdDev = math:sqrt(StandardVariance), + io:format("~.12w: Standardised Average: ~7.4f, Standardised StdDev ~6.4f~n", + [Alg, StandardMean, StandardStdDev]), %% Verify that the basic statistics are ok %% be gentle we don't want to see to many failing tests - abs(Mean) < 0.005 orelse ct:fail({average, Alg, Mean}), - abs(StdDev - 1.0) < 0.005 orelse ct:fail({stddev, Alg, StdDev}), + abs(StandardMean) < 0.005 orelse ct:fail({average, Alg, StandardMean}), + abs(StandardStdDev - 1.0) < 0.005 orelse ct:fail({stddev, Alg, StandardStdDev}), ok. +normal_s(Mean, Variance, State0) when Mean == 0, Variance == 1 -> + % Make sure we're also testing the standard normal interface + rand:normal_s(State0); +normal_s(Mean, Variance, State0) -> + rand:normal_s(Mean, Variance, State0). + %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% Test that the user can write algorithms. @@ -365,7 +411,7 @@ plugin(Config) when is_list(Config) -> {V2, S2} = rand:uniform_s(S1), true = is_float(V2), S2 - end, crypto_seed(), lists:seq(1, 200)), + end, crypto64_seed(), lists:seq(1, 200)), ok catch error:low_entropy -> @@ -375,86 +421,220 @@ plugin(Config) when is_list(Config) -> end. %% Test implementation -crypto_seed() -> - {#{type=>crypto, - max=>(1 bsl 64)-1, - next=>fun crypto_next/1, - uniform=>fun crypto_uniform/1, - uniform_n=>fun crypto_uniform_n/2}, +crypto64_seed() -> + {#{type=>crypto64, + bits=>64, + next=>fun crypto64_next/1, + uniform=>fun crypto64_uniform/1, + uniform_n=>fun crypto64_uniform_n/2}, <<>>}. %% Be fair and create bignums i.e. 64bits otherwise use 58bits -crypto_next(<<Num:64, Bin/binary>>) -> +crypto64_next(<<Num:64, Bin/binary>>) -> {Num, Bin}; -crypto_next(_) -> - crypto_next(crypto:strong_rand_bytes((64 div 8)*100)). +crypto64_next(_) -> + crypto64_next(crypto:strong_rand_bytes((64 div 8)*100)). -crypto_uniform({Api, Data0}) -> - {Int, Data} = crypto_next(Data0), +crypto64_uniform({Api, Data0}) -> + {Int, Data} = crypto64_next(Data0), {Int / (1 bsl 64), {Api, Data}}. -crypto_uniform_n(N, {Api, Data0}) when N < (1 bsl 64) -> - {Int, Data} = crypto_next(Data0), +crypto64_uniform_n(N, {Api, Data0}) when N < (1 bsl 64) -> + {Int, Data} = crypto64_next(Data0), {(Int rem N)+1, {Api, Data}}; -crypto_uniform_n(N, State0) -> - {F,State} = crypto_uniform(State0), +crypto64_uniform_n(N, State0) -> + {F,State} = crypto64_uniform(State0), {trunc(F * N) + 1, State}. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% Not a test but measures the time characteristics of the different algorithms -measure(Suite) when is_atom(Suite) -> []; -measure(_Config) -> - ct:timetrap({minutes,15}), %% valgrind needs a lot of time +measure(Config) -> + ct:timetrap({minutes,30}), %% valgrind needs a lot of time + case ct:get_timetrap_info() of + {_,{_,1}} -> % No scaling + do_measure(Config); + {_,{_,Scale}} -> + {skip,{will_not_run_in_scaled_time,Scale}} + end. + +do_measure(_Config) -> Algos = try crypto:strong_rand_bytes(1) of - <<_>> -> [crypto64] + <<_>> -> [crypto64, crypto] catch error:low_entropy -> []; error:undef -> [] end ++ algs(), - io:format("RNG uniform integer performance~n",[]), - _ = measure_1(random, fun(State) -> {int, random:uniform_s(10000, State)} end), - _ = [measure_1(Algo, fun(State) -> {int, rand:uniform_s(10000, State)} end) || Algo <- Algos], - io:format("RNG uniform float performance~n",[]), - _ = measure_1(random, fun(State) -> {uniform, random:uniform_s(State)} end), - _ = [measure_1(Algo, fun(State) -> {uniform, rand:uniform_s(State)} end) || Algo <- Algos], - io:format("RNG normal float performance~n",[]), - io:format("~.10w: not implemented (too few bits)~n", [random]), - _ = [measure_1(Algo, fun(State) -> {normal, rand:normal_s(State)} end) || Algo <- Algos], + %% + ct:pal("RNG uniform integer performance~n",[]), + TMark1 = + measure_1( + random, + fun (_) -> 10000 end, + undefined, + fun (Range, State) -> + {int, random:uniform_s(Range, State)} + end), + _ = + [measure_1( + Algo, + fun (_) -> 10000 end, + TMark1, + fun (Range, State) -> + {int, rand:uniform_s(Range, State)} + end) || Algo <- Algos], + %% + ct:pal("~nRNG uniform integer 2^(N-1) performance~n",[]), + RangeTwoPowFun = fun (State) -> quart_range(State) bsl 1 end, + TMark2 = + measure_1( + random, + RangeTwoPowFun, + undefined, + fun (Range, State) -> + {int, random:uniform_s(Range, State)} + end), + _ = + [measure_1( + Algo, + RangeTwoPowFun, + TMark2, + fun (Range, State) -> + {int, rand:uniform_s(Range, State)} + end) || Algo <- Algos], + %% + ct:pal("~nRNG uniform integer 3*2^(N-2)+1 performance~n",[]), + RangeLargeFun = fun (State) -> 3 * quart_range(State) + 1 end, + TMark3 = + measure_1( + random, + RangeLargeFun, + undefined, + fun (Range, State) -> + {int, random:uniform_s(Range, State)} + end), + _ = + [measure_1( + Algo, + RangeLargeFun, + TMark3, + fun (Range, State) -> + {int, rand:uniform_s(Range, State)} + end) || Algo <- Algos], + %% + ct:pal("~nRNG uniform integer 2^128 performance~n",[]), + TMark4 = + measure_1( + random, + fun (_) -> 1 bsl 128 end, + undefined, + fun (Range, State) -> + {int, random:uniform_s(Range, State)} + end), + _ = + [measure_1( + Algo, + fun (_) -> 1 bsl 128 end, + TMark4, + fun (Range, State) -> + {int, rand:uniform_s(Range, State)} + end) || Algo <- Algos], + %% + ct:pal("~nRNG uniform integer 2^128 + 1 performance~n",[]), + TMark5 = + measure_1( + random, + fun (_) -> (1 bsl 128) + 1 end, + undefined, + fun (Range, State) -> + {int, random:uniform_s(Range, State)} + end), + _ = + [measure_1( + Algo, + fun (_) -> (1 bsl 128) + 1 end, + TMark5, + fun (Range, State) -> + {int, rand:uniform_s(Range, State)} + end) || Algo <- Algos], + %% + ct:pal("~nRNG uniform float performance~n",[]), + TMark6 = + measure_1( + random, + fun (_) -> 0 end, + undefined, + fun (_, State) -> + {uniform, random:uniform_s(State)} + end), + _ = + [measure_1( + Algo, + fun (_) -> 0 end, + TMark6, + fun (_, State) -> + {uniform, rand:uniform_s(State)} + end) || Algo <- Algos], + %% + ct:pal("~nRNG normal float performance~n",[]), + io:format("~.12w: not implemented (too few bits)~n", [random]), + _ = [measure_1( + Algo, + fun (_) -> 0 end, + TMark6, + fun (_, State) -> + {normal, rand:normal_s(State)} + end) || Algo <- Algos], ok. -measure_1(Algo, Gen) -> +measure_1(Algo, RangeFun, TMark, Gen) -> Parent = self(), - Seed = fun(crypto64) -> crypto_seed(); - (random) -> random:seed(os:timestamp()), get(random_seed); - (Alg) -> rand:seed_s(Alg) - end, - - Pid = spawn_link(fun() -> - Fun = fun() -> measure_2(?LOOP, Seed(Algo), Gen) end, - {Time, ok} = timer:tc(Fun), - io:format("~.10w: ~pµs~n", [Algo, Time]), - Parent ! {self(), ok}, - normal - end), + Seed = + case Algo of + crypto64 -> + crypto64_seed(); + crypto -> + crypto:rand_seed_s(); + random -> + random:seed(os:timestamp()), get(random_seed); + _ -> + rand:seed_s(Algo) + end, + Range = RangeFun(Seed), + Pid = spawn_link( + fun() -> + Fun = fun() -> measure_2(?LOOP, Range, Seed, Gen) end, + {Time, ok} = timer:tc(Fun), + Percent = + case TMark of + undefined -> 100; + _ -> (Time * 100 + 50) div TMark + end, + io:format( + "~.12w: ~p ns ~p% [16#~.16b]~n", + [Algo, (Time * 1000 + 500) div ?LOOP, Percent, Range]), + Parent ! {self(), Time}, + normal + end), receive {Pid, Msg} -> Msg end. -measure_2(N, State0, Fun) when N > 0 -> - case Fun(State0) of +measure_2(N, Range, State0, Fun) when N > 0 -> + case Fun(Range, State0) of {int, {Random, State}} - when is_integer(Random), Random >= 1, Random =< 100000 -> - measure_2(N-1, State, Fun); - {uniform, {Random, State}} when is_float(Random), Random > 0, Random < 1 -> - measure_2(N-1, State, Fun); + when is_integer(Random), Random >= 1, Random =< Range -> + measure_2(N-1, Range, State, Fun); + {uniform, {Random, State}} + when is_float(Random), 0.0 =< Random, Random < 1.0 -> + measure_2(N-1, Range, State, Fun); {normal, {Random, State}} when is_float(Random) -> - measure_2(N-1, State, Fun); + measure_2(N-1, Range, State, Fun); Res -> exit({error, Res, State0}) end; -measure_2(0, _, _) -> ok. +measure_2(0, _, _, _) -> ok. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %% The jump sequence tests has two parts @@ -479,36 +659,43 @@ reference_jump_1(Alg) -> io:format("Failed: ~p~n",[Alg]), io:format("Length ~p ~p~n",[length(Refval), length(Testval)]), io:format("Head ~p ~p~n",[hd(Refval), hd(Testval)]), + io:format("Vals ~p ~p~n",[Refval, Testval]), exit(wrong_value) end. gen_jump_1(Algo) -> - Seed = case Algo of - exsplus -> %% Printed with orig 'C' code and this seed - rand:seed_s({exsplus, [12345678|12345678]}); - exs1024 -> %% Printed with orig 'C' code and this seed - rand:seed_s({exs1024, {lists:duplicate(16, 12345678), []}}); - exs64 -> %% Test exception of not_implemented notice - try rand:jump(rand:seed_s(exs64)) - catch - error:not_implemented -> not_implemented - end; - _ -> % unimplemented - not_implemented - end, - case Seed of + State = + case Algo of + exs64 -> %% Test exception of not_implemented notice + try rand:jump(rand:seed_s(exs64)) + catch + error:not_implemented -> not_implemented + end; + _ when Algo =:= exsplus; Algo =:= exsp; Algo =:= exrop -> + %% Printed with orig 'C' code and this seed + rand:seed_s({Algo, [12345678|12345678]}); + _ when Algo =:= exs1024; Algo =:= exs1024s -> + %% Printed with orig 'C' code and this seed + rand:seed_s({Algo, {lists:duplicate(16, 12345678), []}}); + _ -> % unimplemented + not_implemented + end, + case State of not_implemented -> [not_implemented]; - S -> gen_jump_1(?LOOP_JUMP, S, []) + _ -> + Max = range(State), + gen_jump_1(?LOOP_JUMP, State, Max, []) end. -gen_jump_1(N, State0 = {#{max:=Max}, _}, Acc) when N > 0 -> +gen_jump_1(N, State0, Max, Acc) when N > 0 -> {_, State1} = rand:uniform_s(Max, State0), {Random, State2} = rand:uniform_s(Max, rand:jump(State1)), case N rem (?LOOP_JUMP div 100) of - 0 -> gen_jump_1(N-1, State2, [Random|Acc]); - _ -> gen_jump_1(N-1, State2, Acc) + 0 -> gen_jump_1(N-1, State2, Max, [Random|Acc]); + _ -> gen_jump_1(N-1, State2, Max, Acc) end; -gen_jump_1(_, _, Acc) -> lists:reverse(Acc). +gen_jump_1(_, _, _, Acc) -> lists:reverse(Acc). + %% Check if each algorithm generates the proper jump sequence %% with the internal state in the process dictionary. @@ -530,25 +717,26 @@ reference_jump_0(Alg) -> gen_jump_0(Algo) -> Seed = case Algo of - exsplus -> %% Printed with orig 'C' code and this seed - rand:seed({exsplus, [12345678|12345678]}); - exs1024 -> %% Printed with orig 'C' code and this seed - rand:seed({exs1024, {lists:duplicate(16, 12345678), []}}); exs64 -> %% Test exception of not_implemented notice - try - _ = rand:seed(exs64), - rand:jump() - catch - error:not_implemented -> not_implemented - end; + try + _ = rand:seed(exs64), + rand:jump() + catch + error:not_implemented -> not_implemented + end; + _ when Algo =:= exsplus; Algo =:= exsp; Algo =:= exrop -> + %% Printed with orig 'C' code and this seed + rand:seed({Algo, [12345678|12345678]}); + _ when Algo =:= exs1024; Algo =:= exs1024s -> + %% Printed with orig 'C' code and this seed + rand:seed({Algo, {lists:duplicate(16, 12345678), []}}); _ -> % unimplemented not_implemented end, case Seed of not_implemented -> [not_implemented]; - S -> - {Seedmap=#{}, _} = S, - Max = maps:get(max, Seedmap), + _ -> + Max = range(Seed), gen_jump_0(?LOOP_JUMP, Max, []) end. @@ -643,9 +831,77 @@ reference_val(exsplus) -> 16#6c6145ffa1169d,16#18ec2c393d45359,16#1f1a5f256e7130c,16#131cc2f49b8004f, 16#36f715a249f4ec2,16#1c27629826c50d3,16#914d9a6648726a,16#27f5bf5ce2301e8, 16#3dd493b8012970f,16#be13bed1e00e5c,16#ceef033b74ae10,16#3da38c6a50abe03, - 16#15cbd1a421c7a8c,16#22794e3ec6ef3b1,16#26154d26e7ea99f,16#3a66681359a6ab6]. + 16#15cbd1a421c7a8c,16#22794e3ec6ef3b1,16#26154d26e7ea99f,16#3a66681359a6ab6]; + +reference_val(exsp) -> + reference_val(exsplus); +reference_val(exs1024s) -> + reference_val(exs1024); +reference_val(exrop) -> +%% #include <stdint.h> +%% #include <stdio.h> +%% +%% uint64_t s[2]; +%% uint64_t next(void); +%% /* Xoroshiro116+ PRNG here */ +%% +%% int main(char *argv[]) { +%% int n; +%% uint64_t r; +%% s[0] = 12345678; +%% s[1] = 12345678; +%% +%% for (n = 1000000; n > 0; n--) { +%% r = next(); +%% if ((n % 10000) == 0) { +%% printf("%llu,", (unsigned long long) (r + 1)); +%% } +%% } +%% printf("\n"); +%% } + [24691357,29089185972758626,135434857127264790, + 277209758236304485,101045429972817342, + 241950202080388093,283018380268425711,268233672110762489, + 173241488791227202,245038518481669421, + 253627577363613736,234979870724373477,115607127954560275, + 96445882796968228,166106849348423677, + 83614184550774836,109634510785746957,68415533259662436, + 12078288820568786,246413981014863011, + 96953486962147513,138629231038332640,206078430370986460, + 11002780552565714,238837272913629203, + 60272901610411077,148828243883348685,203140738399788939, + 131001610760610046,30717739120305678, + 262903815608472425,31891125663924935,107252017522511256, + 241577109487224033,263801934853180827, + 155517416581881714,223609336630639997,112175917931581716, + 16523497284706825,201453767973653420, + 35912153101632769,211525452750005043,96678037860996922, + 70962216125870068,107383886372877124, + 223441708670831233,247351119445661499,233235283318278995, + 280646255087307741,232948506631162445, + %% + 117394974124526779,55395923845250321,274512622756597759, + 31754154862553492,222645458401498438, + 161643932692872858,11771755227312868,93933211280589745, + 92242631276348831,197206910466548143, + 150370169849735808,229903773212075765,264650708561842793, + 30318996509793571,158249985447105184, + 220423733894955738,62892844479829080,112941952955911674, + 203157000073363030,54175707830615686, + 50121351829191185,115891831802446962,62298417197154985, + 6569598473421167,69822368618978464, + 176271134892968134,160793729023716344,271997399244980560, + 59100661824817999,150500611720118722, + 23707133151561128,25156834940231911,257788052162304719, + 176517852966055005,247173855600850875, + 83440973524473396,94711136045581604,154881198769946042, + 236537934330658377,152283781345006019, + 250789092615679985,78848633178610658,72059442721196128, + 98223942961505519,191144652663779840, + 102425686803727694,89058927716079076,80721467542933080, + 8462479817391645,2774921106204163]. -%%% +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% reference_jump_val(exsplus) -> [82445318862816932, 145810727464480743, 16514517716894509, 247642377064868650, @@ -701,4 +957,93 @@ reference_jump_val(exs1024) -> 17936751184378118743, 4224632875737239207, 15888641556987476199, 9586888813112229805, 9476861567287505094, 14909536929239540332, 17996844556292992842, 2699310519182298856]; -reference_jump_val(exs64) -> [not_implemented]. +reference_jump_val(exsp) -> + reference_jump_val(exsplus); +reference_jump_val(exs1024s) -> + reference_jump_val(exs1024); +reference_jump_val(exs64) -> [not_implemented]; +reference_jump_val(exrop) -> +%% #include <stdint.h> +%% #include <stdio.h> +%% +%% uint64_t s[2]; +%% uint64_t next(void); +%% /* Xoroshiro116+ PRNG here */ +%% +%% int main(char *argv[]) { +%% int n; +%% uint64_t r; +%% s[0] = 12345678; +%% s[1] = 12345678; + +%% for (n = 1000; n > 0; n--) { +%% next(); +%% jump(); +%% r = next(); +%% if ((n % 10) == 0) { +%% printf("%llu,", (unsigned long long) (r + 1)); +%% } +%% } +%% printf("\n"); +%% } + [60301713907476001,135397949584721850,4148159712710727, + 110297784509908316,18753463199438866, + 106699913259182846,2414728156662676,237591345910610406, + 48519427605486503,38071665570452612, + 235484041375354592,45428997361037927,112352324717959775, + 226084403445232507,270797890380258829, + 160587966336947922,80453153271416820,222758573634013699, + 195715386237881435,240975253876429810, + 93387593470886224,23845439014202236,235376123357642262, + 22286175195310374,239068556844083490, + 120126027410954482,250690865061862527,113265144383673111, + 57986825640269127,206087920253971490, + 265971029949338955,40654558754415167,185972161822891882, + 72224917962819036,116613804322063968, + 129103518989198416,236110607653724474,98446977363728314, + 122264213760984600,55635665885245081, + 42625530794327559,288031254029912894,81654312180555835, + 261800844953573559,144734008151358432, + 77095621402920587,286730580569820386,274596992060316466, + 97977034409404188,5517946553518132, + %% + 56460292644964432,252118572460428657,38694442746260303, + 165653145330192194,136968555571402812, + 64905200201714082,257386366768713186,22702362175273017, + 208480936480037395,152926769756967697, + 256751159334239189,130982960476845557,21613531985982870, + 87016962652282927,130446710536726404, + 188769410109327420,282891129440391928,251807515151187951, + 262029034126352975,30694713572208714, + 46430187445005589,176983177204884508,144190360369444480, + 14245137612606100,126045457407279122, + 169277107135012393,42599413368851184,130940158341360014, + 113412693367677211,119353175256553456, + 96339829771832349,17378172025472134,110141940813943768, + 253735613682893347,234964721082540068, + 85668779779185140,164542570671430062,18205512302089755, + 282380693509970845,190996054681051049, + 250227633882474729,171181147785250210,55437891969696407, + 241227318715885854,77323084015890802, + 1663590009695191,234064400749487599,222983191707424780, + 254956809144783896,203898972156838252]. + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% + +%% The old algorithms used a range 2^N - 1 for their reference val +%% tests, which was incorrect but works as long as you do not draw +%% the value 2^N, which is very unlikely. It was not possible +%% to simply correct the range to 2^N due to another incorrectness +%% in that the old algorithms changed to using the broken +%% (multiply a float approach with too few bits) approach for +%% ranges >= 2^N. This function digs out the range to use +%% for the reference tests for old and new algorithms. +range({#{bits:=Bits}, _}) -> 1 bsl Bits; +range({#{max:=Max}, _}) -> Max; %% Old incorrect range +range({_, _, _}) -> 51. % random + + +quart_range({#{bits:=Bits}, _}) -> 1 bsl (Bits - 2); +quart_range({#{max:=Max}, _}) -> (Max bsr 2) + 1; +quart_range({#{}, _}) -> 1 bsl 62; % crypto +quart_range({_, _, _}) -> 1 bsl 49. % random |