diff options
author | Dan Gudmundsson <[email protected]> | 2016-11-17 09:51:15 +0100 |
---|---|---|
committer | Dan Gudmundsson <[email protected]> | 2016-11-17 09:51:15 +0100 |
commit | 1c000086275c06596ad402081be12ef95db6ea40 (patch) | |
tree | 8c2f2bc3548aa90115d38784b40a324dc80a076e /lib | |
parent | a1a4aa8c66d10cbfb22b7221b2f61f302efae47f (diff) | |
parent | ff568b5e818d04048009926a7fa2ea537d2e656d (diff) | |
download | otp-1c000086275c06596ad402081be12ef95db6ea40.tar.gz otp-1c000086275c06596ad402081be12ef95db6ea40.tar.bz2 otp-1c000086275c06596ad402081be12ef95db6ea40.zip |
Merge branch 'jj1bdx/stdlib/rand-jump/PR-1235/OTP-14038'
* jj1bdx/stdlib/rand-jump/PR-1235/OTP-14038:
Add jump functions to rand module
Diffstat (limited to 'lib')
-rw-r--r-- | lib/stdlib/doc/src/rand.xml | 35 | ||||
-rw-r--r-- | lib/stdlib/src/rand.erl | 146 | ||||
-rw-r--r-- | lib/stdlib/test/rand_SUITE.erl | 182 |
3 files changed, 345 insertions, 18 deletions
diff --git a/lib/stdlib/doc/src/rand.xml b/lib/stdlib/doc/src/rand.xml index 1dcc3de000..1364a3277b 100644 --- a/lib/stdlib/doc/src/rand.xml +++ b/lib/stdlib/doc/src/rand.xml @@ -41,6 +41,11 @@ Sebastiano Vigna</url>. The normal distribution algorithm uses the <url href="http://www.jstatsoft.org/v05/i08">Ziggurat Method by Marsaglia and Tsang</url>.</p> + <p>For some algorithms, jump functions are provided for generating + non-overlapping sequences for parallel computations. + The jump functions perform calculations + equivalent to perform a large number of repeated calls + for calculating new states. </p> <p>The following algorithms are provided:</p> @@ -48,14 +53,17 @@ <tag><c>exsplus</c></tag> <item> <p>Xorshift116+, 58 bits precision and period of 2^116-1</p> + <p>Jump function: equivalent to 2^64 calls</p> </item> <tag><c>exs64</c></tag> <item> <p>Xorshift64*, 64 bits precision and a period of 2^64-1</p> + <p>Jump function: not available</p> </item> <tag><c>exs1024</c></tag> <item> <p>Xorshift1024*, 64 bits precision and a period of 2^1024-1</p> + <p>Jump function: equivalent to 2^512 calls</p> </item> </taglist> @@ -156,6 +164,33 @@ S0 = rand:seed_s(exsplus), </func> <func> + <name name="jump" arity="0"/> + <fsummary>Return the seed after performing jump calculation + to the state in the process dictionary.</fsummary> + <desc><marker id="jump-0" /> + <p>Returns the state + after performing jump calculation + to the state in the process dictionary.</p> + <p>This function generates a <c>not_implemented</c> error exception + when the jump function is not implemented for + the algorithm specified in the state + in the process dictionary.</p> + </desc> + </func> + + <func> + <name name="jump" arity="1"/> + <fsummary>Return the seed after performing jump calculation.</fsummary> + <desc><marker id="jump-1" /> + <p>Returns the state after performing jump calculation + to the given state. </p> + <p>This function generates a <c>not_implemented</c> error exception + when the jump function is not implemented for + the algorithm specified in the state.</p> + </desc> + </func> + + <func> <name name="normal" arity="0"/> <fsummary>Return a standard normal distributed random float.</fsummary> <desc> diff --git a/lib/stdlib/src/rand.erl b/lib/stdlib/src/rand.erl index 93409d95df..3b1767e731 100644 --- a/lib/stdlib/src/rand.erl +++ b/lib/stdlib/src/rand.erl @@ -19,7 +19,7 @@ %% %% ===================================================================== %% Multiple PRNG module for Erlang/OTP -%% Copyright (c) 2015 Kenji Rikitake +%% Copyright (c) 2015-2016 Kenji Rikitake %% ===================================================================== -module(rand). @@ -27,11 +27,14 @@ -export([seed_s/1, seed_s/2, seed/1, seed/2, export_seed/0, export_seed_s/1, uniform/0, uniform/1, uniform_s/1, uniform_s/2, + jump/0, jump/1, normal/0, normal_s/1 ]). -compile({inline, [exs64_next/1, exsplus_next/1, + exsplus_jump/1, exs1024_next/1, exs1024_calc/2, + exs1024_jump/1, get_52/1, normal_kiwi/1]}). -define(DEFAULT_ALG_HANDLER, exsplus). @@ -48,7 +51,8 @@ max := integer(), next := fun(), uniform := fun(), - uniform_n := fun()}. + uniform_n := fun(), + jump := fun()}. %% Internal state -opaque state() :: {alg_handler(), alg_seed()}. @@ -79,9 +83,7 @@ export_seed_s({#{type:=Alg}, Seed}) -> {Alg, Seed}. -spec seed(AlgOrExpState::alg() | export_state()) -> state(). seed(Alg) -> - R = seed_s(Alg), - _ = seed_put(R), - R. + seed_put(seed_s(Alg)). -spec seed_s(AlgOrExpState::alg() | export_state()) -> state(). seed_s(Alg) when is_atom(Alg) -> @@ -97,9 +99,7 @@ seed_s({Alg0, Seed}) -> -spec seed(Alg :: alg(), {integer(), integer(), integer()}) -> state(). seed(Alg0, S0) -> - State = seed_s(Alg0, S0), - _ = seed_put(State), - State. + seed_put(seed_s(Alg0, S0)). -spec seed_s(Alg :: alg(), {integer(), integer(), integer()}) -> state(). seed_s(Alg0, S0 = {_, _, _}) -> @@ -150,6 +150,25 @@ uniform_s(N, State0 = {#{uniform:=Uniform}, _}) {F, State} = Uniform(State0), {trunc(F * N) + 1, State}. +%% jump/1: given a state, jump/1 +%% returns a new state which is equivalent to that +%% after a large number of call defined for each algorithm. +%% The large number is algorithm dependent. + +-spec jump(state()) -> {NewS :: state()}. +jump(State = {#{jump:=Jump}, _}) -> + Jump(State). + +%% jump/0: read the internal state and +%% apply the jump function for the state as in jump/1 +%% and write back the new value to the internal state, +%% then returns the new value. + +-spec jump() -> {NewS :: state()}. + +jump() -> + seed_put(jump(seed_get())). + %% normal/0: returns a random float with standard normal distribution %% updating the state in the process dictionary. @@ -192,9 +211,10 @@ normal_s(State0) -> -type uint64() :: 0..16#ffffffffffffffff. -type uint58() :: 0..16#03ffffffffffffff. --spec seed_put(state()) -> undefined | state(). +-spec seed_put(state()) -> state(). seed_put(Seed) -> - put(?SEED_DICT, Seed). + put(?SEED_DICT, Seed), + Seed. seed_get() -> case get(?SEED_DICT) of @@ -205,15 +225,18 @@ seed_get() -> %% Setup alg record mk_alg(exs64) -> {#{type=>exs64, max=>?UINT64MASK, next=>fun exs64_next/1, - uniform=>fun exs64_uniform/1, uniform_n=>fun exs64_uniform/2}, + uniform=>fun exs64_uniform/1, uniform_n=>fun exs64_uniform/2, + jump=>fun exs64_jump/1}, fun exs64_seed/1}; mk_alg(exsplus) -> {#{type=>exsplus, max=>?UINT58MASK, next=>fun exsplus_next/1, - uniform=>fun exsplus_uniform/1, uniform_n=>fun exsplus_uniform/2}, + uniform=>fun exsplus_uniform/1, uniform_n=>fun exsplus_uniform/2, + jump=>fun exsplus_jump/1}, fun exsplus_seed/1}; mk_alg(exs1024) -> {#{type=>exs1024, max=>?UINT64MASK, next=>fun exs1024_next/1, - uniform=>fun exs1024_uniform/1, uniform_n=>fun exs1024_uniform/2}, + uniform=>fun exs1024_uniform/1, uniform_n=>fun exs1024_uniform/2, + jump=>fun exs1024_jump/1}, fun exs1024_seed/1}. %% ===================================================================== @@ -246,6 +269,9 @@ exs64_uniform(Max, {Alg, R}) -> {V, R1} = exs64_next(R), {(V rem Max) + 1, {Alg, R1}}. +exs64_jump(_) -> + erlang:error(not_implemented). + %% ===================================================================== %% exsplus PRNG: Xorshift116+ %% Algorithm by Sebastiano Vigna @@ -283,6 +309,42 @@ exsplus_uniform(Max, {Alg, R}) -> {V, R1} = exsplus_next(R), {(V rem Max) + 1, {Alg, R1}}. +%% This is the jump function for the exsplus generator, equivalent +%% to 2^64 calls to next/1; it can be used to generate 2^52 +%% non-overlapping subsequences for parallel computations. +%% Note: the jump function takes 116 times of the execution time of +%% next/1. + +%% -define(JUMPCONST, 16#000d174a83e17de2302f8ea6bc32c797). +%% split into 58-bit chunks +%% and two iterative executions + +-define(JUMPCONST1, 16#02f8ea6bc32c797). +-define(JUMPCONST2, 16#345d2a0f85f788c). +-define(JUMPELEMLEN, 58). + +-spec exsplus_jump(exsplus_state()) -> exsplus_state(). + +exsplus_jump({Alg, S}) -> + {S1, AS1} = exsplus_jump(S, [0|0], ?JUMPCONST1, ?JUMPELEMLEN), + {_, AS2} = exsplus_jump(S1, AS1, ?JUMPCONST2, ?JUMPELEMLEN), + {Alg, AS2}. + +-spec exsplus_jump(state(), state(), pos_integer(), pos_integer()) -> + {state(), state()}. + +exsplus_jump(S, AS, _, 0) -> + {S, AS}; +exsplus_jump(S, [AS0|AS1], J, N) -> + {_, NS} = exsplus_next(S), + case (J band 1) of + 1 -> + [S0|S1] = S, + exsplus_jump(NS, [(AS0 bxor S0)|(AS1 bxor S1)], J bsr 1, N-1); + 0 -> + exsplus_jump(NS, [AS0|AS1], J bsr 1, N-1) + end. + %% ===================================================================== %% exs1024 PRNG: Xorshift1024* %% Algorithm by Sebastiano Vigna @@ -340,6 +402,64 @@ exs1024_uniform(Max, {Alg, R}) -> {V, R1} = exs1024_next(R), {(V rem Max) + 1, {Alg, R1}}. +%% This is the jump function for the exs1024 generator, equivalent +%% to 2^512 calls to next(); it can be used to generate 2^512 +%% non-overlapping subsequences for parallel computations. +%% Note: the jump function takes ~2000 times of the execution time of +%% next/1. + +%% Jump constant here split into 58 bits for speed +-define(JUMPCONSTHEAD, 16#00242f96eca9c41d). +-define(JUMPCONSTTAIL, + [16#0196e1ddbe5a1561, + 16#0239f070b5837a3c, + 16#03f393cc68796cd2, + 16#0248316f404489af, + 16#039a30088bffbac2, + 16#02fea70dc2d9891f, + 16#032ae0d9644caec4, + 16#0313aac17d8efa43, + 16#02f132e055642626, + 16#01ee975283d71c93, + 16#00552321b06f5501, + 16#00c41d10a1e6a569, + 16#019158ecf8aa1e44, + 16#004e9fc949d0b5fc, + 16#0363da172811fdda, + 16#030e38c3b99181f2, + 16#0000000a118038fc]). +-define(JUMPTOTALLEN, 1024). +-define(RINGLEN, 16). + +-spec exs1024_jump(state()) -> state(). + +exs1024_jump({Alg, {L, RL}}) -> + P = length(RL), + AS = exs1024_jump({L, RL}, + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], + ?JUMPCONSTTAIL, ?JUMPCONSTHEAD, ?JUMPELEMLEN, ?JUMPTOTALLEN), + {ASL, ASR} = lists:split(?RINGLEN - P, AS), + {Alg, {ASL, lists:reverse(ASR)}}. + +-spec exs1024_jump(state(), list(non_neg_integer()), + list(non_neg_integer()), non_neg_integer(), + non_neg_integer(), non_neg_integer()) -> list(non_neg_integer()). + +exs1024_jump(_, AS, _, _, _, 0) -> + AS; +exs1024_jump(S, AS, [H|T], _, 0, TN) -> + exs1024_jump(S, AS, T, H, ?JUMPELEMLEN, TN); +exs1024_jump({L, RL}, AS, JL, J, N, TN) -> + {_, NS} = exs1024_next({L, RL}), + case (J band 1) of + 1 -> + AS2 = lists:zipwith(fun(X, Y) -> X bxor Y end, + AS, L ++ lists:reverse(RL)), + exs1024_jump(NS, AS2, JL, J bsr 1, N-1, TN-1); + 0 -> + exs1024_jump(NS, AS, JL, J bsr 1, N-1, TN-1) + end. + %% ===================================================================== %% Ziggurat cont %% ===================================================================== diff --git a/lib/stdlib/test/rand_SUITE.erl b/lib/stdlib/test/rand_SUITE.erl index 02b7cb10c2..8e7ac223a7 100644 --- a/lib/stdlib/test/rand_SUITE.erl +++ b/lib/stdlib/test/rand_SUITE.erl @@ -28,7 +28,8 @@ api_eq/1, reference/1, basic_stats_uniform_1/1, basic_stats_uniform_2/1, basic_stats_normal/1, - plugin/1, measure/1]). + plugin/1, measure/1, + reference_jump_state/1, reference_jump_procdict/1]). -export([test/0, gen/1]). @@ -45,14 +46,21 @@ all() -> api_eq, reference, {group, basic_stats}, - plugin, measure]. + plugin, measure, + {group, reference_jump} + ]. groups() -> [{basic_stats, [parallel], - [basic_stats_uniform_1, basic_stats_uniform_2, basic_stats_normal]}]. + [basic_stats_uniform_1, basic_stats_uniform_2, basic_stats_normal]}, + {reference_jump, [parallel], + [reference_jump_state, reference_jump_procdict]}]. group(basic_stats) -> %% valgrind needs a lot of time + [{timetrap,{minutes,10}}]; +group(reference_jump) -> + %% valgrind needs a lot of time [{timetrap,{minutes,10}}]. %% A simple helper to test without test_server during dev @@ -228,7 +236,7 @@ interval_float_1(N) -> %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% -%% Check if exs64 algorithm generates the proper sequence. +%% Check if each algorithm generates the proper sequence. reference(Config) when is_list(Config) -> [reference_1(Alg) || Alg <- algs()], ok. @@ -242,7 +250,7 @@ reference_1(Alg) -> io:format("Failed: ~p~n",[Alg]), io:format("Length ~p ~p~n",[length(Refval), length(Testval)]), io:format("Head ~p ~p~n",[hd(Refval), hd(Testval)]), - ok + exit(wrong_value) end. gen(Algo) -> @@ -434,6 +442,112 @@ measure_2(N, State0, Fun) when N > 0 -> measure_2(0, _, _) -> ok. %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% +%% The jump sequence tests has two parts +%% for those with the functional API (jump/1) +%% and for those with the internal state +%% in process dictionary (jump/0). + +-define(LOOP_JUMP, (?LOOP div 1000)). + +%% Check if each algorithm generates the proper jump sequence +%% with the functional API. +reference_jump_state(Config) when is_list(Config) -> + [reference_jump_1(Alg) || Alg <- algs()], + ok. + +reference_jump_1(Alg) -> + Refval = reference_jump_val(Alg), + Testval = gen_jump_1(Alg), + case Refval =:= Testval of + true -> ok; + false -> + io:format("Failed: ~p~n",[Alg]), + io:format("Length ~p ~p~n",[length(Refval), length(Testval)]), + io:format("Head ~p ~p~n",[hd(Refval), hd(Testval)]), + exit(wrong_value) + end. + +gen_jump_1(Algo) -> + Seed = case Algo of + exsplus -> %% Printed with orig 'C' code and this seed + rand:seed_s({exsplus, [12345678|12345678]}); + exs1024 -> %% Printed with orig 'C' code and this seed + rand:seed_s({exs1024, {lists:duplicate(16, 12345678), []}}); + exs64 -> %% Test exception of not_implemented notice + try rand:jump(rand:seed_s(exs64)) + catch + error:not_implemented -> not_implemented + end; + _ -> % unimplemented + not_implemented + end, + case Seed of + not_implemented -> [not_implemented]; + S -> gen_jump_1(?LOOP_JUMP, S, []) + end. + +gen_jump_1(N, State0 = {#{max:=Max}, _}, Acc) when N > 0 -> + {_, State1} = rand:uniform_s(Max, State0), + {Random, State2} = rand:uniform_s(Max, rand:jump(State1)), + case N rem (?LOOP_JUMP div 100) of + 0 -> gen_jump_1(N-1, State2, [Random|Acc]); + _ -> gen_jump_1(N-1, State2, Acc) + end; +gen_jump_1(_, _, Acc) -> lists:reverse(Acc). + +%% Check if each algorithm generates the proper jump sequence +%% with the internal state in the process dictionary. +reference_jump_procdict(Config) when is_list(Config) -> + [reference_jump_0(Alg) || Alg <- algs()], + ok. + +reference_jump_0(Alg) -> + Refval = reference_jump_val(Alg), + Testval = gen_jump_0(Alg), + case Refval =:= Testval of + true -> ok; + false -> + io:format("Failed: ~p~n",[Alg]), + io:format("Length ~p ~p~n",[length(Refval), length(Testval)]), + io:format("Head ~p ~p~n",[hd(Refval), hd(Testval)]), + exit(wrong_value) + end. + +gen_jump_0(Algo) -> + Seed = case Algo of + exsplus -> %% Printed with orig 'C' code and this seed + rand:seed({exsplus, [12345678|12345678]}); + exs1024 -> %% Printed with orig 'C' code and this seed + rand:seed({exs1024, {lists:duplicate(16, 12345678), []}}); + exs64 -> %% Test exception of not_implemented notice + try + _ = rand:seed(exs64), + rand:jump() + catch + error:not_implemented -> not_implemented + end; + _ -> % unimplemented + not_implemented + end, + case Seed of + not_implemented -> [not_implemented]; + S -> + {Seedmap=#{}, _} = S, + Max = maps:get(max, Seedmap), + gen_jump_0(?LOOP_JUMP, Max, []) + end. + +gen_jump_0(N, Max, Acc) when N > 0 -> + _ = rand:uniform(Max), + _ = rand:jump(), + Random = rand:uniform(Max), + case N rem (?LOOP_JUMP div 100) of + 0 -> gen_jump_0(N-1, Max, [Random|Acc]); + _ -> gen_jump_0(N-1, Max, Acc) + end; +gen_jump_0(_, _, Acc) -> lists:reverse(Acc). + +%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% %%% Data reference_val(exs64) -> [16#3737ad0c703ff6c3,16#3868a78fe71adbbd,16#1f01b62b4338b605,16#50876a917437965f, @@ -515,3 +629,61 @@ reference_val(exsplus) -> 16#36f715a249f4ec2,16#1c27629826c50d3,16#914d9a6648726a,16#27f5bf5ce2301e8, 16#3dd493b8012970f,16#be13bed1e00e5c,16#ceef033b74ae10,16#3da38c6a50abe03, 16#15cbd1a421c7a8c,16#22794e3ec6ef3b1,16#26154d26e7ea99f,16#3a66681359a6ab6]. + +%%% + +reference_jump_val(exsplus) -> + [82445318862816932, 145810727464480743, 16514517716894509, 247642377064868650, + 162385642339156908, 251810707075252101, 82288275771998924, 234412731596926322, + 49960883129071044, 200690077681656596, 213743196668671647, 131182800982967108, + 144200072021941728, 263557425008503277, 194858522616874272, 185869394820993172, + 80384502675241453, 262654144824057588, 90033295011291362, 4494510449302659, + 226005372746479588, 116780561309220553, 47048528594475843, 39168929349768743, + 139615163424415552, 55330632656603925, 237575574720486569, 102381140288455025, + 18452933910354323, 150248612130579752, 269358096791922740, 61313433522002187, + 160327361842676597, 185187983548528938, 57378981505594193, 167510799293984067, + 105117045862954303, 176126685946302943, 123590876906828803, 69185336947273487, + 9098689247665808, 49906154674145057, 131575138412788650, 161843880211677185, + 30743946051071186, 187578920583823612, 45008401528636978, 122454158686456658, + 111195992644229524, 17962783958752862, 13579507636941108, 130137843317798663, + 144202635170576832, 132539563255093922, 159785575703967124, 187241848364816640, + 183044737781926478, 12921559769912263, 83553932242922001, 96698298841984688, + 281664320227537824, 224233030818578263, 77812932110318774, 169729351013291728, + 164475402723178734, 242780633011249051, 51095111179609125, 19249189591963554, + 221412426221439180, 265700202856282653, 265342254311932308, 241218503498385511, + 255400887248486575, 212083616929812076, 227947034485840579, 268261881651571692, + 104846262373404908, 49690734329496661, 213259196633566308, 186966479726202436, + 282157378232384574, 11272948584603747, 166540426999573480, 50628164001018755, + 65235580992800860, 230664399047956956, 64575592354687978, 40519393736078511, + 108341851194332747, 115426411532008961, 120656817002338193, 234537867870809797, + 12504080415362731, 45083100453836317, 270968267812126657, 93505647407734103, + 252852934678537969, 258758309277167202, 74250882143432077, 141629095984552833]; + +reference_jump_val(exs1024) -> + [2655961906500790629, 17003395417078685063, 10466831598958356428, 7603399148503548021, + 1650550950190587188, 12294992315080723704, 15743995773860389219, 5492181000145247327, + 14118165228742583601, 1024386975263610703, 10124872895886669513, 6445624517813169301, + 6238575554686562601, 14108646153524288915, 11804141635807832816, 8421575378006186238, + 6354993374304550369, 838493020029548163, 14759355804308819469, 12212491527912522022, + 16943204735100571602, 198964074252287588, 7325922870779721649, 15853102065526570574, + 16294058349151823341, 6153379962047409781, 15874031679495957261, 17299265255608442340, + 984658421210027171, 17408042033939375278, 3326465916992232353, 5222817718770538733, + 13262385796795170510, 15648751121811336061, 6718721549566546451, 7353765235619801875, + 16110995049882478788, 14559143407227563441, 4189805181268804683, 10938587948346538224, + 1635025506014383478, 12619562911869525411, 17469465615861488695, 125252234176411528, + 2004192558503448853, 13175467866790974840, 17712272336167363518, 1710549840100880318, + 17486892343528340916, 5337910082227550967, 8333082060923612691, 6284787745504163856, + 8072221024586708290, 6077032673910717705, 11495200863352251610, 11722792537523099594, + 14642059504258647996, 8595733246938141113, 17223366528010341891, 17447739753327015776, + 6149800490736735996, 11155866914574313276, 7123864553063709909, 15982886296520662323, + 5775920250955521517, 8624640108274906072, 8652974210855988961, 8715770416136907275, + 11841689528820039868, 10991309078149220415, 11758038663970841716, 7308750055935299261, + 15939068400245256963, 6920341533033919644, 8017706063646646166, 15814376391419160498, + 13529376573221932937, 16749061963269842448, 14639730709921425830, 3265850480169354066, + 4569394597532719321, 16594515239012200038, 13372824240764466517, 16892840440503406128, + 11260004846380394643, 2441660009097834955, 10566922722880085440, 11463315545387550692, + 5252492021914937692, 10404636333478845345, 11109538423683960387, 5525267334484537655, + 17936751184378118743, 4224632875737239207, 15888641556987476199, 9586888813112229805, + 9476861567287505094, 14909536929239540332, 17996844556292992842, 2699310519182298856]; + +reference_jump_val(exs64) -> [not_implemented]. |