From 0078ebf6c5311edc1a07be71fb7a127a175a60fa Mon Sep 17 00:00:00 2001 From: Ingela Anderton Andin Date: Wed, 15 Aug 2018 16:54:03 +0200 Subject: ssl: Adopt distribution over TLS to use new sender process --- lib/ssl/src/inet_tls_dist.erl | 15 ++- lib/ssl/src/ssl_connection.erl | 231 +++++++++++++++++------------------------ lib/ssl/src/ssl_connection.hrl | 1 + lib/ssl/src/ssl_internal.hrl | 2 +- lib/ssl/src/tls_connection.erl | 33 +++++- lib/ssl/src/tls_sender.erl | 94 +++++++++++++++-- 6 files changed, 221 insertions(+), 155 deletions(-) (limited to 'lib/ssl/src') diff --git a/lib/ssl/src/inet_tls_dist.erl b/lib/ssl/src/inet_tls_dist.erl index aa3d7e3f72..ca059603ae 100644 --- a/lib/ssl/src/inet_tls_dist.erl +++ b/lib/ssl/src/inet_tls_dist.erl @@ -69,14 +69,14 @@ is_node_name(Node) -> %% ------------------------------------------------------------------------- -hs_data_common(#sslsocket{pid = DistCtrl} = SslSocket) -> +hs_data_common(#sslsocket{pid = [_, DistCtrl|_]} = SslSocket) -> #hs_data{ f_send = - fun (Ctrl, Packet) when Ctrl == DistCtrl -> + fun (_Ctrl, Packet) -> f_send(SslSocket, Packet) end, f_recv = - fun (Ctrl, Length, Timeout) when Ctrl == DistCtrl -> + fun (_, Length, Timeout) -> f_recv(SslSocket, Length, Timeout) end, f_setopts_pre_nodeup = @@ -175,8 +175,7 @@ mf_getopts(SslSocket, Opts) -> ssl:getopts(SslSocket, Opts). f_handshake_complete(DistCtrl, Node, DHandle) -> - ssl_connection:handshake_complete(DistCtrl, Node, DHandle). - + tls_sender:dist_handshake_complete(DistCtrl, Node, DHandle). setopts_filter(Opts) -> [Opt || {K,_} = Opt <- Opts, @@ -244,7 +243,7 @@ accept_loop(Driver, Listen, Kernel, Socket) -> trace([{active, false},{packet, 4}|Opts]), net_kernel:connecttime()) of - {ok, #sslsocket{pid = DistCtrl} = SslSocket} -> + {ok, #sslsocket{pid = [_, DistCtrl| _]} = SslSocket} -> trace( Kernel ! {accept, self(), DistCtrl, @@ -404,7 +403,7 @@ gen_accept_connection( do_accept( _Driver, AcceptPid, DistCtrl, MyNode, Allowed, SetupTime, Kernel) -> - SslSocket = ssl_connection:get_sslsocket(DistCtrl), + {ok, SslSocket} = tls_sender:dist_tls_socket(DistCtrl), receive {AcceptPid, controller} -> Timer = dist_util:start_timer(SetupTime), @@ -529,7 +528,7 @@ do_setup_connect(Driver, Kernel, Node, Address, Ip, TcpPort, Version, Type, MyNo [binary, {active, false}, {packet, 4}, Driver:family(), nodelay()] ++ Opts, net_kernel:connecttime()) of - {ok, #sslsocket{pid = DistCtrl} = SslSocket} -> + {ok, #sslsocket{pid = [_, DistCtrl| _]} = SslSocket} -> _ = monitor_pid(DistCtrl), ok = ssl:controlling_process(SslSocket, self()), HSData0 = hs_data_common(SslSocket), diff --git a/lib/ssl/src/ssl_connection.erl b/lib/ssl/src/ssl_connection.erl index d83cc7225f..b17b7fbffe 100644 --- a/lib/ssl/src/ssl_connection.erl +++ b/lib/ssl/src/ssl_connection.erl @@ -64,13 +64,13 @@ %% General gen_statem state functions with extra callback argument %% to determine if it is an SSL/TLS or DTLS gen_statem machine -export([init/4, error/4, hello/4, user_hello/4, abbreviated/4, certify/4, cipher/4, - connection/4, death_row/4, downgrade/4]). + connection/4, death_row/4, death_row/2, downgrade/4]). %% gen_statem callbacks -export([terminate/3, format_status/2]). %% Erlang Distribution export --export([get_sslsocket/1, handshake_complete/3]). +-export([get_sslsocket/1, dist_handshake_complete/2]). %%==================================================================== %% Setup @@ -318,8 +318,8 @@ internal_renegotiation(ConnectionPid, #{current_write := WriteState}) -> get_sslsocket(ConnectionPid) -> call(ConnectionPid, get_sslsocket). -handshake_complete(ConnectionPid, Node, DHandle) -> - call(ConnectionPid, {handshake_complete, Node, DHandle}). +dist_handshake_complete(ConnectionPid, DHandle) -> + gen_statem:cast(ConnectionPid, {dist_handshake_complete, DHandle}). %%-------------------------------------------------------------------- -spec prf(pid(), binary() | 'master_secret', binary(), @@ -452,50 +452,36 @@ read_application_data(Data, #state{user_application = {_Mon, Pid}, end, case get_data(SOpts, BytesToRead, Buffer1) of {ok, ClientData, Buffer} -> % Send data - case State0 of - #state{ - ssl_options = #ssl_options{erl_dist = true}, - protocol_specific = #{d_handle := DHandle}} -> - State = - State0#state{ - user_data_buffer = Buffer, - bytes_to_read = undefined}, - try erlang:dist_ctrl_put_data(DHandle, ClientData) of - _ - when SOpts#socket_options.active =:= false; - Buffer =:= <<>> -> - %% Passive mode, wait for active once or recv - %% Active and empty, get more data - Connection:next_record_if_active(State); - _ -> %% We have more data - read_application_data(<<>>, State) - catch error:_ -> - death_row(State, disconnect) - end; - _ -> - SocketOpt = - deliver_app_data(Connection:pids(State0), - Transport, Socket, SOpts, - ClientData, Pid, RecvFrom, Tracker, Connection), - cancel_timer(Timer), - State = - State0#state{ - user_data_buffer = Buffer, + #state{ssl_options = #ssl_options{erl_dist = Dist}, + erl_dist_data = DistData} = State0, + case Dist andalso is_dist_up(DistData) of + true -> + dist_app_data(ClientData, State0#state{user_data_buffer = Buffer, + bytes_to_read = undefined}); + _ -> + SocketOpt = + deliver_app_data(Connection:pids(State0), + Transport, Socket, SOpts, + ClientData, Pid, RecvFrom, Tracker, Connection), + cancel_timer(Timer), + State = + State0#state{ + user_data_buffer = Buffer, start_or_recv_from = undefined, timer = undefined, bytes_to_read = undefined, socket_options = SocketOpt - }, - if - SocketOpt#socket_options.active =:= false; - Buffer =:= <<>> -> - %% Passive mode, wait for active once or recv + }, + if + SocketOpt#socket_options.active =:= false; + Buffer =:= <<>> -> + %% Passive mode, wait for active once or recv %% Active and empty, get more data - Connection:next_record_if_active(State); - true -> %% We have more data - read_application_data(<<>>, State) - end - end; + Connection:next_record_if_active(State); + true -> %% We have more data + read_application_data(<<>>, State) + end + end; {more, Buffer} -> % no reply, we need more data Connection:next_record(State0#state{user_data_buffer = Buffer}); {passive, Buffer} -> @@ -505,6 +491,35 @@ read_application_data(Data, #state{user_application = {_Mon, Pid}, Transport, Socket, SOpts, Buffer1, Pid, RecvFrom, Tracker, Connection), stop(normal, State0) end. + +dist_app_data(ClientData, #state{protocol_cb = Connection, + erl_dist_data = #{dist_handle := undefined, + dist_buffer := DistBuff} = DistData} = State) -> + Connection:next_record_if_active(State#state{erl_dist_data = DistData#{dist_buffer => [ClientData, DistBuff]}}); +dist_app_data(ClientData, #state{erl_dist_data = #{dist_handle := DHandle, + dist_buffer := DistBuff} = ErlDistData, + protocol_cb = Connection, + user_data_buffer = Buffer, + socket_options = SOpts} = State) -> + Data = merge_dist_data(DistBuff, ClientData), + try erlang:dist_ctrl_put_data(DHandle, Data) of + _ when SOpts#socket_options.active =:= false; + Buffer =:= <<>> -> + %% Passive mode, wait for active once or recv + %% Active and empty, get more data + Connection:next_record_if_active(State#state{erl_dist_data = ErlDistData#{dist_buffer => <<>>}}); + _ -> %% We have more data + read_application_data(<<>>, State) + catch error:_ -> + death_row(State, disconnect) + end. + +merge_dist_data(<<>>, ClientData) -> + ClientData; +merge_dist_data(DistBuff, <<>>) -> + DistBuff; +merge_dist_data(DistBuff, ClientData) -> + [DistBuff, ClientData]. %%==================================================================== %% Help functions for tls|dtls_connection.erl %%==================================================================== @@ -604,12 +619,6 @@ init({call, From}, {start, {Opts, EmOpts}, Timeout}, socket_options = SockOpts} = State0, Connection) -> try SslOpts = ssl:handle_options(Opts, OrigSSLOptions), - case SslOpts of - #ssl_options{erl_dist = true} -> - process_flag(priority, max); - _ -> - ok - end, State = ssl_config(SslOpts, Role, State0), init({call, From}, {start, Timeout}, State#state{ssl_options = SslOpts, @@ -1045,25 +1054,6 @@ connection({call, From}, negotiated_protocol, #state{negotiated_protocol = SelectedProtocol} = State, _) -> hibernate_after(?FUNCTION_NAME, State, [{reply, From, {ok, SelectedProtocol}}]); -connection({call, From}, {handshake_complete, _Node, DHandle}, - #state{ssl_options = #ssl_options{erl_dist = true}, - socket_options = SockOpts, - protocol_specific = ProtocolSpecific} = State, Connection) -> - %% From now on we execute on normal priority - process_flag(priority, normal), - try erlang:dist_ctrl_get_data_notification(DHandle) of - _ -> - NewState = - State#state{ - socket_options = - SockOpts#socket_options{active = true}, - protocol_specific = - ProtocolSpecific#{d_handle => DHandle}}, - {Record, NewerState} = Connection:next_record_if_active(NewState), - Connection:next_event(connection, Record, NewerState, [{reply, From, ok}]) - catch error:_ -> - death_row(State, disconnect) - end; connection({call, From}, Msg, State, Connection) -> handle_call(Msg, From, ?FUNCTION_NAME, State, Connection); connection(cast, {internal_renegotiate, WriteState}, #state{protocol_cb = Connection, @@ -1071,28 +1061,18 @@ connection(cast, {internal_renegotiate, WriteState}, #state{protocol_cb = Connec = State, Connection) -> Connection:renegotiate(State#state{renegotiation = {true, internal}, connection_states = ConnectionStates#{current_write => WriteState}}, []); -connection(info, dist_data = Msg, #state{ssl_options = #ssl_options{erl_dist = true}, - protocol_specific = #{d_handle := DHandle}} = State, _) -> - eat_msgs(Msg), - try send_dist_data(?FUNCTION_NAME, State, DHandle, []) - catch error:_ -> - death_row(State, disconnect) - end; -connection(info, {send, From, Ref, Data}, #state{ssl_options = #ssl_options{erl_dist = true}, - protocol_specific = #{d_handle := _}}, _) -> - %% This is for testing only! - %% - %% Needed by some OTP distribution - %% test suites... - From ! {Ref, ok}, - {keep_state_and_data, - [{next_event, {call, {self(), undefined}}, - {application_data, iolist_to_binary(Data)}}]}; -connection(info, tick = Msg, #state{ssl_options = #ssl_options{erl_dist = true}, - protocol_specific = #{d_handle := _}},_) -> - eat_msgs(Msg), - {keep_state_and_data, - [{next_event, {call, {self(), undefined}}, {application_data, <<>>}}]}; +connection(cast, {dist_handshake_complete, DHandle}, + #state{ssl_options = #ssl_options{erl_dist = true}, + erl_dist_data = ErlDistData, + socket_options = SockOpts} = State0, Connection) -> + process_flag(priority, normal), + State1 = + State0#state{ + socket_options = + SockOpts#socket_options{active = true}, + erl_dist_data = ErlDistData#{dist_handle => DHandle}}, + {Record, State} = dist_app_data(<<>>, State1), + Connection:next_event(connection, Record, State); connection(info, Msg, State, _) -> handle_info(Msg, ?FUNCTION_NAME, State); connection(internal, {recv, _}, State, Connection) -> @@ -1107,13 +1087,10 @@ connection(Type, Msg, State, Connection) -> %%-------------------------------------------------------------------- %% We just wait for the owner to die which triggers the monitor, %% or the socket may die too -death_row( - info, {'DOWN', MonitorRef, _, _, Reason}, - #state{user_application={MonitorRef,_Pid}}, - _) -> +death_row(info, {'DOWN', MonitorRef, _, _, Reason}, + #state{user_application={MonitorRef,_Pid}},_) -> {stop, {shutdown, Reason}}; -death_row( - info, {'EXIT', Socket, Reason}, #state{socket = Socket}, _) -> +death_row(info, {'EXIT', Socket, Reason}, #state{socket = Socket}, _) -> {stop, {shutdown, Reason}}; death_row(state_timeout, Reason, _State, _Connection) -> {stop, {shutdown,Reason}}; @@ -1176,7 +1153,14 @@ handle_common_event(internal, {application_data, Data}, StateName, State0, Conne {stop, _, _} = Stop-> Stop; {Record, State} -> - Connection:next_event(StateName, Record, State) + case Connection:next_event(StateName, Record, State) of + {next_state, StateName, State} -> + hibernate_after(StateName, State, []); + {next_state, StateName, State, Actions} -> + hibernate_after(StateName, State, Actions); + {stop, _, _} = Stop -> + Stop + end end; handle_common_event(internal, #change_cipher_spec{type = <<1>>}, StateName, #state{negotiated_version = Version} = State, _) -> @@ -1262,12 +1246,8 @@ handle_call({set_opts, Opts0}, From, StateName, handle_call(renegotiate, From, StateName, _, _) when StateName =/= connection -> {keep_state_and_data, [{reply, From, {error, already_renegotiating}}]}; -handle_call( - get_sslsocket, From, _StateName, - #state{transport_cb = Transport, socket = Socket, tracker = Tracker}, - Connection) -> - SslSocket = - Connection:socket(self(), Transport, Socket, Connection, Tracker), +handle_call(get_sslsocket, From, _StateName, State, Connection) -> + SslSocket = Connection:socket(State), {keep_state_and_data, [{reply, From, SslSocket}]}; handle_call({prf, Secret, Label, Seed, WantedLength}, From, _, @@ -1316,23 +1296,18 @@ handle_info({ErrorTag, Socket, Reason}, StateName, #state{socket = Socket, handle_normal_shutdown(?ALERT_REC(?FATAL, ?CLOSE_NOTIFY), StateName, State), stop(normal, State); -handle_info( - {'DOWN', MonitorRef, _, _, Reason}, _, - #state{ - user_application = {MonitorRef, _Pid}, - ssl_options = #ssl_options{erl_dist = true}}) -> +handle_info({'DOWN', MonitorRef, _, _, Reason}, _, + #state{user_application = {MonitorRef, _Pid}, + ssl_options = #ssl_options{erl_dist = true}}) -> {stop, {shutdown, Reason}}; -handle_info( - {'DOWN', MonitorRef, _, _, _}, _, - #state{user_application = {MonitorRef, _Pid}}) -> +handle_info({'DOWN', MonitorRef, _, _, _}, _, + #state{user_application = {MonitorRef, _Pid}}) -> {stop, normal}; -handle_info( - {'EXIT', Pid, _Reason}, StateName, - #state{user_application = {_MonitorRef, Pid}} = State) -> +handle_info({'EXIT', Pid, _Reason}, StateName, + #state{user_application = {_MonitorRef, Pid}} = State) -> %% It seems the user application has linked to us %% - ignore that and let the monitor handle this {next_state, StateName, State}; - %%% So that terminate will be run when supervisor issues shutdown handle_info({'EXIT', _Sup, shutdown}, _StateName, State) -> stop(shutdown, State); @@ -1380,7 +1355,7 @@ terminate({shutdown, transport_closed} = Reason, socket = Socket, transport_cb = Transport} = State) -> handle_trusted_certs_db(State), Connection:close(Reason, Socket, Transport, undefined, undefined); -terminate({shutdown, own_alert}, _StateName, #state{%%send_queue = SendQueue, +terminate({shutdown, own_alert}, _StateName, #state{ protocol_cb = Connection, socket = Socket, transport_cb = Transport} = State) -> @@ -2753,25 +2728,6 @@ new_emulated([], EmOpts) -> new_emulated(NewEmOpts, _) -> NewEmOpts. %%---------------Erlang distribution -------------------------------------- - -send_dist_data(StateName, State, DHandle, Acc) -> - case erlang:dist_ctrl_get_data(DHandle) of - none -> - erlang:dist_ctrl_get_data_notification(DHandle), - hibernate_after(StateName, State, lists:reverse(Acc)); - Data -> - send_dist_data( - StateName, State, DHandle, - [{next_event, {call, {self(), undefined}}, {application_data, Data}} - |Acc]) - end. - -%% Overload mitigation -eat_msgs(Msg) -> - receive Msg -> eat_msgs(Msg) - after 0 -> ok - end. - %% When acting as distribution controller map the exit reason %% to follow the documented nodedown_reason for net_kernel stop(Reason, State) -> @@ -2791,3 +2747,8 @@ erl_dist_stop_reason( end; erl_dist_stop_reason(Reason, _State) -> Reason. + +is_dist_up(#{dist_handle := Handle}) when Handle =/= undefined -> + true; +is_dist_up(_) -> + false. diff --git a/lib/ssl/src/ssl_connection.hrl b/lib/ssl/src/ssl_connection.hrl index 504a141a4a..66e3182313 100644 --- a/lib/ssl/src/ssl_connection.hrl +++ b/lib/ssl/src/ssl_connection.hrl @@ -75,6 +75,7 @@ cert_db_ref :: certdb_ref() | 'undefined', bytes_to_read :: undefined | integer(), %% bytes to read in passive mode user_data_buffer :: undefined | binary() | secret_printout(), + erl_dist_data = #{} :: map(), renegotiation :: undefined | {boolean(), From::term() | internal | peer}, start_or_recv_from :: term(), timer :: undefined | reference(), % start_or_recive_timer diff --git a/lib/ssl/src/ssl_internal.hrl b/lib/ssl/src/ssl_internal.hrl index ae1c3ea47c..fd246e2550 100644 --- a/lib/ssl/src/ssl_internal.hrl +++ b/lib/ssl/src/ssl_internal.hrl @@ -120,7 +120,7 @@ %% undefined if not hibernating, or number of ms of %% inactivity after which ssl_connection will go into %% hibernation - hibernate_after :: timeout(), + hibernate_after :: timeout(), %% This option should only be set to true by inet_tls_dist erl_dist = false :: boolean(), alpn_advertised_protocols = undefined :: [binary()] | undefined , diff --git a/lib/ssl/src/tls_connection.erl b/lib/ssl/src/tls_connection.erl index ef52421523..8277569281 100644 --- a/lib/ssl/src/tls_connection.erl +++ b/lib/ssl/src/tls_connection.erl @@ -69,6 +69,9 @@ %% gen_statem callbacks -export([callback_mode/0, terminate/3, code_change/4, format_status/2]). + +-define(DIST_CNTRL_SPAWN_OPTS, [{priority, max}]). + %%==================================================================== %% Internal application API %%==================================================================== @@ -93,7 +96,7 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} = User, {CbModule, _,_, _} = CbInfo, Timeout) -> try - {ok, Sender} = tls_sender:start(), + {ok, Sender} = tls_sender:start([{spawn_opt, ?DIST_CNTRL_SPAWN_OPTS}]), {ok, Pid} = tls_connection_sup:start_child_dist([Role, Sender, Host, Port, Socket, Opts, User, CbInfo]), {ok, SslSocket} = ssl_connection:socket_control(?MODULE, Socket, [Pid, Sender], CbModule, Tracker), @@ -113,8 +116,14 @@ start_fsm(Role, Host, Port, Socket, {#ssl_options{erl_dist = true},_, Tracker} = start_link(Role, Sender, Host, Port, Socket, Options, User, CbInfo) -> {ok, proc_lib:spawn_link(?MODULE, init, [[Role, Sender, Host, Port, Socket, Options, User, CbInfo]])}. -init([Role, Sender, Host, Port, Socket, Options, User, CbInfo]) -> +init([Role, Sender, Host, Port, Socket, {SslOpts, _, _} = Options, User, CbInfo]) -> process_flag(trap_exit, true), + case SslOpts#ssl_options.erl_dist of + true -> + process_flag(priority, max); + _ -> + ok + end, State0 = #state{protocol_specific = Map} = initial_state(Role, Sender, Host, Port, Socket, Options, User, CbInfo), try @@ -646,9 +655,11 @@ code_change(_OldVsn, StateName, State, _) -> %%-------------------------------------------------------------------- initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Tracker}, User, {CbModule, DataTag, CloseTag, ErrorTag}) -> - #ssl_options{beast_mitigation = BeastMitigation} = SSLOptions, + #ssl_options{beast_mitigation = BeastMitigation, + erl_dist = IsErlDist} = SSLOptions, ConnectionStates = tls_record:init_connection_states(Role, BeastMitigation), + ErlDistData = erl_dist_data(IsErlDist), SessionCacheCb = case application:get_env(ssl, session_cb) of {ok, Cb} when is_atom(Cb) -> Cb; @@ -670,6 +681,7 @@ initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trac host = Host, port = Port, socket = Socket, + erl_dist_data = ErlDistData, connection_states = ConnectionStates, protocol_buffers = #protocol_buffers{}, user_application = {UserMonitor, User}, @@ -684,8 +696,16 @@ initial_state(Role, Sender, Host, Port, Socket, {SSLOptions, SocketOptions, Trac protocol_specific = #{sender => {SendMonitor, Sender}} }. -initialize_tls_sender(#state{socket = Socket, +erl_dist_data(true) -> + #{dist_handle => undefined, + dist_buffer => <<>>}; +erl_dist_data(false) -> + #{}. + +initialize_tls_sender(#state{role = Role, + socket = Socket, socket_options = SockOpts, + tracker = Tracker, protocol_cb = Connection, transport_cb = Transport, negotiated_version = Version, @@ -693,8 +713,10 @@ initialize_tls_sender(#state{socket = Socket, connection_states = #{current_write := ConnectionWriteState}, protocol_specific = #{sender := {_, Sender}}}) -> Init = #{current_write => ConnectionWriteState, + role => Role, socket => Socket, socket_options => SockOpts, + tracker => Tracker, protocol_cb => Connection, transport_cb => Transport, negotiated_version => Version, @@ -772,6 +794,9 @@ handle_info({CloseTag, Socket}, StateName, %% and then receive the final message. next_event(StateName, no_record, State) end; +handle_info({'DOWN', Mon, _, _, _}, _, #state{ssl_options = #ssl_options{erl_dist = true}, + protocol_specific = #{sender:= {Mon, _}}} = State) -> + ssl_connection:death_row(State, disconnect); handle_info({'DOWN', Mon, _, _, Reason}, _, #state{protocol_specific = #{sender:= {Mon, _}}} = State) -> {stop, {shudown, sender_died, Reason}, State}; handle_info(Msg, StateName, State) -> diff --git a/lib/ssl/src/tls_sender.erl b/lib/ssl/src/tls_sender.erl index 4aeb13284f..2746d89048 100644 --- a/lib/ssl/src/tls_sender.erl +++ b/lib/ssl/src/tls_sender.erl @@ -27,8 +27,8 @@ -include("ssl_handshake.hrl"). %% API --export([start/0, initialize/2, send_data/2, send_alert/2, renegotiate/1, - update_connection_state/3]). +-export([start/0, start/1, initialize/2, send_data/2, send_alert/2, renegotiate/1, + update_connection_state/3, dist_tls_socket/1, dist_handshake_complete/3]). %% gen_statem callbacks -export([callback_mode/0, init/1, terminate/3, code_change/4]). @@ -38,13 +38,16 @@ -record(data, {connection_pid, connection_states = #{}, + role, socket, socket_options, + tracker, protocol_cb, transport_cb, negotiated_version, renegotiate_at, - connection_monitor + connection_monitor, + dist_handle }). %%%=================================================================== @@ -53,9 +56,24 @@ -spec start() -> {ok, Pid :: pid()} | ignore | {error, Error :: term()}. +-spec start(list()) -> {ok, Pid :: pid()} | + ignore | + {error, Error :: term()}. + +%% Description: Start sender process to avoid dead lock that +%% may happen when a socket is busy (busy port) and the +%% same process is sending and receiving +%%-------------------------------------------------------------------- start() -> gen_statem:start(?MODULE, [], []). +start(SpawnOpts) -> + gen_statem:start_link(?MODULE, [], SpawnOpts). +%%-------------------------------------------------------------------- +-spec initialize(pid(), map()) -> ok. +%% Description: So TLS connection process can initialize it sender +%% process. +%%-------------------------------------------------------------------- initialize(Pid, InitMsg) -> gen_statem:call(Pid, {self(), InitMsg}). @@ -82,6 +100,12 @@ renegotiate(Pid) -> update_connection_state(Pid, NewState, Version) -> gen_statem:cast(Pid, {new_write, NewState, Version}). +dist_handshake_complete(ConnectionPid, Node, DHandle) -> + gen_statem:call(ConnectionPid, {dist_handshake_complete, Node, DHandle}). + +dist_tls_socket(Pid) -> + gen_statem:call(Pid, dist_get_tls_socket). + %%%=================================================================== %%% gen_statem callbacks %%%=================================================================== @@ -105,8 +129,10 @@ init(_) -> gen_statem:event_handler_result(atom()). %%-------------------------------------------------------------------- init({call, From}, {Pid, #{current_write := WriteState, + role := Role, socket := Socket, socket_options := SockOpts, + tracker := Tracker, protocol_cb := Connection, transport_cb := Transport, negotiated_version := Version, @@ -118,8 +144,10 @@ init({call, From}, {Pid, #{current_write := WriteState, connection_monitor = Monitor, connection_states = ConnectionStates#{current_write => WriteState}, + role = Role, socket = Socket, socket_options = SockOpts, + tracker = Tracker, protocol_cb = Connection, transport_cb = Transport, negotiated_version = Version, @@ -137,13 +165,28 @@ connection({call, From}, renegotiate, #data{connection_states = #{current_write := Write}} = StateData) -> {next_state, handshake, StateData, [{reply, From, {ok, Write}}]}; connection({call, From}, {application_data, AppData}, - #data{socket_options = SockOpts} = StateData) -> + #data{socket_options = SockOpts} = StateData) -> case encode_packet(AppData, SockOpts) of {error, _} = Error -> {next_state, ?FUNCTION_NAME, StateData, [{reply, From, Error}]}; Data -> send_application_data(Data, From, ?FUNCTION_NAME, StateData) end; +connection({call, From}, dist_get_tls_socket, + #data{protocol_cb = Connection, + transport_cb = Transport, + socket = Socket, + connection_pid = Pid, + tracker = Tracker} = StateData) -> + TLSSocket = Connection:socket([Pid, self()], Transport, Socket, Connection, Tracker), + {next_state, ?FUNCTION_NAME, StateData, [{reply, From, {ok, TLSSocket}}]}; +connection({call, From}, {dist_handshake_complete, _Node, DHandle}, #data{connection_pid = Pid} = StateData) -> + ok = erlang:dist_ctrl_input_handler(DHandle, Pid), + ok = ssl_connection:dist_handshake_complete(Pid, DHandle), + %% From now on we execute on normal priority + process_flag(priority, normal), + Events = dist_data_events(DHandle, []), + {next_state, ?FUNCTION_NAME, StateData#data{dist_handle = DHandle}, [{reply, From, ok} | Events]}; connection(cast, #alert{} = Alert, StateData0) -> StateData = send_tls_alert(Alert, StateData0), {next_state, ?FUNCTION_NAME, StateData}; @@ -153,6 +196,23 @@ connection(cast, {new_write, WritesState, Version}, StateData#data{connection_states = ConnectionStates0#{current_write => WritesState}, negotiated_version = Version}}; +connection(info, dist_data, #data{dist_handle = DHandle} = StateData) -> + Events = dist_data_events(DHandle, []), + {next_state, ?FUNCTION_NAME, StateData, Events}; +connection(info, tick, StateData) -> + consume_ticks(), + {next_state, ?FUNCTION_NAME, StateData, + [{next_event, {call, {self(), undefined}}, + {application_data, <<>>}}]}; +connection(info, {send, From, Ref, Data}, _StateData) -> + %% This is for testing only! + %% + %% Needed by some OTP distribution + %% test suites... + From ! {Ref, ok}, + {keep_state_and_data, + [{next_event, {call, {self(), undefined}}, + {application_data, iolist_to_binary(Data)}}]}; connection(info, Msg, StateData) -> handle_info(Msg, ?FUNCTION_NAME, StateData). %%-------------------------------------------------------------------- @@ -209,9 +269,10 @@ send_tls_alert(Alert, #data{negotiated_version = Version, Connection:send(Transport, Socket, BinMsg), StateData0#data{connection_states = ConnectionStates}. -send_application_data(Data, {FromPid, _} = From, StateName, +send_application_data(Data, From, StateName, #data{connection_pid = Pid, socket = Socket, + dist_handle = DistHandle, negotiated_version = Version, protocol_cb = Connection, transport_cb = Transport, @@ -227,9 +288,9 @@ send_application_data(Data, {FromPid, _} = From, StateName, Connection:encode_data(Data, Version, ConnectionStates0), StateData = StateData0#data{connection_states = ConnectionStates}, case Connection:send(Transport, Socket, Msgs) of - ok when FromPid =:= Pid -> + ok when DistHandle =/= undefined -> {next_state, StateName, StateData, []}; - Error when FromPid =:= Pid -> + Error when DistHandle =/= undefined -> ssl_connection:stop({shutdown, Error}, StateData); ok -> {next_state, StateName, StateData, [{reply, From, ok}]}; @@ -279,3 +340,22 @@ call(FsmPid, Event) -> exit:{{shutdown, _},_} -> {error, closed} end. + +%%---------------Erlang distribution -------------------------------------- + +dist_data_events(DHandle, Events) -> + case erlang:dist_ctrl_get_data(DHandle) of + none -> + erlang:dist_ctrl_get_data_notification(DHandle), + lists:reverse(Events); + Data -> + Event = {next_event, {call, {self(), undefined}}, {application_data, Data}}, + dist_data_events(DHandle, [Event | Events]) + end. + +consume_ticks() -> + receive tick -> + consume_ticks() + after 0 -> + ok + end. -- cgit v1.2.3