aboutsummaryrefslogtreecommitdiffstats
path: root/src/cowboy_websocket.erl
diff options
context:
space:
mode:
Diffstat (limited to 'src/cowboy_websocket.erl')
-rw-r--r--src/cowboy_websocket.erl90
1 files changed, 49 insertions, 41 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl
index 3c34908..9a2862e 100644
--- a/src/cowboy_websocket.erl
+++ b/src/cowboy_websocket.erl
@@ -18,6 +18,7 @@
-behaviour(cowboy_sub_protocol).
-export([upgrade/6]).
+-export([takeover/7]).
-export([handler_loop/4]).
-type terminate_reason() :: normal | stop | timeout
@@ -50,7 +51,6 @@
-optional_callbacks([terminate/3]).
-record(state, {
- env :: cowboy_middleware:env(),
socket = undefined :: inet:socket(),
transport = undefined :: module(),
handler :: module(),
@@ -65,25 +65,22 @@
extensions = #{} :: map()
}).
+%% Stream process.
+
-spec upgrade(Req, Env, module(), any(), timeout(), run | hibernate)
- -> {ok, Req, Env} | {suspend, module(), atom(), [any()]}
+ -> {ok, Req, Env}
when Req::cowboy_req:req(), Env::cowboy_middleware:env().
upgrade(Req, Env, Handler, HandlerState, Timeout, Hibernate) ->
- {_, Ref} = lists:keyfind(listener, 1, Env),
- ranch:remove_connection(Ref),
- [Socket, Transport] = cowboy_req:get([socket, transport], Req),
- State = #state{env=Env, socket=Socket, transport=Transport, handler=Handler,
- timeout=Timeout, hibernate=Hibernate =:= hibernate},
+ State = #state{handler=Handler, timeout=Timeout,
+ hibernate=Hibernate =:= hibernate},
+ %% @todo We need to fail if HTTP/2.
try websocket_upgrade(State, Req) of
{ok, State2, Req2} ->
- websocket_handshake(State2, Req2, HandlerState)
+ websocket_handshake(State2, Req2, HandlerState, Env)
catch _:_ ->
- receive
- {cowboy_req, resp_sent} -> ok
- after 0 ->
- _ = cowboy_req:reply(400, Req),
- exit(normal)
- end
+ %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection.
+ cowboy_req:reply(400, Req),
+ {ok, Req, Env}
end.
-spec websocket_upgrade(#state{}, Req)
@@ -99,14 +96,15 @@ websocket_upgrade(State, Req) ->
orelse (IntVersion =:= 13),
Key = cowboy_req:header(<<"sec-websocket-key">>, Req),
false = Key =:= undefined,
- websocket_extensions(State#state{key=Key},
- cowboy_req:set_meta(websocket_version, IntVersion, Req)).
+ websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}).
-spec websocket_extensions(#state{}, Req)
-> {ok, #state{}, Req} when Req::cowboy_req:req().
websocket_extensions(State, Req) ->
- [Compress] = cowboy_req:get([resp_compress], Req),
- Req2 = cowboy_req:set_meta(websocket_compress, false, Req),
+ %% @todo Proper options for this.
+% [Compress] = cowboy_req:get([resp_compress], Req),
+ Compress = false,
+ Req2 = Req#{websocket_compress => false},
case {Compress, cowboy_req:parse_header(<<"sec-websocket-extensions">>, Req2)} of
{true, Extensions} when Extensions =/= undefined ->
websocket_extensions(State, Req2, Extensions, []);
@@ -123,7 +121,7 @@ websocket_extensions(State=#state{extensions=Extensions}, Req, [{<<"permessage-d
Opts = #{level => best_compression, mem_level => 8, strategy => default},
case cow_ws:negotiate_permessage_deflate(Params, Extensions, Opts) of
{ok, RespExt, Extensions2} ->
- Req2 = cowboy_req:set_meta(websocket_compress, true, Req),
+ Req2 = Req#{websocket_compress => true},
websocket_extensions(State#state{extensions=Extensions2},
Req2, Tail, [<<", ">>, RespExt|RespHeader]);
ignore ->
@@ -143,33 +141,46 @@ websocket_extensions(State=#state{extensions=Extensions}, Req, [{<<"x-webkit-def
websocket_extensions(State, Req, [_|Tail], RespHeader) ->
websocket_extensions(State, Req, Tail, RespHeader).
--spec websocket_handshake(#state{}, Req, any())
- -> {ok, Req, cowboy_middleware:env()}
- | {suspend, module(), atom(), [any()]}
- when Req::cowboy_req:req().
-websocket_handshake(State=#state{transport=Transport, key=Key}, Req, HandlerState) ->
+-spec websocket_handshake(#state{}, Req, any(), Env)
+ -> {ok, Req, Env}
+ when Req::cowboy_req:req(), Env::cowboy_middleware:env().
+websocket_handshake(State=#state{transport=Transport, key=Key},
+ Req=#{pid := Pid, streamid := StreamID}, HandlerState, Env) ->
Challenge = base64:encode(crypto:hash(sha,
<< Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)),
- Req2 = cowboy_req:upgrade_reply(101, [
- {<<"upgrade">>, <<"websocket">>},
- {<<"sec-websocket-accept">>, Challenge}
- ], Req),
- %% Flush the resp_sent message before moving on.
- receive {cowboy_req, resp_sent} -> ok after 0 -> ok end,
- State2 = handler_loop_timeout(State),
+ Headers = #{
+ %% @todo Hmm should those be here or in cowboy_http?
+ <<"connection">> => <<"Upgrade">>,
+ <<"upgrade">> => <<"websocket">>,
+ <<"sec-websocket-accept">> => Challenge
+ },
+ Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {Req, State, HandlerState}}},
+ {ok, Req, Env}.
+
+%% Connection process.
+
+takeover(Parent, Ref, Socket, Transport, Opts, Buffer, {Req0, State=#state{handler=Handler}, HandlerState0}) ->
+ ranch:remove_connection(Ref),
+ %% @todo Remove Req from Websocket callbacks.
+ %% @todo Allow sending a reply from websocket_init.
+ %% @todo Try/catch.
+ {ok, Req, HandlerState} = case erlang:function_exported(Handler, websocket_init, 2) of
+ true -> Handler:websocket_init(Req0, HandlerState0);
+ false -> {ok, Req0, HandlerState0}
+ end,
+ State2 = handler_loop_timeout(State#state{socket=Socket, transport=Transport}),
handler_before_loop(State2#state{key=undefined,
- messages=Transport:messages()}, Req2, HandlerState, <<>>).
+ messages=Transport:messages()}, Req, HandlerState, Buffer).
-spec handler_before_loop(#state{}, Req, any(), binary())
-> {ok, Req, cowboy_middleware:env()}
- | {suspend, module(), atom(), [any()]}
when Req::cowboy_req:req().
handler_before_loop(State=#state{
socket=Socket, transport=Transport, hibernate=true},
Req, HandlerState, SoFar) ->
Transport:setopts(Socket, [{active, once}]),
- {suspend, ?MODULE, handler_loop,
- [State#state{hibernate=false}, Req, HandlerState, SoFar]};
+ proc_lib:hibernate(?MODULE, handler_loop,
+ [State#state{hibernate=false}, Req, HandlerState, SoFar]);
handler_before_loop(State=#state{socket=Socket, transport=Transport},
Req, HandlerState, SoFar) ->
Transport:setopts(Socket, [{active, once}]),
@@ -186,7 +197,6 @@ handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) ->
-spec handler_loop(#state{}, Req, any(), binary())
-> {ok, Req, cowboy_middleware:env()}
- | {suspend, module(), atom(), [any()]}
when Req::cowboy_req:req().
handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
timeout_ref=TRef}, Req, HandlerState, SoFar) ->
@@ -210,7 +220,6 @@ handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error},
-spec websocket_data(#state{}, Req, any(), binary())
-> {ok, Req, cowboy_middleware:env()}
- | {suspend, module(), atom(), [any()]}
when Req::cowboy_req:req().
websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, Req, HandlerState, Data) ->
case cow_ws:parse_header(Data, Extensions, FragState) of
@@ -298,7 +307,6 @@ websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=F
-spec handler_call(#state{}, Req, any(), binary(), atom(), any(), fun())
-> {ok, Req, cowboy_middleware:env()}
- | {suspend, module(), atom(), [any()]}
when Req::cowboy_req:req().
handler_call(State=#state{handler=Handler}, Req, HandlerState,
RemainingData, Callback, Message, NextState) ->
@@ -358,7 +366,7 @@ handler_call(State=#state{handler=Handler}, Req, HandlerState,
{mfa, {Handler, Callback, 3}},
{stacktrace, erlang:get_stacktrace()},
{msg, Message},
- {req, cowboy_req:to_list(Req)},
+ {req, Req},
{state, HandlerState}
]})
end.
@@ -407,7 +415,7 @@ websocket_close(State=#state{socket=Socket, transport=Transport, extensions=Exte
-spec handler_terminate(#state{}, Req, any(), terminate_reason())
-> {ok, Req, cowboy_middleware:env()}
when Req::cowboy_req:req().
-handler_terminate(#state{env=Env, handler=Handler},
+handler_terminate(#state{handler=Handler},
Req, HandlerState, Reason) ->
cowboy_handler:terminate(Reason, Req, HandlerState, Handler),
- {ok, Req, [{result, closed}|Env]}.
+ exit(normal).