From 8d690995085088a9af66b898b68abc0c9580ad01 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= Date: Mon, 14 Jan 2013 16:20:33 +0100 Subject: Improve websocket close handling We now always send a failure reason (bad protocol, bad encoding, etc.) unless the closure was initiated by the client and it didn't send a close code. We now check that the close frames have a payload that is valid UTF-8, unless they don't have a payload at all. We now do not crash the process anymore when bad opcodes are sent, or when the opcode 0 is sent before fragmentation was initiated. Overall this makes us closer to full compliance with the RFC. --- src/cowboy_websocket.erl | 72 ++++++++++++++++++++++++++++++++++++++---------- test/ws_SUITE.erl | 12 ++++---- 2 files changed, 63 insertions(+), 21 deletions(-) diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index a10b008..4553aef 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -196,7 +196,9 @@ handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error}, SoFar, websocket_info, Message, fun handler_before_loop/4) end. -%% All frames passing through this function are considered valid. +%% All frames passing through this function are considered valid, +%% with the only exception of text and close frames with a payload +%% which may still contain errors. -spec websocket_data(#state{}, Req, any(), binary()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} @@ -206,19 +208,31 @@ handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error}, websocket_data(State, Req, HandlerState, << _:1, Rsv:3, _/bits >>) when Rsv =/= 0 -> websocket_close(State, Req, HandlerState, {error, badframe}); +%% Invalid opcode. Note that these opcodes may be used by extensions. +websocket_data(State, Req, HandlerState, << _:4, Opcode:4, _/bits >>) + when Opcode > 2, Opcode =/= 8, Opcode =/= 9, Opcode =/= 10 -> + websocket_close(State, Req, HandlerState, {error, badframe}); %% Control frames MUST NOT be fragmented. websocket_data(State, Req, HandlerState, << 0:1, _:3, Opcode:4, _/bits >>) when Opcode >= 8 -> websocket_close(State, Req, HandlerState, {error, badframe}); -%% A fragmented message MUST start a non-zero opcode. +%% A frame MUST NOT use the zero opcode unless fragmentation was initiated. websocket_data(State=#state{frag_state=undefined}, Req, HandlerState, - << 0:1, _:3, 0:4, _/bits >>) -> + << _:4, 0:4, _/bits >>) -> websocket_close(State, Req, HandlerState, {error, badframe}); %% Non-control opcode when expecting control message or next fragment. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState, << _:4, Opcode:4, _/bits >>) when Opcode =/= 0, Opcode < 8 -> websocket_close(State, Req, HandlerState, {error, badframe}); +%% Close control frame length MUST be 0 or >= 2. +websocket_data(State, Req, HandlerState, << _:4, 8:4, _:1, 1:7, _/bits >>) -> + websocket_close(State, Req, HandlerState, {error, badframe}); +%% Close control frame with incomplete close code. Need more data. +websocket_data(State, Req, HandlerState, + Data = << _:4, 8:4, 1:1, Len:7, _/bits >>) + when Len > 1, byte_size(Data) < 8 -> + handler_before_loop(State, Req, HandlerState, Data); %% 7 bits payload length. websocket_data(State, Req, HandlerState, << Fin:1, _Rsv:3, Opcode:4, 1:1, Len:7, MaskKey:32, Rest/bits >>) @@ -286,22 +300,35 @@ websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, 1) -> -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). -%% Text frames must have a payload that is valid UTF-8. +%% Close control frames with a payload MUST contain a valid close code. +websocket_payload(State, Req, HandlerState, + Opcode=8, Len, MaskKey, <<>>, << MaskedCode:2/binary, Rest/bits >>) -> + Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>), + if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006; + (Code > 1011) and (Code < 3000); Code > 4999 -> + websocket_close(State, Req, HandlerState, {error, badframe}); + true -> + websocket_payload(State, Req, HandlerState, + Opcode, Len - 2, MaskKey, Unmasked, Rest) + end; +%% Text frames and close control frames MUST have a payload that is valid UTF-8. websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode=1, Len, MaskKey, Unmasked, Data) - when byte_size(Data) < Len -> + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data) + when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse + ((Opcode =:= 8) andalso (Unmasked =/= <<>>))) -> Unmasked2 = websocket_unmask(Data, rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), case is_utf8(<< Incomplete/binary, Unmasked2/binary >>) of false -> - websocket_close(State, Req, HandlerState, {error, badframe}); + websocket_close(State, Req, HandlerState, {error, badencoding}); Utf8State -> websocket_payload_loop(State#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, << Unmasked/binary, Unmasked2/binary >>) end; websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode=1, Len, MaskKey, Unmasked, Data) -> + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data) + when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), @@ -311,7 +338,7 @@ websocket_payload(State=#state{utf8_state=Incomplete}, Req, HandlerState, Rest, Opcode, << Unmasked/binary, Unmasked2/binary >>); _ -> - websocket_close(State, Req, HandlerState, {error, badframe}) + websocket_close(State, Req, HandlerState, {error, badencoding}) end; %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints. websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, @@ -321,7 +348,7 @@ websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), case is_utf8(<< Incomplete/binary, Unmasked2/binary >>) of false -> - websocket_close(State, Req, HandlerState, {error, badframe}); + websocket_close(State, Req, HandlerState, {error, badencoding}); Utf8State -> websocket_payload_loop(State#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, @@ -342,7 +369,7 @@ websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, Req, HandlerState, Rest, Opcode, << Unmasked/binary, Unmasked2/binary >>); _ -> - websocket_close(State, Req, HandlerState, {error, badframe}) + websocket_close(State, Req, HandlerState, {error, badencoding}) end; %% Other frames have a binary payload. websocket_payload(State, Req, HandlerState, @@ -470,9 +497,11 @@ websocket_dispatch(State, Req, HandlerState, RemainingData, 2, Payload) -> handler_call(State, Req, HandlerState, RemainingData, websocket_handle, {binary, Payload}, fun websocket_data/4); %% Close control frame. -%% @todo Handle the optional Payload. -websocket_dispatch(State, Req, HandlerState, _RemainingData, 8, _Payload) -> - websocket_close(State, Req, HandlerState, {normal, closed}); +websocket_dispatch(State, Req, HandlerState, _RemainingData, 8, <<>>) -> + websocket_close(State, Req, HandlerState, {remote, closed}); +websocket_dispatch(State, Req, HandlerState, _RemainingData, 8, + << Code:16, Payload/bits >>) -> + websocket_close(State, Req, HandlerState, {remote, Code, Payload}); %% Ping control frame. Send a pong back and forward the ping to the handler. websocket_dispatch(State=#state{socket=Socket, transport=Transport}, Req, HandlerState, RemainingData, 9, Payload) -> @@ -621,7 +650,20 @@ websocket_send_many([Frame|Tail], State) -> when Req::cowboy_req:req(). websocket_close(State=#state{socket=Socket, transport=Transport}, Req, HandlerState, Reason) -> - Transport:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>), + case Reason of + {normal, _} -> + Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1000:16 >>); + {error, badframe} -> + Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1002:16 >>); + {error, badencoding} -> + Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1007:16 >>); + {error, handler} -> + Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1011:16 >>); + {remote, closed} -> + Transport:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>); + {remote, Code, _} -> + Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, Code:16 >>) + end, handler_terminate(State, Req, HandlerState, Reason). -spec handler_terminate(#state{}, Req, any(), atom() | {atom(), atom()}) diff --git a/test/ws_SUITE.erl b/test/ws_SUITE.erl index cc1d557..92fd98b 100644 --- a/test/ws_SUITE.erl +++ b/test/ws_SUITE.erl @@ -396,7 +396,7 @@ ws_send_many(Config) -> << 1:1, 0:3, 1:4, 0:1, 3:7, "one", 1:1, 0:3, 1:4, 0:1, 3:7, "two", 1:1, 0:3, 1:4, 0:1, 6:7, "seven!" >> = Many, - ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>), %% close + ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 1:1, 0:7, 0:32 >>), %% close {ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000), {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. @@ -450,7 +450,7 @@ ws_text_fragments(Config) -> << 16#9f >>, << 16#4d >>, << 16#51 >>, << 16#58 >>]), {ok, << 1:1, 0:3, 1:4, 0:1, 15:7, "HelloHelloHello" >>} = gen_tcp:recv(Socket, 0, 6000), - ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>), %% close + ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 1:1, 0:7, 0:32 >>), %% close {ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000), {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. @@ -477,7 +477,7 @@ ws_timeout_hibernate(Config) -> {'Upgrade', "websocket"} = lists:keyfind('Upgrade', 1, Headers), {"sec-websocket-accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="} = lists:keyfind("sec-websocket-accept", 1, Headers), - {ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000), + {ok, << 1:1, 0:3, 8:4, 0:1, 2:7, 1000:16 >>} = gen_tcp:recv(Socket, 0, 6000), {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. @@ -504,7 +504,7 @@ ws_timeout_cancel(Config) -> {'Upgrade', "websocket"} = lists:keyfind('Upgrade', 1, Headers), {"sec-websocket-accept", "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="} = lists:keyfind("sec-websocket-accept", 1, Headers), - {ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000), + {ok, << 1:1, 0:3, 8:4, 0:1, 2:7, 1000:16 >>} = gen_tcp:recv(Socket, 0, 6000), {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. @@ -538,7 +538,7 @@ ws_timeout_reset(Config) -> = gen_tcp:recv(Socket, 0, 6000), ok = timer:sleep(500) end || _ <- [1, 2, 3, 4]], - {ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000), + {ok, << 1:1, 0:3, 8:4, 0:1, 2:7, 1000:16 >>} = gen_tcp:recv(Socket, 0, 6000), {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. @@ -566,7 +566,7 @@ ws_upgrade_with_opts(Config) -> = lists:keyfind("sec-websocket-accept", 1, Headers), {ok, Response} = gen_tcp:recv(Socket, 9, 6000), << 1:1, 0:3, 1:4, 0:1, 7:7, "success" >> = Response, - ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>), %% close + ok = gen_tcp:send(Socket, << 1:1, 0:3, 8:4, 1:1, 0:7, 0:32 >>), %% close {ok, << 1:1, 0:3, 8:4, 0:8 >>} = gen_tcp:recv(Socket, 0, 6000), {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. -- cgit v1.2.3