From 5ef4a15b48bfc1b5ca867b893b7cbd1b535175f7 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= Date: Mon, 3 Dec 2012 14:13:46 +0100 Subject: Allow passing the Req and an updated Opts when upgrading protocols --- src/cowboy_protocol.erl | 4 +++- test/ws_SUITE.erl | 37 +++++++++++++++++++++++++++++++++-- test/ws_upgrade_with_opts_handler.erl | 28 ++++++++++++++++++++++++++ 3 files changed, 66 insertions(+), 3 deletions(-) create mode 100644 test/ws_upgrade_with_opts_handler.erl diff --git a/src/cowboy_protocol.erl b/src/cowboy_protocol.erl index 901a43c..cbedbfa 100644 --- a/src/cowboy_protocol.erl +++ b/src/cowboy_protocol.erl @@ -491,7 +491,9 @@ handler_init(Req, State=#state{transport=Transport}, Handler, Opts) -> handler_terminate(Req2, Handler, HandlerState); %% @todo {upgrade, transport, Module} {upgrade, protocol, Module} -> - upgrade_protocol(Req, State, Handler, Opts, Module) + upgrade_protocol(Req, State, Handler, Opts, Module); + {upgrade, protocol, Module, Req2, Opts2} -> + upgrade_protocol(Req2, State, Handler, Opts2, Module) catch Class:Reason -> error_terminate(500, State), error_logger:error_msg( diff --git a/test/ws_SUITE.erl b/test/ws_SUITE.erl index b63f41c..5047d97 100644 --- a/test/ws_SUITE.erl +++ b/test/ws_SUITE.erl @@ -35,6 +35,7 @@ -export([ws_send_many/1]). -export([ws_text_fragments/1]). -export([ws_timeout_hibernate/1]). +-export([ws_upgrade_with_opts/1]). %% ct. @@ -52,7 +53,8 @@ groups() -> ws_send_close_payload, ws_send_many, ws_text_fragments, - ws_timeout_hibernate + ws_timeout_hibernate, + ws_upgrade_with_opts ], [{ws, [], BaseTests}]. @@ -107,7 +109,9 @@ init_dispatch() -> {close, <<"some text!">>}, {text, <<"won't be received">>}]} ]}, - {[<<"ws_timeout_hibernate">>], ws_timeout_hibernate_handler, []} + {[<<"ws_timeout_hibernate">>], ws_timeout_hibernate_handler, []}, + {[<<"ws_upgrade_with_opts">>], ws_upgrade_with_opts_handler, + <<"failure">>} ]} ]. @@ -502,6 +506,35 @@ ws_timeout_hibernate(Config) -> {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. +ws_upgrade_with_opts(Config) -> + {port, Port} = lists:keyfind(port, 1, Config), + {ok, Socket} = gen_tcp:connect("localhost", Port, + [binary, {active, false}, {packet, raw}]), + ok = gen_tcp:send(Socket, [ + "GET /ws_upgrade_with_opts HTTP/1.1\r\n" + "Host: localhost\r\n" + "Connection: Upgrade\r\n" + "Upgrade: websocket\r\n" + "Sec-WebSocket-Origin: http://localhost\r\n" + "Sec-WebSocket-Version: 8\r\n" + "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" + "\r\n"]), + {ok, Handshake} = gen_tcp:recv(Socket, 0, 6000), + {ok, {http_response, {1, 1}, 101, "Switching Protocols"}, Rest} + = erlang:decode_packet(http, Handshake, []), + [Headers, <<>>] = websocket_headers( + erlang:decode_packet(httph, Rest, []), []), + {'Connection', "Upgrade"} = lists:keyfind('Connection', 1, Headers), + {'Upgrade', "websocket"} = lists:keyfind('Upgrade', 1, Headers), + {"sec-websocket-accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="} + = lists:keyfind("sec-websocket-accept", 1, Headers), + {ok, Response} = gen_tcp:recv(Socket, 9, 6000), + << 1:1, 0:3, 1:4, 0:1, 7:7, "success" >> = Response, + ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>), %% close + {ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000), + {error, closed} = gen_tcp:recv(Socket, 0, 6000), + ok. + %% Internal. websocket_headers({ok, http_eoh, Rest}, Acc) -> diff --git a/test/ws_upgrade_with_opts_handler.erl b/test/ws_upgrade_with_opts_handler.erl new file mode 100644 index 0000000..02d755e --- /dev/null +++ b/test/ws_upgrade_with_opts_handler.erl @@ -0,0 +1,28 @@ +%% Feel free to use, reuse and abuse the code in this file. + +-module(ws_upgrade_with_opts_handler). +-behaviour(cowboy_websocket_handler). + +-export([init/3]). +-export([websocket_init/3]). +-export([websocket_handle/3]). +-export([websocket_info/3]). +-export([websocket_terminate/3]). + +init(_Any, Req, _Opts) -> + {upgrade, protocol, cowboy_websocket, Req, <<"success">>}. + +websocket_init(_TransportName, Req, Response) -> + Req2 = cowboy_req:compact(Req), + erlang:send_after(10, self(), send_response), + {ok, Req2, Response}. + +websocket_handle(_Frame, Req, State) -> + {ok, Req, State}. + +websocket_info(send_response, Req, State = Response) + when is_binary(Response) -> + {reply, {text, Response}, Req, State}. + +websocket_terminate(_Reason, _Req, _State) -> + ok. -- cgit v1.2.3