From 00cc1f385f94823a0684deee001b643091e235b0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= Date: Thu, 26 Sep 2019 13:16:56 +0200 Subject: Add reply_to option to ws_upgrade; remove notowner entirely The reply_to option is also propagated when we switch protocols. --- src/gun.erl | 60 ++++++++++++++++++++++++++++-------------------------------- 1 file changed, 28 insertions(+), 32 deletions(-) (limited to 'src/gun.erl') diff --git a/src/gun.erl b/src/gun.erl index 3154b9b..ddd38c8 100644 --- a/src/gun.erl +++ b/src/gun.erl @@ -142,7 +142,6 @@ ws_opts => ws_opts() }. -export_type([opts/0]). -%% @todo Add an option to disable/enable the notowner behavior. -type connect_destination() :: #{ host := inet:hostname() | inet:ip_address(), @@ -229,6 +228,7 @@ flow => pos_integer(), keepalive => timeout(), protocols => [{binary(), module()}], + reply_to => pid(), silence_pings => boolean() }. -export_type([ws_opts/0]). @@ -447,7 +447,7 @@ close(ServerPid) -> -spec shutdown(pid()) -> ok. shutdown(ServerPid) -> - gen_statem:cast(ServerPid, {shutdown, self()}). + gen_statem:cast(ServerPid, shutdown). %% Requests. @@ -843,7 +843,8 @@ ws_upgrade(ServerPid, Path, Headers) -> ws_upgrade(ServerPid, Path, Headers, Opts) -> ok = gun_ws:check_options(Opts), StreamRef = make_ref(), - gen_statem:cast(ServerPid, {ws_upgrade, self(), StreamRef, Path, Headers, Opts}), + ReplyTo = maps:get(reply_to, Opts, self()), + gen_statem:cast(ServerPid, {ws_upgrade, ReplyTo, StreamRef, Path, Headers, Opts}), StreamRef. %% @todo ws_send/2 will need to be deprecated in favor of a variant with StreamRef. @@ -1011,20 +1012,20 @@ ensure_alpn_sni(Protocols0, TransOpts0, #state{origin_host=OriginHost}) -> end. %% Normal TLS handshake. -tls_handshake(internal, {tls_handshake, HandshakeEvent, Protocols}, +tls_handshake(internal, {tls_handshake, HandshakeEvent, Protocols, ReplyTo}, State0=#state{socket=Socket, transport=gun_tcp}) -> case normal_tls_handshake(Socket, State0, HandshakeEvent, Protocols) of {ok, TLSSocket, NewProtocol, State} -> commands([ {switch_transport, gun_tls, TLSSocket}, - {switch_protocol, NewProtocol} + {switch_protocol, NewProtocol, ReplyTo} ], State); {error, Reason, State} -> commands({error, Reason}, State) end; %% TLS over TLS. tls_handshake(internal, {tls_handshake, - HandshakeEvent0=#{tls_opts := TLSOpts0, timeout := TLSTimeout}, Protocols}, + HandshakeEvent0=#{tls_opts := TLSOpts0, timeout := TLSTimeout}, Protocols, ReplyTo}, State=#state{socket=Socket, transport=Transport, origin_host=OriginHost, origin_port=OriginPort, event_handler=EvHandler, event_handler_state=EvHandlerState0}) -> TLSOpts = ensure_alpn_sni(Protocols, TLSOpts0, State), @@ -1034,20 +1035,20 @@ tls_handshake(internal, {tls_handshake, }, EvHandlerState = EvHandler:tls_handshake_start(HandshakeEvent, EvHandlerState0), {ok, ProxyPid} = gun_tls_proxy:start_link(OriginHost, OriginPort, - TLSOpts, TLSTimeout, Socket, Transport, {HandshakeEvent, Protocols}), + TLSOpts, TLSTimeout, Socket, Transport, {HandshakeEvent, Protocols, ReplyTo}), commands([{switch_transport, gun_tls_proxy, ProxyPid}], State#state{ socket=ProxyPid, transport=gun_tls_proxy, event_handler_state=EvHandlerState}); %% When using gun_tls_proxy we need a separate message to know whether %% the handshake succeeded and whether we need to switch to a different protocol. -tls_handshake(info, {gun_tls_proxy, Socket, {ok, Negotiated}, {HandshakeEvent, Protocols}}, +tls_handshake(info, {gun_tls_proxy, Socket, {ok, Negotiated}, {HandshakeEvent, Protocols, ReplyTo}}, State0=#state{socket=Socket, event_handler=EvHandler, event_handler_state=EvHandlerState0}) -> NewProtocol = protocol_negotiated(Negotiated, Protocols), EvHandlerState = EvHandler:tls_handshake_end(HandshakeEvent#{ socket => Socket, protocol => NewProtocol }, EvHandlerState0), - commands([{switch_protocol, NewProtocol}], State0#state{event_handler_state=EvHandlerState}); -tls_handshake(info, {gun_tls_proxy, Socket, Error = {error, Reason}, {HandshakeEvent, _}}, + commands([{switch_protocol, NewProtocol, ReplyTo}], State0#state{event_handler_state=EvHandlerState}); +tls_handshake(info, {gun_tls_proxy, Socket, Error = {error, Reason}, {HandshakeEvent, _, _}}, State=#state{socket=Socket, event_handler=EvHandler, event_handler_state=EvHandlerState0}) -> EvHandlerState = EvHandler:tls_handshake_end(HandshakeEvent#{ error => Reason @@ -1099,10 +1100,10 @@ connected_data_only(cast, Msg, _) connected_data_only(Type, Event, State) -> handle_common_connected(Type, Event, ?FUNCTION_NAME, State). -connected_ws_only(cast, {ws_send, Owner, Frames}, State=#state{ - owner=Owner, protocol=Protocol=gun_ws, protocol_state=ProtoState, +connected_ws_only(cast, {ws_send, ReplyTo, Frames}, State=#state{ + protocol=Protocol=gun_ws, protocol_state=ProtoState, event_handler=EvHandler, event_handler_state=EvHandlerState0}) -> - {Commands, EvHandlerState} = Protocol:send(Frames, ProtoState, EvHandler, EvHandlerState0), + {Commands, EvHandlerState} = Protocol:ws_send(Frames, ProtoState, ReplyTo, EvHandler, EvHandlerState0), commands(Commands, State#state{event_handler_state=EvHandlerState}); connected_ws_only(cast, Msg, _) when element(1, Msg) =:= headers; element(1, Msg) =:= request; element(1, Msg) =:= data; @@ -1155,22 +1156,22 @@ connected(cast, {connect, ReplyTo, StreamRef, Destination, Headers, InitialFlow} %% Public Websocket interface. %% @todo Maybe make an interface in the protocol module instead of checking on protocol name. %% An interface would also make sure that HTTP/1.0 can't upgrade. -connected(cast, {ws_upgrade, Owner, StreamRef, Path, Headers}, State=#state{opts=Opts}) -> +connected(cast, {ws_upgrade, ReplyTo, StreamRef, Path, Headers}, State=#state{opts=Opts}) -> WsOpts = maps:get(ws_opts, Opts, #{}), - connected(cast, {ws_upgrade, Owner, StreamRef, Path, Headers, WsOpts}, State); -connected(cast, {ws_upgrade, Owner, StreamRef, Path, Headers, WsOpts}, - State=#state{owner=Owner, origin_host=Host, origin_port=Port, + connected(cast, {ws_upgrade, ReplyTo, StreamRef, Path, Headers, WsOpts}, State); +connected(cast, {ws_upgrade, ReplyTo, StreamRef, Path, Headers, WsOpts}, + State=#state{origin_host=Host, origin_port=Port, protocol=Protocol, protocol_state=ProtoState, event_handler=EvHandler, event_handler_state=EvHandlerState0}) when Protocol =:= gun_http -> EvHandlerState1 = EvHandler:ws_upgrade(#{ stream_ref => StreamRef, - reply_to => Owner, %% Only the owner can upgrade the connection at this time. + reply_to => ReplyTo, opts => WsOpts }, EvHandlerState0), %% @todo Can fail if HTTP/1.0. {ProtoState2, EvHandlerState} = Protocol:ws_upgrade(ProtoState, - StreamRef, Host, Port, Path, Headers, WsOpts, + StreamRef, ReplyTo, Host, Port, Path, Headers, WsOpts, EvHandler, EvHandlerState1), {keep_state, State#state{protocol_state=ProtoState2, event_handler_state=EvHandlerState}}; @@ -1272,6 +1273,7 @@ handle_common_connected_no_input(Type, Event, StateName, State) -> %% Common events. handle_common(cast, {set_owner, CurrentOwner, NewOwner}, _, State=#state{owner=CurrentOwner, status={up, CurrentOwnerRef}}) -> + %% @todo This should probably trigger an event. demonitor(CurrentOwnerRef, [flush]), NewOwnerRef = monitor(process, NewOwner), {keep_state, State#state{owner=NewOwner, status={up, NewOwnerRef}}}; @@ -1280,8 +1282,8 @@ handle_common(cast, {set_owner, CurrentOwner, _}, _, #state{owner=CurrentOwner}) CurrentOwner ! {gun_error, self(), {badstate, "The owner of the connection cannot be changed when the connection is shutting down."}}, keep_state_and_state; -handle_common(cast, {shutdown, Owner}, StateName, State=#state{ - owner=Owner, status=Status, socket=Socket, transport=Transport, protocol=Protocol}) -> +handle_common(cast, shutdown, StateName, State=#state{ + status=Status, socket=Socket, transport=Transport, protocol=Protocol}) -> case {Socket, Protocol} of {undefined, _} -> {stop, shutdown}; @@ -1318,12 +1320,6 @@ handle_common(info, {'DOWN', OwnerRef, process, Owner, Reason}, StateName, State end; handle_common({call, From}, _, _, _) -> {keep_state_and_data, {reply, From, {error, bad_call}}}; -%% @todo The ReplyTo patch disabled the notowner behavior. -%% We need to add an option to enforce this behavior if needed. -handle_common(cast, Any, _, #state{owner=Owner}) when element(2, Any) =/= Owner -> - element(2, Any) ! {gun_error, self(), {notowner, - "Operations are restricted to the owner of the connection."}}, - keep_state_and_data; %% We postpone all HTTP/Websocket operations until we are connected. handle_common(cast, _, StateName, _) when StateName =/= connected -> {keep_state_and_data, postpone}; @@ -1381,8 +1377,8 @@ commands([{switch_transport, Transport, Socket}|Tail], State=#state{ commands(Tail, active(State#state{socket=Socket, transport=Transport, messages=Transport:messages(), protocol_state=ProtoState, event_handler_state=EvHandlerState})); -commands([{switch_protocol, Protocol0}], State0=#state{ - owner=Owner, opts=Opts, socket=Socket, transport=Transport, protocol=CurrentProtocol, +commands([{switch_protocol, Protocol0, ReplyTo}], State0=#state{ + opts=Opts, socket=Socket, transport=Transport, protocol=CurrentProtocol, event_handler=EvHandler, event_handler_state=EvHandlerState0}) -> {Protocol, ProtoOpts} = case Protocol0 of {P, PO} -> {protocol_handler(P), PO}; @@ -1392,10 +1388,10 @@ commands([{switch_protocol, Protocol0}], State0=#state{ end, %% When we switch_protocol from socks we must send a gun_socks_up message. _ = case CurrentProtocol of - gun_socks -> Owner ! {gun_socks_up, self(), Protocol:name()}; + gun_socks -> ReplyTo ! {gun_socks_up, self(), Protocol:name()}; _ -> ok end, - {StateName, ProtoState} = Protocol:init(Owner, Socket, Transport, ProtoOpts), + {StateName, ProtoState} = Protocol:init(ReplyTo, Socket, Transport, ProtoOpts), EvHandlerState = EvHandler:protocol_changed(#{protocol => Protocol:name()}, EvHandlerState0), %% We cancel the existing keepalive and, depending on the protocol, %% we enable keepalive again, effectively resetting the timer. @@ -1406,7 +1402,7 @@ commands([{switch_protocol, Protocol0}], State0=#state{ false -> {next_state, StateName, State} end; %% Perform a TLS handshake. -commands([TLSHandshake={tls_handshake, _, _}], State) -> +commands([TLSHandshake={tls_handshake, _, _, _}], State) -> {next_state, tls_handshake, State, {next_event, internal, TLSHandshake}}. -- cgit v1.2.3