diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cowboy_websocket.erl | 552 |
1 files changed, 116 insertions, 436 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 36190a5..681470f 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -20,21 +20,8 @@ -export([upgrade/6]). -export([handler_loop/4]). --type close_code() :: 1000..4999. --export_type([close_code/0]). - --type frame() :: close | ping | pong - | {text | binary | close | ping | pong, iodata()} - | {close, close_code(), iodata()}. --export_type([frame/0]). - --type opcode() :: 0 | 1 | 2 | 8 | 9 | 10. --type mask_key() :: 0..16#ffffffff. --type frag_state() :: undefined - | {nofin, opcode(), binary()} | {fin, opcode(), binary()}. --type rsv() :: << _:3 >>. -type terminate_reason() :: normal | stop | timeout - | remote | {remote, close_code(), binary()} + | remote | {remote, cow_ws:close_code(), binary()} | {error, badencoding | badframe | closed | atom()} | {crash, error | exit | throw, any()}. @@ -47,15 +34,15 @@ -callback websocket_handle({text | binary | ping | pong, binary()}, Req, State) -> {ok, Req, State} | {ok, Req, State, hibernate} - | {reply, frame() | [frame()], Req, State} - | {reply, frame() | [frame()], Req, State, hibernate} + | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State} + | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State, hibernate} | {stop, Req, State} when Req::cowboy_req:req(), State::any(). -callback websocket_info(any(), Req, State) -> {ok, Req, State} | {ok, Req, State, hibernate} - | {reply, frame() | [frame()], Req, State} - | {reply, frame() | [frame()], Req, State, hibernate} + | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State} + | {reply, cow_ws:frame() | [cow_ws:frame()], Req, State, hibernate} | {stop, Req, State} when Req::cowboy_req:req(), State::any(). %% @todo optional -callback terminate(terminate_reason(), cowboy_req:req(), state()) -> ok. @@ -70,11 +57,11 @@ timeout_ref = undefined :: undefined | reference(), messages = undefined :: undefined | {atom(), atom(), atom()}, hibernate = false :: boolean(), - frag_state = undefined :: frag_state(), + frag_state = undefined :: cow_ws:frag_state(), + frag_buffer = <<>> :: binary(), utf8_state = <<>> :: binary(), - deflate_frame = false :: boolean(), - inflate_state :: undefined | port(), - deflate_state :: undefined | port() + recv_extensions = #{} :: map(), + send_extensions = #{} :: map() }). -spec upgrade(Req, Env, module(), any(), timeout(), run | hibernate) @@ -135,9 +122,8 @@ websocket_extensions(State, Req) -> % the zlib headers. ok = zlib:deflateInit(Deflate, best_compression, deflated, -15, 8, default), {ok, State#state{ - deflate_frame = true, - inflate_state = Inflate, - deflate_state = Deflate + recv_extensions = #{deflate_frame => Inflate}, + send_extensions = #{deflate_frame => Deflate} }, cowboy_req:set_meta(websocket_compress, true, Req)}; _ -> {ok, State, cowboy_req:set_meta(websocket_compress, false, Req)} @@ -149,16 +135,16 @@ websocket_extensions(State, Req) -> | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). websocket_handshake(State=#state{ - transport=Transport, key=Key, deflate_frame=DeflateFrame}, + transport=Transport, key=Key, recv_extensions=Extensions}, Req, HandlerState) -> Challenge = base64:encode(crypto:hash(sha, << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)), - Extensions = case DeflateFrame of - false -> []; - true -> [{<<"sec-websocket-extensions">>, <<"x-webkit-deflate-frame">>}] + ExtHeader = case Extensions of + #{deflate_frame := _} -> [{<<"sec-websocket-extensions">>, <<"x-webkit-deflate-frame">>}]; + _ -> [] end, Req2 = cowboy_req:upgrade_reply(101, [{<<"upgrade">>, <<"websocket">>}, - {<<"sec-websocket-accept">>, Challenge}|Extensions], Req), + {<<"sec-websocket-accept">>, Challenge}|ExtHeader], Req), %% Flush the resp_sent message before moving on. receive {cowboy_req, resp_sent} -> ok after 0 -> ok end, State2 = handler_loop_timeout(State), @@ -213,299 +199,59 @@ handler_loop(State=#state{socket=Socket, messages={OK, Closed, Error}, SoFar, websocket_info, Message, fun handler_before_loop/4) end. -%% All frames passing through this function are considered valid, -%% with the only exception of text and close frames with a payload -%% which may still contain errors. -spec websocket_data(#state{}, Req, any(), binary()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). -%% RSV bits MUST be 0 unless an extension is negotiated -%% that defines meanings for non-zero values. -websocket_data(State, Req, HandlerState, << _:1, Rsv:3, _/bits >>) - when Rsv =/= 0, State#state.deflate_frame =:= false -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% Invalid opcode. Note that these opcodes may be used by extensions. -websocket_data(State, Req, HandlerState, << _:4, Opcode:4, _/bits >>) - when Opcode > 2, Opcode =/= 8, Opcode =/= 9, Opcode =/= 10 -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% Control frames MUST NOT be fragmented. -websocket_data(State, Req, HandlerState, << 0:1, _:3, Opcode:4, _/bits >>) - when Opcode >= 8 -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% A frame MUST NOT use the zero opcode unless fragmentation was initiated. -websocket_data(State=#state{frag_state=undefined}, Req, HandlerState, - << _:4, 0:4, _/bits >>) -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% Non-control opcode when expecting control message or next fragment. -websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState, - << _:4, Opcode:4, _/bits >>) - when Opcode =/= 0, Opcode < 8 -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% Close control frame length MUST be 0 or >= 2. -websocket_data(State, Req, HandlerState, << _:4, 8:4, _:1, 1:7, _/bits >>) -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% Close control frame with incomplete close code. Need more data. -websocket_data(State, Req, HandlerState, - Data = << _:4, 8:4, 1:1, Len:7, _/bits >>) - when Len > 1, byte_size(Data) < 8 -> - handler_before_loop(State, Req, HandlerState, Data); -%% 7 bits payload length. -websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1, - Len:7, MaskKey:32, Rest/bits >>) - when Len < 126 -> - websocket_data(State, Req, HandlerState, - Opcode, Len, MaskKey, Rest, Rsv, Fin); -%% 16 bits payload length. -websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1, - 126:7, Len:16, MaskKey:32, Rest/bits >>) - when Len > 125, Opcode < 8 -> - websocket_data(State, Req, HandlerState, - Opcode, Len, MaskKey, Rest, Rsv, Fin); -%% 63 bits payload length. -websocket_data(State, Req, HandlerState, << Fin:1, Rsv:3/bits, Opcode:4, 1:1, - 127:7, 0:1, Len:63, MaskKey:32, Rest/bits >>) - when Len > 16#ffff, Opcode < 8 -> - websocket_data(State, Req, HandlerState, - Opcode, Len, MaskKey, Rest, Rsv, Fin); -%% When payload length is over 63 bits, the most significant bit MUST be 0. -websocket_data(State, Req, HandlerState, << _:8, 1:1, 127:7, 1:1, _:7, _/bits >>) -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% All frames sent from the client to the server are masked. -websocket_data(State, Req, HandlerState, << _:8, 0:1, _/bits >>) -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% For the next two clauses, it can be one of the following: -%% -%% * The minimal number of bytes MUST be used to encode the length -%% * All control frames MUST have a payload length of 125 bytes or less -websocket_data(State, Req, HandlerState, << _:9, 126:7, _:48, _/bits >>) -> - websocket_close(State, Req, HandlerState, {error, badframe}); -websocket_data(State, Req, HandlerState, << _:9, 127:7, _:96, _/bits >>) -> - websocket_close(State, Req, HandlerState, {error, badframe}); -%% Need more data. -websocket_data(State, Req, HandlerState, Data) -> - handler_before_loop(State, Req, HandlerState, Data). - -%% Initialize or update fragmentation state. --spec websocket_data(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), rsv(), 0 | 1) - -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} - when Req::cowboy_req:req(). -%% The opcode is only included in the first frame fragment. -websocket_data(State=#state{frag_state=undefined}, Req, HandlerState, - Opcode, Len, MaskKey, Data, Rsv, 0) -> - websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); -%% Subsequent frame fragments. -websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState, - 0, Len, MaskKey, Data, Rsv, 0) -> - websocket_payload(State, Req, HandlerState, - 0, Len, MaskKey, <<>>, 0, Data, Rsv); -%% Final frame fragment. -websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}}, - Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) -> - websocket_payload(State#state{frag_state={fin, Opcode, SoFar}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); -%% Unfragmented frame. -websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) -> - websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, <<>>, 0, Data, Rsv). - --spec websocket_payload(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), non_neg_integer(), - binary(), rsv()) - -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} - when Req::cowboy_req:req(). -%% Close control frames with a payload MUST contain a valid close code. -websocket_payload(State, Req, HandlerState, - Opcode=8, Len, MaskKey, <<>>, 0, - << MaskedCode:2/binary, Rest/bits >>, Rsv) -> - Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>), - if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006; - (Code > 1011) and (Code < 3000); Code > 4999 -> +websocket_data(State=#state{frag_state=FragState, recv_extensions=Extensions}, Req, HandlerState, Data) -> + case cow_ws:parse_header(Data, Extensions, FragState) of + %% All frames sent from the client to the server are masked. + {_, _, _, _, undefined, _} -> websocket_close(State, Req, HandlerState, {error, badframe}); - true -> - websocket_payload(State, Req, HandlerState, - Opcode, Len - 2, MaskKey, Unmasked, byte_size(MaskedCode), - Rest, Rsv) - end; -%% Text frames and close control frames MUST have a payload that is valid UTF-8. -websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, - Data, Rsv) - when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse - ((Opcode =:= 8) andalso (Unmasked =/= <<>>))) -> - Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, UnmaskedLen), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), - case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of - false -> - websocket_close(State2, Req, HandlerState, {error, badencoding}); - Utf8State -> - websocket_payload_loop(State2#state{utf8_state=Utf8State}, - Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, - UnmaskedLen + byte_size(Data), Rsv) - end; -websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, - Data, Rsv) - when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) -> - << End:Len/binary, Rest/bits >> = Data, - Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, UnmaskedLen), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), - case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of - <<>> -> - websocket_dispatch(State2#state{utf8_state= <<>>}, - Req, HandlerState, Rest, Opcode, - << Unmasked/binary, Unmasked3/binary >>); - _ -> - websocket_close(State2, Req, HandlerState, {error, badencoding}) - end; -%% Fragmented text frames may cut payload in the middle of UTF-8 codepoints. -websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, - Data, Rsv) - when byte_size(Data) < Len -> - Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, UnmaskedLen), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), - case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of - false -> - websocket_close(State2, Req, HandlerState, {error, badencoding}); - Utf8State -> - websocket_payload_loop(State2#state{utf8_state=Utf8State}, - Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, - UnmaskedLen + byte_size(Data), Rsv) - end; -websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, - Data, Rsv) -> - << End:Len/binary, Rest/bits >> = Data, - Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, UnmaskedLen), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, Fin =:= fin, State), - case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of - <<>> -> - websocket_dispatch(State2#state{utf8_state= <<>>}, - Req, HandlerState, Rest, Opcode, - << Unmasked/binary, Unmasked3/binary >>); - Utf8State when is_binary(Utf8State), Fin =:= nofin -> - websocket_dispatch(State2#state{utf8_state=Utf8State}, - Req, HandlerState, Rest, Opcode, - << Unmasked/binary, Unmasked3/binary >>); - _ -> - websocket_close(State, Req, HandlerState, {error, badencoding}) - end; -%% Other frames have a binary payload. -websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) - when byte_size(Data) < Len -> - Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, UnmaskedLen), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), - websocket_payload_loop(State2, Req, HandlerState, - Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, UnmaskedLen + byte_size(Data), - Rsv); -websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) -> - << End:Len/binary, Rest/bits >> = Data, - Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, UnmaskedLen), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), - websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, - << Unmasked/binary, Unmasked3/binary >>). - --spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) -> - {binary(), #state{}}. -websocket_inflate_frame(Data, << Rsv1:1, _:2 >>, _, - #state{deflate_frame = DeflateFrame} = State) - when DeflateFrame =:= false orelse Rsv1 =:= 0 -> - {Data, State}; -websocket_inflate_frame(Data, << 1:1, _:2 >>, false, State) -> - Result = zlib:inflate(State#state.inflate_state, Data), - {iolist_to_binary(Result), State}; -websocket_inflate_frame(Data, << 1:1, _:2 >>, true, State) -> - Result = zlib:inflate(State#state.inflate_state, - << Data/binary, 0:8, 0:8, 255:8, 255:8 >>), - {iolist_to_binary(Result), State}. - --spec websocket_unmask(B, mask_key(), B) -> B when B::binary(). -websocket_unmask(<<>>, _, Unmasked) -> - Unmasked; -websocket_unmask(<< O:32, Rest/bits >>, MaskKey, Acc) -> - T = O bxor MaskKey, - websocket_unmask(Rest, MaskKey, << Acc/binary, T:32 >>); -websocket_unmask(<< O:24 >>, MaskKey, Acc) -> - << MaskKey2:24, _:8 >> = << MaskKey:32 >>, - T = O bxor MaskKey2, - << Acc/binary, T:24 >>; -websocket_unmask(<< O:16 >>, MaskKey, Acc) -> - << MaskKey2:16, _:16 >> = << MaskKey:32 >>, - T = O bxor MaskKey2, - << Acc/binary, T:16 >>; -websocket_unmask(<< O:8 >>, MaskKey, Acc) -> - << MaskKey2:8, _:24 >> = << MaskKey:32 >>, - T = O bxor MaskKey2, - << Acc/binary, T:8 >>. + %% No payload. + {Type, FragState2, _, 0, _, Rest} -> + websocket_dispatch(State#state{frag_state=FragState2}, Req, HandlerState, Type, <<>>, undefined, Rest); + {Type, FragState2, Rsv, Len, MaskKey, Rest} -> + websocket_payload(State#state{frag_state=FragState2}, Req, HandlerState, Type, Len, MaskKey, Rsv, Rest); + more -> + handler_before_loop(State, Req, HandlerState, Data); + error -> + websocket_close(State, Req, HandlerState, {error, badframe}) + end. -%% Because we unmask on the fly we need to continue from the right mask byte. --spec rotate_mask_key(mask_key(), non_neg_integer()) -> mask_key(). -rotate_mask_key(MaskKey, UnmaskedLen) -> - Left = UnmaskedLen rem 4, - Right = 4 - Left, - (MaskKey bsl (Left * 8)) + (MaskKey bsr (Right * 8)). +websocket_payload(State, Req, HandlerState, Type = close, Len, MaskKey, Rsv, Data) -> + case cow_ws:parse_close_code(Data, MaskKey) of + {ok, CloseCode, Rest} -> + websocket_payload(State, Req, HandlerState, Type, Len - 2, MaskKey, Rsv, CloseCode, <<>>, 2, Rest); + error -> + websocket_close(State, Req, HandlerState, {error, badframe}) + end; +websocket_payload(State, Req, HandlerState, Type, Len, MaskKey, Rsv, Data) -> + websocket_payload(State, Req, HandlerState, Type, Len, MaskKey, Rsv, undefined, <<>>, 0, Data). -%% Returns <<>> if the argument is valid UTF-8, false if not, -%% or the incomplete part of the argument if we need more data. --spec is_utf8(binary()) -> false | binary(). -is_utf8(Valid = <<>>) -> - Valid; -is_utf8(<< _/utf8, Rest/bits >>) -> - is_utf8(Rest); -%% 2 bytes. Codepages C0 and C1 are invalid; fail early. -is_utf8(<< 2#1100000:7, _/bits >>) -> - false; -is_utf8(Incomplete = << 2#110:3, _:5 >>) -> - Incomplete; -%% 3 bytes. -is_utf8(Incomplete = << 2#1110:4, _:4 >>) -> - Incomplete; -is_utf8(Incomplete = << 2#1110:4, _:4, 2#10:2, _:6 >>) -> - Incomplete; -%% 4 bytes. Codepage F4 may have invalid values greater than 0x10FFFF. -is_utf8(<< 2#11110100:8, 2#10:2, High:6, _/bits >>) when High >= 2#10000 -> - false; -is_utf8(Incomplete = << 2#11110:5, _:3 >>) -> - Incomplete; -is_utf8(Incomplete = << 2#11110:5, _:3, 2#10:2, _:6 >>) -> - Incomplete; -is_utf8(Incomplete = << 2#11110:5, _:3, 2#10:2, _:6, 2#10:2, _:6 >>) -> - Incomplete; -%% Invalid. -is_utf8(_) -> - false. +websocket_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, recv_extensions=Extensions}, + Req, HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data) -> + case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, Type, Len, FragState, Extensions, Rsv) of + {ok, Payload, Utf8State, Rest} -> + websocket_dispatch(State#state{utf8_state=Utf8State}, + Req, HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode, Rest); + {more, Payload, Utf8State} -> + websocket_payload_loop(State#state{utf8_state=Utf8State}, + Req, HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode, + << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data)); + error -> + websocket_close(State, Req, HandlerState, {error, badencoding}) + end. --spec websocket_payload_loop(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), - non_neg_integer(), rsv()) - -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} - when Req::cowboy_req:req(). websocket_payload_loop(State=#state{socket=Socket, transport=Transport, messages={OK, Closed, Error}, timeout_ref=TRef}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) -> + Req, HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) -> Transport:setopts(Socket, [{active, once}]), receive {OK, Socket, Data} -> State2 = handler_loop_timeout(State), websocket_payload(State2, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv); + Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data); {Closed, Socket} -> handler_terminate(State, Req, HandlerState, {error, closed}); {Error, Socket, Reason} -> @@ -514,53 +260,46 @@ websocket_payload_loop(State=#state{socket=Socket, transport=Transport, websocket_close(State, Req, HandlerState, timeout); {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> websocket_payload_loop(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv); + Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen); Message -> handler_call(State, Req, HandlerState, <<>>, websocket_info, Message, fun (State2, Req2, HandlerState2, _) -> websocket_payload_loop(State2, Req2, HandlerState2, - Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) + Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) end) end. --spec websocket_dispatch(#state{}, Req, any(), binary(), opcode(), binary()) - -> {ok, Req, cowboy_middleware:env()} - | {suspend, module(), atom(), [any()]} - when Req::cowboy_req:req(). %% Continuation frame. -websocket_dispatch(State=#state{frag_state={nofin, Opcode, SoFar}}, - Req, HandlerState, RemainingData, 0, Payload) -> - websocket_data(State#state{frag_state={nofin, Opcode, - << SoFar/binary, Payload/binary >>}}, Req, HandlerState, RemainingData); +websocket_dispatch(State=#state{frag_state={nofin, _}, frag_buffer=SoFar}, + Req, HandlerState, fragment, Payload, _, RemainingData) -> + websocket_data(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, Req, HandlerState, RemainingData); %% Last continuation frame. -websocket_dispatch(State=#state{frag_state={fin, Opcode, SoFar}}, - Req, HandlerState, RemainingData, 0, Payload) -> - websocket_dispatch(State#state{frag_state=undefined}, Req, HandlerState, - RemainingData, Opcode, << SoFar/binary, Payload/binary >>); +websocket_dispatch(State=#state{frag_state={fin, Type}, frag_buffer=SoFar}, + Req, HandlerState, fragment, Payload, CloseCode, RemainingData) -> + websocket_dispatch(State#state{frag_state=undefined, frag_buffer= <<>>}, Req, HandlerState, + Type, << SoFar/binary, Payload/binary >>, CloseCode, RemainingData); %% Text frame. -websocket_dispatch(State, Req, HandlerState, RemainingData, 1, Payload) -> +websocket_dispatch(State, Req, HandlerState, text, Payload, _, RemainingData) -> handler_call(State, Req, HandlerState, RemainingData, websocket_handle, {text, Payload}, fun websocket_data/4); %% Binary frame. -websocket_dispatch(State, Req, HandlerState, RemainingData, 2, Payload) -> +websocket_dispatch(State, Req, HandlerState, binary, Payload, _, RemainingData) -> handler_call(State, Req, HandlerState, RemainingData, websocket_handle, {binary, Payload}, fun websocket_data/4); %% Close control frame. -websocket_dispatch(State, Req, HandlerState, _RemainingData, 8, <<>>) -> +websocket_dispatch(State, Req, HandlerState, close, _, undefined, _) -> websocket_close(State, Req, HandlerState, remote); -websocket_dispatch(State, Req, HandlerState, _RemainingData, 8, - << Code:16, Payload/bits >>) -> +websocket_dispatch(State, Req, HandlerState, close, Payload, Code, _) -> websocket_close(State, Req, HandlerState, {remote, Code, Payload}); %% Ping control frame. Send a pong back and forward the ping to the handler. -websocket_dispatch(State=#state{socket=Socket, transport=Transport}, - Req, HandlerState, RemainingData, 9, Payload) -> - Len = payload_length_to_binary(byte_size(Payload)), - Transport:send(Socket, << 1:1, 0:3, 10:4, 0:1, Len/bits, Payload/binary >>), +websocket_dispatch(State=#state{socket=Socket, transport=Transport, send_extensions=Extensions}, + Req, HandlerState, ping, Payload, _, RemainingData) -> + Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)), handler_call(State, Req, HandlerState, RemainingData, websocket_handle, {ping, Payload}, fun websocket_data/4); %% Pong control frame. -websocket_dispatch(State, Req, HandlerState, RemainingData, 10, Payload) -> +websocket_dispatch(State, Req, HandlerState, pong, Payload, _, RemainingData) -> handler_call(State, Req, HandlerState, RemainingData, websocket_handle, {pong, Payload}, fun websocket_data/4). @@ -579,42 +318,42 @@ handler_call(State=#state{handler=Handler}, Req, HandlerState, {reply, Payload, Req2, HandlerState2} when is_list(Payload) -> case websocket_send_many(Payload, State) of - {ok, State2} -> - NextState(State2, Req2, HandlerState2, RemainingData); - {stop, State2} -> - handler_terminate(State2, Req2, HandlerState2, stop); - {{error, _} = Error, State2} -> - handler_terminate(State2, Req2, HandlerState2, Error) + ok -> + NextState(State, Req2, HandlerState2, RemainingData); + stop -> + handler_terminate(State, Req2, HandlerState2, stop); + Error = {error, _} -> + handler_terminate(State, Req2, HandlerState2, Error) end; {reply, Payload, Req2, HandlerState2, hibernate} when is_list(Payload) -> case websocket_send_many(Payload, State) of - {ok, State2} -> - NextState(State2#state{hibernate=true}, + ok -> + NextState(State#state{hibernate=true}, Req2, HandlerState2, RemainingData); - {stop, State2} -> - handler_terminate(State2, Req2, HandlerState2, stop); - {{error, _} = Error, State2} -> - handler_terminate(State2, Req2, HandlerState2, Error) + stop -> + handler_terminate(State, Req2, HandlerState2, stop); + Error = {error, _} -> + handler_terminate(State, Req2, HandlerState2, Error) end; {reply, Payload, Req2, HandlerState2} -> case websocket_send(Payload, State) of - {ok, State2} -> - NextState(State2, Req2, HandlerState2, RemainingData); - {stop, State2} -> - handler_terminate(State2, Req2, HandlerState2, stop); - {{error, _} = Error, State2} -> - handler_terminate(State2, Req2, HandlerState2, Error) + ok -> + NextState(State, Req2, HandlerState2, RemainingData); + stop -> + handler_terminate(State, Req2, HandlerState2, stop); + Error = {error, _} -> + handler_terminate(State, Req2, HandlerState2, Error) end; {reply, Payload, Req2, HandlerState2, hibernate} -> case websocket_send(Payload, State) of - {ok, State2} -> - NextState(State2#state{hibernate=true}, + ok -> + NextState(State#state{hibernate=true}, Req2, HandlerState2, RemainingData); - {stop, State2} -> - handler_terminate(State2, Req2, HandlerState2, stop); - {{error, _} = Error, State2} -> - handler_terminate(State2, Req2, HandlerState2, Error) + stop -> + handler_terminate(State, Req2, HandlerState2, stop); + Error = {error, _} -> + handler_terminate(State, Req2, HandlerState2, Error) end; {stop, Req2, HandlerState2} -> websocket_close(State, Req2, HandlerState2, stop) @@ -630,103 +369,44 @@ handler_call(State=#state{handler=Handler}, Req, HandlerState, ]) end. -websocket_opcode(text) -> 1; -websocket_opcode(binary) -> 2; -websocket_opcode(close) -> 8; -websocket_opcode(ping) -> 9; -websocket_opcode(pong) -> 10. - --spec websocket_deflate_frame(opcode(), binary(), #state{}) -> - {binary(), rsv(), #state{}}. -websocket_deflate_frame(Opcode, Payload, - State=#state{deflate_frame = DeflateFrame}) - when DeflateFrame =:= false orelse Opcode >= 8 -> - {Payload, << 0:3 >>, State}; -websocket_deflate_frame(_, Payload, State=#state{deflate_state = Deflate}) -> - Deflated = iolist_to_binary(zlib:deflate(Deflate, Payload, sync)), - DeflatedBodyLength = erlang:size(Deflated) - 4, - Deflated1 = case Deflated of - << Body:DeflatedBodyLength/binary, 0:8, 0:8, 255:8, 255:8 >> -> Body; - _ -> Deflated - end, - {Deflated1, << 1:1, 0:2 >>, State}. - --spec websocket_send(frame(), #state{}) --> {ok, #state{}} | {stop, #state{}} | {{error, atom()}, #state{}}. -websocket_send(Type = close, State=#state{socket=Socket, transport=Transport}) -> - Opcode = websocket_opcode(Type), - case Transport:send(Socket, << 1:1, 0:3, Opcode:4, 0:8 >>) of - ok -> {stop, State}; - Error -> {Error, State} - end; -websocket_send(Type, State=#state{socket=Socket, transport=Transport}) - when Type =:= ping; Type =:= pong -> - Opcode = websocket_opcode(Type), - {Transport:send(Socket, << 1:1, 0:3, Opcode:4, 0:8 >>), State}; -websocket_send({close, Payload}, State) -> - websocket_send({close, 1000, Payload}, State); -websocket_send({Type = close, StatusCode, Payload}, State=#state{ - socket=Socket, transport=Transport}) -> - Opcode = websocket_opcode(Type), - Len = 2 + iolist_size(Payload), - %% Control packets must not be > 125 in length. - true = Len =< 125, - BinLen = payload_length_to_binary(Len), - Transport:send(Socket, - [<< 1:1, 0:3, Opcode:4, 0:1, BinLen/bits, StatusCode:16 >>, Payload]), - {stop, State}; -websocket_send({Type, Payload0}, State=#state{socket=Socket, transport=Transport}) -> - Opcode = websocket_opcode(Type), - {Payload, Rsv, State2} = websocket_deflate_frame(Opcode, iolist_to_binary(Payload0), State), - Len = iolist_size(Payload), - %% Control packets must not be > 125 in length. - true = if Type =:= ping; Type =:= pong -> - Len =< 125; - true -> - true - end, - BinLen = payload_length_to_binary(Len), - {Transport:send(Socket, - [<< 1:1, Rsv/bits, Opcode:4, 0:1, BinLen/bits >>, Payload]), State2}. - --spec payload_length_to_binary(0..16#7fffffffffffffff) - -> << _:7 >> | << _:23 >> | << _:71 >>. -payload_length_to_binary(N) -> - case N of - N when N =< 125 -> << N:7 >>; - N when N =< 16#ffff -> << 126:7, N:16 >>; - N when N =< 16#7fffffffffffffff -> << 127:7, N:64 >> +-spec websocket_send(cow_ws:frame(), #state{}) -> ok | stop | {error, atom()}. +websocket_send(Frame, #state{socket=Socket, transport=Transport, send_extensions=Extensions}) -> + Res = Transport:send(Socket, cow_ws:frame(Frame, Extensions)), + case Frame of + close -> stop; + {close, _} -> stop; + {close, _, _} -> stop; + _ -> Res end. --spec websocket_send_many([frame()], #state{}) - -> {ok, #state{}} | {stop, #state{}} | {{error, atom()}, #state{}}. -websocket_send_many([], State) -> - {ok, State}; +-spec websocket_send_many([cow_ws:frame()], #state{}) -> ok | stop | {error, atom()}. +websocket_send_many([], _) -> + ok; websocket_send_many([Frame|Tail], State) -> case websocket_send(Frame, State) of - {ok, State2} -> websocket_send_many(Tail, State2); - {stop, State2} -> {stop, State2}; - {Error, State2} -> {Error, State2} + ok -> websocket_send_many(Tail, State); + stop -> stop; + Error -> Error end. -spec websocket_close(#state{}, Req, any(), terminate_reason()) -> {ok, Req, cowboy_middleware:env()} when Req::cowboy_req:req(). -websocket_close(State=#state{socket=Socket, transport=Transport}, +websocket_close(State=#state{socket=Socket, transport=Transport, send_extensions=Extensions}, Req, HandlerState, Reason) -> case Reason of Normal when Normal =:= stop; Normal =:= timeout -> - Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1000:16 >>); + Transport:send(Socket, cow_ws:frame({close, 1000, <<>>}, Extensions)); {error, badframe} -> - Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1002:16 >>); + Transport:send(Socket, cow_ws:frame({close, 1002, <<>>}, Extensions)); {error, badencoding} -> - Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1007:16 >>); + Transport:send(Socket, cow_ws:frame({close, 1007, <<>>}, Extensions)); {crash, _, _} -> - Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, 1011:16 >>); + Transport:send(Socket, cow_ws:frame({close, 1011, <<>>}, Extensions)); remote -> - Transport:send(Socket, << 1:1, 0:3, 8:4, 0:8 >>); + Transport:send(Socket, cow_ws:frame(close, Extensions)); {remote, Code, _} -> - Transport:send(Socket, << 1:1, 0:3, 8:4, 0:1, 2:7, Code:16 >>) + Transport:send(Socket, cow_ws:frame({close, Code, <<>>}, Extensions)) end, handler_terminate(State, Req, HandlerState, Reason). |