From 25ae2028d6a9ce516b01f0ec126abeab00eb329d Mon Sep 17 00:00:00 2001
From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= <essen@dev-extend.eu>
Date: Mon, 10 Oct 2011 09:09:15 +0200
Subject: Add {shutdown, Req} to websocket_init/3 to fail a websocket upgrade

This change allows application developers to refuse websocket upgrades
by returning {shutdown, Req}. The application can also send a reply
with a custom error before returning from websocket_init/3, otherwise
an error 400 is sent.

Note that right now Cowboy closes the connection immediately. Also note
that neither terminate/3 nor websocket_terminate/3 will be called when
the connection is shutdown by websocket_init/3.
---
 src/cowboy_http_websocket.erl            | 26 +++++++++++++++++++++++---
 test/http_SUITE.erl                      | 24 ++++++++++++++++++++++--
 test/websocket_handler_init_shutdown.erl | 30 ++++++++++++++++++++++++++++++
 3 files changed, 75 insertions(+), 5 deletions(-)
 create mode 100644 test/websocket_handler_init_shutdown.erl

diff --git a/src/cowboy_http_websocket.erl b/src/cowboy_http_websocket.erl
index 61917b4..039e2b6 100644
--- a/src/cowboy_http_websocket.erl
+++ b/src/cowboy_http_websocket.erl
@@ -124,7 +124,9 @@ handler_init(State=#state{handler=Handler, opts=Opts},
 				Req2, HandlerState);
 		{ok, Req2, HandlerState, Timeout, hibernate} ->
 			websocket_handshake(State#state{timeout=Timeout,
-				hibernate=true}, Req2, HandlerState)
+				hibernate=true}, Req2, HandlerState);
+		{shutdown, Req2} ->
+			upgrade_denied(Req2)
 	catch Class:Reason ->
 		upgrade_error(Req),
 		error_logger:error_msg(
@@ -135,9 +137,27 @@ handler_init(State=#state{handler=Handler, opts=Opts},
 	end.
 
 -spec upgrade_error(#http_req{}) -> ok.
-upgrade_error(Req=#http_req{socket=Socket, transport=Transport}) ->
-	{ok, _Req} = cowboy_http_req:reply(400, [], [],
+upgrade_error(Req) ->
+	{ok, Req2} = cowboy_http_req:reply(400, [], [],
 		Req#http_req{resp_state=waiting}),
+	upgrade_terminate(Req2).
+
+%% @see cowboy_http_protocol:ensure_response/1
+-spec upgrade_denied(#http_req{}) -> ok.
+upgrade_denied(Req=#http_req{resp_state=done}) ->
+	upgrade_terminate(Req);
+upgrade_denied(Req=#http_req{resp_state=waiting}) ->
+	{ok, Req2} = cowboy_http_req:reply(400, [], [], Req),
+	upgrade_terminate(Req2);
+upgrade_denied(Req=#http_req{method='HEAD', resp_state=chunks}) ->
+	upgrade_terminate(Req);
+upgrade_denied(Req=#http_req{socket=Socket, transport=Transport,
+		resp_state=chunks}) ->
+	Transport:send(Socket, <<"0\r\n\r\n">>),
+	upgrade_terminate(Req).
+
+-spec upgrade_terminate(#http_req{}) -> ok.
+upgrade_terminate(#http_req{socket=Socket, transport=Transport}) ->
 	Transport:close(Socket).
 
 -spec websocket_handshake(#state{}, #http_req{}, any()) -> ok.
diff --git a/test/http_SUITE.erl b/test/http_SUITE.erl
index 10d26b8..187d5af 100644
--- a/test/http_SUITE.erl
+++ b/test/http_SUITE.erl
@@ -20,7 +20,7 @@
 	init_per_group/2, end_per_group/2]). %% ct.
 -export([chunked_response/1, headers_dupe/1, headers_huge/1,
 	keepalive_nl/1, nc_rand/1, pipeline/1, raw/1,
-	ws0/1, ws8/1, ws8_single_bytes/1,
+	ws0/1, ws8/1, ws8_single_bytes/1, ws8_init_shutdown/1,
 	ws_timeout_hibernate/1]). %% http.
 -export([http_200/1, http_404/1]). %% http and https.
 -export([http_10_hostless/1]). %% misc.
@@ -34,7 +34,7 @@ groups() ->
 	BaseTests = [http_200, http_404],
 	[{http, [], [chunked_response, headers_dupe, headers_huge,
 		keepalive_nl, nc_rand, pipeline, raw,
-		ws0, ws8, ws8_single_bytes,
+		ws0, ws8, ws8_single_bytes, ws8_init_shutdown,
 		ws_timeout_hibernate] ++ BaseTests},
 	{https, [], BaseTests}, {misc, [], [http_10_hostless]}].
 
@@ -95,6 +95,7 @@ init_http_dispatch() ->
 			{[<<"chunked_response">>], chunked_handler, []},
 			{[<<"websocket">>], websocket_handler, []},
 			{[<<"ws_timeout_hibernate">>], ws_timeout_hibernate_handler, []},
+			{[<<"ws_init_shutdown">>], websocket_handler_init_shutdown, []},
 			{[<<"init_shutdown">>], http_handler_init_shutdown, []},
 			{[<<"headers">>, <<"dupe">>], http_handler,
 				[{headers, [{<<"Connection">>, <<"close">>}]}]},
@@ -394,6 +395,25 @@ ws_timeout_hibernate(Config) ->
 	{error, closed} = gen_tcp:recv(Socket, 0, 6000),
 	ok.
 
+ws8_init_shutdown(Config) ->
+	{port, Port} = lists:keyfind(port, 1, Config),
+	{ok, Socket} = gen_tcp:connect("localhost", Port,
+		[binary, {active, false}, {packet, raw}]),
+	ok = gen_tcp:send(Socket, [
+		"GET /ws_init_shutdown HTTP/1.1\r\n"
+		"Host: localhost\r\n"
+		"Connection: Upgrade\r\n"
+		"Upgrade: websocket\r\n"
+		"Sec-WebSocket-Origin: http://localhost\r\n"
+		"Sec-WebSocket-Version: 8\r\n"
+		"Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n"
+		"\r\n"]),
+	{ok, Handshake} = gen_tcp:recv(Socket, 0, 6000),
+	{ok, {http_response, {1, 1}, 403, "Forbidden"}, _Rest}
+		= erlang:decode_packet(http, Handshake, []),
+	{error, closed} = gen_tcp:recv(Socket, 0, 6000),
+	ok.
+
 websocket_headers({ok, http_eoh, Rest}, Acc) ->
 	[Acc, Rest];
 websocket_headers({ok, {http_header, _I, Key, _R, Value}, Rest}, Acc) ->
diff --git a/test/websocket_handler_init_shutdown.erl b/test/websocket_handler_init_shutdown.erl
new file mode 100644
index 0000000..ad76336
--- /dev/null
+++ b/test/websocket_handler_init_shutdown.erl
@@ -0,0 +1,30 @@
+%% Feel free to use, reuse and abuse the code in this file.
+
+-module(websocket_handler_init_shutdown).
+-behaviour(cowboy_http_handler).
+-behaviour(cowboy_http_websocket_handler).
+-export([init/3, handle/2, terminate/2]).
+-export([websocket_init/3, websocket_handle/3,
+	websocket_info/3, websocket_terminate/3]).
+
+init(_Any, _Req, _Opts) ->
+	{upgrade, protocol, cowboy_http_websocket}.
+
+handle(_Req, _State) ->
+	exit(badarg).
+
+terminate(_Req, _State) ->
+	exit(badarg).
+
+websocket_init(_TransportName, Req, _Opts) ->
+	Req2 = cowboy_http_req:reply(403, [], [], Req),
+	{shutdown, Req2}.
+
+websocket_handle(_Frame, _Req, _State) ->
+	exit(badarg).
+
+websocket_info(_Info, _Req, _State) ->
+	exit(badarg).
+
+websocket_terminate(_Reason, _Req, _State) ->
+	exit(badarg).
-- 
cgit v1.2.3