diff options
Diffstat (limited to 'src/cowboy_websocket.erl')
-rw-r--r-- | src/cowboy_websocket.erl | 90 |
1 files changed, 49 insertions, 41 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 3c34908..9a2862e 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -18,6 +18,7 @@ -behaviour(cowboy_sub_protocol). -export([upgrade/6]). +-export([takeover/7]). -export([handler_loop/4]). -type terminate_reason() :: normal | stop | timeout @@ -50,7 +51,6 @@ -optional_callbacks([terminate/3]). -record(state, { - env :: cowboy_middleware:env(), socket = undefined :: inet:socket(), transport = undefined :: module(), handler :: module(), @@ -65,25 +65,22 @@ extensions = #{} :: map() }). +%% Stream process. + -spec upgrade(Req, Env, module(), any(), timeout(), run | hibernate) - -> {ok, Req, Env} | {suspend, module(), atom(), [any()]} + -> {ok, Req, Env} when Req::cowboy_req:req(), Env::cowboy_middleware:env(). upgrade(Req, Env, Handler, HandlerState, Timeout, Hibernate) -> - {_, Ref} = lists:keyfind(listener, 1, Env), - ranch:remove_connection(Ref), - [Socket, Transport] = cowboy_req:get([socket, transport], Req), - State = #state{env=Env, socket=Socket, transport=Transport, handler=Handler, - timeout=Timeout, hibernate=Hibernate =:= hibernate}, + State = #state{handler=Handler, timeout=Timeout, + hibernate=Hibernate =:= hibernate}, + %% @todo We need to fail if HTTP/2. try websocket_upgrade(State, Req) of {ok, State2, Req2} -> - websocket_handshake(State2, Req2, HandlerState) + websocket_handshake(State2, Req2, HandlerState, Env) catch _:_ -> - receive - {cowboy_req, resp_sent} -> ok - after 0 -> - _ = cowboy_req:reply(400, Req), - exit(normal) - end + %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection. + cowboy_req:reply(400, Req), + {ok, Req, Env} end. -spec websocket_upgrade(#state{}, Req) @@ -99,14 +96,15 @@ websocket_upgrade(State, Req) -> orelse (IntVersion =:= 13), Key = cowboy_req:header(<<"sec-websocket-key">>, Req), false = Key =:= undefined, - websocket_extensions(State#state{key=Key}, - cowboy_req:set_meta(websocket_version, IntVersion, Req)). + websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}). -spec websocket_extensions(#state{}, Req) -> {ok, #state{}, Req} when Req::cowboy_req:req(). websocket_extensions(State, Req) -> - [Compress] = cowboy_req:get([resp_compress], Req), - Req2 = cowboy_req:set_meta(websocket_compress, false, Req), + %% @todo Proper options for this. +% [Compress] = cowboy_req:get([resp_compress], Req), + Compress = false, + Req2 = Req#{websocket_compress => false}, case {Compress, cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req2)} of {true, Extensions} when Extensions =/= undefined -> websocket_extensions(State, Req2, Extensions, []); @@ -123,7 +121,7 @@ websocket_extensions(State=#state{extensions=Extensions}, Req, [{<<"permessage-d Opts = #{level => best_compression, mem_level => 8, strategy => default}, case cow_ws:negotiate_permessage_deflate(Params, Extensions, Opts) of {ok, RespExt, Extensions2} -> - Req2 = cowboy_req:set_meta(websocket_compress, true, Req), + Req2 = Req#{websocket_compress => true}, websocket_extensions(State#state{extensions=Extensions2}, Req2, Tail, [<<", ">>, RespExt|RespHeader]); ignore -> @@ -143,33 +141,46 @@ websocket_extensions(State=#state{extensions=Extensions}, Req, [{<<"x-webkit-def websocket_extensions(State, Req, [_|Tail], RespHeader) -> websocket_extensions(State, Req, Tail, RespHeader). --spec websocket_handshake(#state{}, Req, any()) - -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} - when Req::cowboy_req:req(). -websocket_handshake(State=#state{transport=Transport, key=Key}, Req, HandlerState) -> +-spec websocket_handshake(#state{}, Req, any(), Env) + -> {ok, Req, Env} + when Req::cowboy_req:req(), Env::cowboy_middleware:env(). +websocket_handshake(State=#state{transport=Transport, key=Key}, + Req=#{pid := Pid, streamid := StreamID}, HandlerState, Env) -> Challenge = base64:encode(crypto:hash(sha, << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)), - Req2 = cowboy_req:upgrade_reply(101, [ - {<<"upgrade">>, <<"websocket">>}, - {<<"sec-websocket-accept">>, Challenge} - ], Req), - %% Flush the resp_sent message before moving on. - receive {cowboy_req, resp_sent} -> ok after 0 -> ok end, - State2 = handler_loop_timeout(State), + Headers = #{ + %% @todo Hmm should those be here or in cowboy_http? + <<"connection">> => <<"Upgrade">>, + <<"upgrade">> => <<"websocket">>, + <<"sec-websocket-accept">> => Challenge + }, + Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {Req, State, HandlerState}}}, + {ok, Req, Env}. + +%% Connection process. + +takeover(Parent, Ref, Socket, Transport, Opts, Buffer, {Req0, State=#state{handler=Handler}, HandlerState0}) -> + ranch:remove_connection(Ref), + %% @todo Remove Req from Websocket callbacks. + %% @todo Allow sending a reply from websocket_init. + %% @todo Try/catch. + {ok, Req, HandlerState} = case erlang:function_exported(Handler, websocket_init, 2) of + true -> Handler:websocket_init(Req0, HandlerState0); + false -> {ok, Req0, HandlerState0} + end, + State2 = handler_loop_timeout(State#state{socket=Socket, transport=Transport}), handler_before_loop(State2#state{key=undefined, - messages=Transport:messages()}, Req2, HandlerState, <<>>). + messages=Transport:messages()}, Req, HandlerState, Buffer). -spec handler_before_loop(#state{}, Req, any(), binary()) -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). handler_before_loop(State=#state{ socket=Socket, transport=Transport, hibernate=true}, Req, HandlerState, SoFar) -> Transport:setopts(Socket, [{active, once}]), - {suspend, ?MODULE, handler_loop, - [State#state{hibernate=false}, Req, HandlerState, SoFar]}; + proc_lib:hibernate(?MODULE, handler_loop, + [State#state{hibernate=false}, Req, HandlerState, SoFar]); handler_before_loop(State=#state{socket=Socket, transport=Transport}, Req, HandlerState, SoFar) -> Transport:setopts(Socket, [{active, once}]), @@ -186,7 +197,6 @@ handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) -> -spec handler_loop(#state{}, Req, any(), binary()) -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error}, timeout_ref=TRef}, Req, HandlerState, SoFar) -> @@ -210,7 +220,6 @@ handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error}, -spec websocket_data(#state{}, Req, any(), binary()) -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, Req, HandlerState, Data) -> case cow_ws:parse_header(Data, Extensions, FragState) of @@ -298,7 +307,6 @@ websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=F -spec handler_call(#state{}, Req, any(), binary(), atom(), any(), fun()) -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). handler_call(State=#state{handler=Handler}, Req, HandlerState, RemainingData, Callback, Message, NextState) -> @@ -358,7 +366,7 @@ handler_call(State=#state{handler=Handler}, Req, HandlerState, {mfa, {Handler, Callback, 3}}, {stacktrace, erlang:get_stacktrace()}, {msg, Message}, - {req, cowboy_req:to_list(Req)}, + {req, Req}, {state, HandlerState} ]}) end. @@ -407,7 +415,7 @@ websocket_close(State=#state{socket=Socket, transport=Transport, extensions=Exte -spec handler_terminate(#state{}, Req, any(), terminate_reason()) -> {ok, Req, cowboy_middleware:env()} when Req::cowboy_req:req(). -handler_terminate(#state{env=Env, handler=Handler}, +handler_terminate(#state{handler=Handler}, Req, HandlerState, Reason) -> cowboy_handler:terminate(Reason, Req, HandlerState, Handler), - {ok, Req, [{result, closed}|Env]}. + exit(normal). |