diff options
-rw-r--r-- | src/cowboy_websocket.erl | 37 | ||||
-rw-r--r-- | test/rfc7231_SUITE.erl | 22 |
2 files changed, 46 insertions, 13 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index d38c673..c9794b6 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -97,7 +97,12 @@ upgrade(Req0, Env, Handler, HandlerState, Opts) -> State0 = #state{handler=Handler, timeout=Timeout, compress=Compress, req=FilteredReq}, try websocket_upgrade(State0, Req0) of {ok, State, Req} -> - websocket_handshake(State, Req, HandlerState, Env) + websocket_handshake(State, Req, HandlerState, Env); + {error, upgrade_required} -> + {ok, cowboy_req:reply(426, #{ + <<"connection">> => <<"upgrade">>, + <<"upgrade">> => <<"websocket">> + }, Req0), Env} catch _:_ -> %% @todo Probably log something here? %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection. @@ -108,17 +113,25 @@ upgrade(Req0, Env, Handler, HandlerState, Opts) -> -spec websocket_upgrade(#state{}, Req) -> {ok, #state{}, Req} when Req::cowboy_req:req(). websocket_upgrade(State, Req) -> - ConnTokens = cowboy_req:parse_header(<<"connection">>, Req), - true = lists:member(<<"upgrade">>, ConnTokens), - %% @todo Should probably send a 426 if the Upgrade header is missing. - [<<"websocket">>] = cowboy_req:parse_header(<<"upgrade">>, Req), - Version = cowboy_req:header(<<"sec-websocket-version">>, Req), - IntVersion = binary_to_integer(Version), - true = (IntVersion =:= 7) orelse (IntVersion =:= 8) - orelse (IntVersion =:= 13), - Key = cowboy_req:header(<<"sec-websocket-key">>, Req), - false = Key =:= undefined, - websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}). + ConnTokens = cowboy_req:parse_header(<<"connection">>, Req, []), + case lists:member(<<"upgrade">>, ConnTokens) of + false -> + {error, upgrade_required}; + true -> + UpgradeTokens = cowboy_req:parse_header(<<"upgrade">>, Req, []), + case lists:member(<<"websocket">>, UpgradeTokens) of + false -> + {error, upgrade_required}; + true -> + Version = cowboy_req:header(<<"sec-websocket-version">>, Req), + IntVersion = binary_to_integer(Version), + true = (IntVersion =:= 7) orelse (IntVersion =:= 8) + orelse (IntVersion =:= 13), + Key = cowboy_req:header(<<"sec-websocket-key">>, Req), + false = Key =:= undefined, + websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}) + end + end. -spec websocket_extensions(#state{}, Req) -> {ok, #state{}, Req} when Req::cowboy_req:req(). diff --git a/test/rfc7231_SUITE.erl b/test/rfc7231_SUITE.erl index f1b3415..9eddf72 100644 --- a/test/rfc7231_SUITE.erl +++ b/test/rfc7231_SUITE.erl @@ -41,7 +41,8 @@ init_dispatch(_) -> {"*", asterisk_h, []}, {"/", hello_h, []}, {"/echo/:key", echo_h, []}, - {"/resp/:key[/:arg]", resp_h, []} + {"/resp/:key[/:arg]", resp_h, []}, + {"/ws", ws_init_h, []} ]}]). %% @todo The documentation should list what methods, headers and status codes @@ -514,6 +515,25 @@ status_code_426(Config) -> {response, _, 426, _} = gun:await(ConnPid, Ref), ok. +status_code_426_upgrade_header(Config) -> + case config(protocol, Config) of + http -> + do_status_code_426_upgrade_header(Config); + http2 -> + doc("HTTP/2 does not support the HTTP/1.1 Upgrade mechanism.") + end. + +do_status_code_426_upgrade_header(Config) -> + doc("A 426 response must include a upgrade header. (RFC7231 6.5.15)"), + ConnPid = gun_open(Config), + Ref = gun:get(ConnPid, "/ws?ok", [ + {<<"accept-encoding">>, <<"gzip">>} + ]), + {response, _, 426, Headers} = gun:await(ConnPid, Ref), + {_, <<"upgrade">>} = lists:keyfind(<<"connection">>, 1, Headers), + {_, <<"websocket">>} = lists:keyfind(<<"upgrade">>, 1, Headers), + ok. + status_code_500(Config) -> doc("The 500 Internal Server Error status code can be sent. (RFC7231 6.6.1)"), ConnPid = gun_open(Config), |