aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorGuilherme Andrade <[email protected]>2017-03-18 19:48:23 +0000
committerGuilherme Andrade <[email protected]>2017-04-20 21:30:53 +0100
commitdff85f3d6fdb4b3453d863bf9208073564a9fcf2 (patch)
tree57bd88d6776d40a89ae975e7f3cb569d783c2154
parent2cc47f22e4cb775422a6a7fb1b94836e7cf51143 (diff)
downloadotp-dff85f3d6fdb4b3453d863bf9208073564a9fcf2.tar.gz
otp-dff85f3d6fdb4b3453d863bf9208073564a9fcf2.tar.bz2
otp-dff85f3d6fdb4b3453d863bf9208073564a9fcf2.zip
rand: Support arbitrary normal distributions
-rw-r--r--lib/stdlib/doc/src/rand.xml23
-rw-r--r--lib/stdlib/src/rand.erl16
-rw-r--r--lib/stdlib/test/rand_SUITE.erl65
3 files changed, 90 insertions, 14 deletions
diff --git a/lib/stdlib/doc/src/rand.xml b/lib/stdlib/doc/src/rand.xml
index 8745e16908..3a5d18dc35 100644
--- a/lib/stdlib/doc/src/rand.xml
+++ b/lib/stdlib/doc/src/rand.xml
@@ -119,6 +119,11 @@ S0 = rand:seed_s(exsplus),
<pre>
{SND0, S2} = rand:normal_s(S1),</pre>
+ <p>Create a normal deviate with mean -3 and variance 0.5:</p>
+
+ <pre>
+{ND0, S3} = rand:normal_s(-3, 0.5, S2),</pre>
+
<note>
<p>This random number generator is not cryptographically
strong. If a strong cryptographic random number generator is
@@ -201,6 +206,15 @@ S0 = rand:seed_s(exsplus),
</func>
<func>
+ <name name="normal" arity="2"/>
+ <fsummary>Return a normal distributed random float.</fsummary>
+ <desc>
+ <p>Returns a normal N(Mean, Variance) deviate float
+ and updates the state in the process dictionary.</p>
+ </desc>
+ </func>
+
+ <func>
<name name="normal_s" arity="1"/>
<fsummary>Return a standard normal distributed random float.</fsummary>
<desc>
@@ -211,6 +225,15 @@ S0 = rand:seed_s(exsplus),
</func>
<func>
+ <name name="normal_s" arity="3"/>
+ <fsummary>Return a normal distributed random float.</fsummary>
+ <desc>
+ <p>Returns, for a specified state, a normal N(Mean, Variance)
+ deviate float and a new state.</p>
+ </desc>
+ </func>
+
+ <func>
<name name="seed" arity="1"/>
<fsummary>Seed random number generator.</fsummary>
<desc>
diff --git a/lib/stdlib/src/rand.erl b/lib/stdlib/src/rand.erl
index 1f457b9e0e..cef83da287 100644
--- a/lib/stdlib/src/rand.erl
+++ b/lib/stdlib/src/rand.erl
@@ -28,7 +28,7 @@
export_seed/0, export_seed_s/1,
uniform/0, uniform/1, uniform_s/1, uniform_s/2,
jump/0, jump/1,
- normal/0, normal_s/1
+ normal/0, normal/2, normal_s/1, normal_s/3
]).
-compile({inline, [exs64_next/1, exsplus_next/1,
@@ -178,6 +178,13 @@ normal() ->
_ = seed_put(Seed),
X.
+%% normal/2: returns a random float with N(μ, σ²) normal distribution
+%% updating the state in the process dictionary.
+
+-spec normal(Mean :: number(), Variance :: number()) -> float().
+normal(Mean, Variance) ->
+ Mean + (math:sqrt(Variance) * normal()).
+
%% normal_s/1: returns a random float with standard normal distribution
%% The Ziggurat Method for generating random variables - Marsaglia and Tsang
%% Paper and reference code: http://www.jstatsoft.org/v05/i08/
@@ -198,6 +205,13 @@ normal_s(State0) ->
false -> normal_s(Idx, Sign, -X, State)
end.
+%% normal_s/3: returns a random float with normal N(μ, σ²) distribution
+
+-spec normal_s(Mean :: number(), Variance :: number(), state()) -> {float(), NewS :: state()}.
+normal_s(Mean, Variance, State0) when Variance > 0 ->
+ {X, State} = normal_s(State0),
+ {Mean + (math:sqrt(Variance) * X), State}.
+
%% =====================================================================
%% Internal functions
diff --git a/lib/stdlib/test/rand_SUITE.erl b/lib/stdlib/test/rand_SUITE.erl
index fe5eaccda5..cb97d27992 100644
--- a/lib/stdlib/test/rand_SUITE.erl
+++ b/lib/stdlib/test/rand_SUITE.erl
@@ -27,6 +27,7 @@
-export([interval_int/1, interval_float/1, seed/1,
api_eq/1, reference/1,
basic_stats_uniform_1/1, basic_stats_uniform_2/1,
+ basic_stats_standard_normal/1,
basic_stats_normal/1,
plugin/1, measure/1,
reference_jump_state/1, reference_jump_procdict/1]).
@@ -52,7 +53,8 @@ all() ->
groups() ->
[{basic_stats, [parallel],
- [basic_stats_uniform_1, basic_stats_uniform_2, basic_stats_normal]},
+ [basic_stats_uniform_1, basic_stats_uniform_2,
+ basic_stats_standard_normal, basic_stats_normal]},
{reference_jump, [parallel],
[reference_jump_state, reference_jump_procdict]}].
@@ -294,12 +296,35 @@ basic_stats_uniform_2(Config) when is_list(Config) ->
|| Alg <- algs()],
ok.
-basic_stats_normal(Config) when is_list(Config) ->
+basic_stats_standard_normal(Config) when is_list(Config) ->
ct:timetrap({minutes,6}), %% valgrind needs a lot of time
- io:format("Testing normal~n",[]),
- [basic_normal_1(?LOOP, rand:seed_s(Alg), 0, 0) || Alg <- algs()],
+ io:format("Testing standard normal~n",[]),
+ IntendedMean = 0,
+ IntendedVariance = 1,
+ [basic_normal_1(?LOOP, IntendedMean, IntendedVariance,
+ rand:seed_s(Alg), 0, 0)
+ || Alg <- algs()],
ok.
+basic_stats_normal(Config) when is_list(Config) ->
+ IntendedMeans = [-1.0e6, -50, -math:pi(), -math:exp(-1),
+ 0.12345678, math:exp(1), 100, 1.0e6],
+ IntendedVariances = [1.0e-6, math:exp(-1), 1, math:pi(), 1.0e6],
+ IntendedMeanVariancePairs =
+ [{Mean, Variance} || Mean <- IntendedMeans,
+ Variance <- IntendedVariances],
+
+ ct:timetrap({minutes, 6 * length(IntendedMeanVariancePairs)}), %% valgrind needs a lot of time
+ lists:foreach(
+ fun ({IntendedMean, IntendedVariance}) ->
+ io:format("Testing normal(~.2f, ~.2f)~n",
+ [float(IntendedMean), float(IntendedVariance)]),
+ [basic_normal_1(?LOOP, IntendedMean, IntendedVariance,
+ rand:seed_s(Alg), 0, 0)
+ || Alg <- algs()]
+ end,
+ IntendedMeanVariancePairs).
+
basic_uniform_1(N, S0, Sum, A0) when N > 0 ->
{X,S} = rand:uniform_s(S0),
I = trunc(X*100),
@@ -339,19 +364,33 @@ basic_uniform_2(0, {#{type:=Alg}, _}, Sum, A) ->
abs(?LOOP div 100 - Max) < 1000 orelse ct:fail({max, Alg, Max}),
ok.
-basic_normal_1(N, S0, Sum, Sq) when N > 0 ->
- {X,S} = rand:normal_s(S0),
- basic_normal_1(N-1, S, X+Sum, X*X+Sq);
-basic_normal_1(0, {#{type:=Alg}, _}, Sum, SumSq) ->
- Mean = Sum / ?LOOP,
- StdDev = math:sqrt((SumSq - (Sum*Sum/?LOOP))/(?LOOP - 1)),
- io:format("~.10w: Average: ~7.4f StdDev ~6.4f~n", [Alg, Mean, StdDev]),
+basic_normal_1(N, IntendedMean, IntendedVariance, S0, StandardSum, StandardSq) when N > 0 ->
+ {X,S} = normal_s(IntendedMean, IntendedVariance, S0),
+ % We now shape X into a standard normal distribution (in case it wasn't already)
+ % in order to minimise the accumulated error on Sum / SumSq;
+ % otherwise said error would prevent us of making a fair judgment on
+ % the overall distribution when targeting large means and variances.
+ StandardX = (X - IntendedMean) / math:sqrt(IntendedVariance),
+ basic_normal_1(N-1, IntendedMean, IntendedVariance, S,
+ StandardX+StandardSum, StandardX*StandardX+StandardSq);
+basic_normal_1(0, _IntendedMean, _IntendedVariance, {#{type:=Alg}, _}, StandardSum, StandardSumSq) ->
+ StandardMean = StandardSum / ?LOOP,
+ StandardVariance = (StandardSumSq - (StandardSum*StandardSum/?LOOP))/(?LOOP - 1),
+ StandardStdDev = math:sqrt(StandardVariance),
+ io:format("~.10w: Standardised Average: ~7.4f, Standardised StdDev ~6.4f~n",
+ [Alg, StandardMean, StandardStdDev]),
%% Verify that the basic statistics are ok
%% be gentle we don't want to see to many failing tests
- abs(Mean) < 0.005 orelse ct:fail({average, Alg, Mean}),
- abs(StdDev - 1.0) < 0.005 orelse ct:fail({stddev, Alg, StdDev}),
+ abs(StandardMean) < 0.005 orelse ct:fail({average, Alg, StandardMean}),
+ abs(StandardStdDev - 1.0) < 0.005 orelse ct:fail({stddev, Alg, StandardStdDev}),
ok.
+normal_s(Mean, Variance, State0) when Mean == 0, Variance == 1 ->
+ % Make sure we're also testing the standard normal interface
+ rand:normal_s(State0);
+normal_s(Mean, Variance, State0) ->
+ rand:normal_s(Mean, Variance, State0).
+
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
%% Test that the user can write algorithms.