diff options
Diffstat (limited to 'src/cowboy_websocket.erl')
-rw-r--r-- | src/cowboy_websocket.erl | 69 |
1 files changed, 41 insertions, 28 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 6c58818..918c9e6 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -329,45 +329,49 @@ websocket_data(State, Req, HandlerState, Data) -> websocket_data(State=#state{frag_state=undefined}, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 0) -> websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv); + Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Subsequent frame fragments. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 0) -> websocket_payload(State, Req, HandlerState, - 0, Len, MaskKey, <<>>, Data, Rsv); + 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Final frame fragment. websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}}, Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) -> websocket_payload(State#state{frag_state={fin, Opcode, SoFar}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv); + Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Unfragmented frame. websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) -> websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, <<>>, Data, Rsv). + Opcode, Len, MaskKey, <<>>, 0, Data, Rsv). -spec websocket_payload(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), binary(), rsv()) + opcode(), non_neg_integer(), mask_key(), binary(), non_neg_integer(), + binary(), rsv()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). %% Close control frames with a payload MUST contain a valid close code. websocket_payload(State, Req, HandlerState, - Opcode=8, Len, MaskKey, <<>>, << MaskedCode:2/binary, Rest/bits >>, Rsv) -> + Opcode=8, Len, MaskKey, <<>>, 0, + << MaskedCode:2/binary, Rest/bits >>, Rsv) -> Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>), if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006; (Code > 1011) and (Code < 3000); Code > 4999 -> websocket_close(State, Req, HandlerState, {error, badframe}); true -> websocket_payload(State, Req, HandlerState, - Opcode, Len - 2, MaskKey, Unmasked, Rest, Rsv) + Opcode, Len - 2, MaskKey, Unmasked, byte_size(MaskedCode), + Rest, Rsv) end; %% Text frames and close control frames MUST have a payload that is valid UTF-8. websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse ((Opcode =:= 8) andalso (Unmasked =/= <<>>))) -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of false -> @@ -375,14 +379,16 @@ websocket_payload(State=#state{utf8_state=Incomplete}, Utf8State -> websocket_payload_loop(State2#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, Rsv) + << Unmasked/binary, Unmasked3/binary >>, + UnmaskedLen + byte_size(Data), Rsv) end; websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of <<>> -> @@ -394,10 +400,11 @@ websocket_payload(State=#state{utf8_state=Incomplete}, end; %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints. websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when byte_size(Data) < Len -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of false -> @@ -405,13 +412,15 @@ websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, Utf8State -> websocket_payload_loop(State2#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, Rsv) + << Unmasked/binary, Unmasked3/binary >>, + UnmaskedLen + byte_size(Data), Rsv) end; websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) -> + Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of <<>> -> @@ -427,20 +436,23 @@ websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, end; %% Other frames have a binary payload. websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) when byte_size(Data) < Len -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), websocket_payload_loop(State2, Req, HandlerState, - Opcode, Len - byte_size(Data), MaskKey, Unmasked3, Rsv); + Opcode, Len - byte_size(Data), MaskKey, + << Unmasked/binary, Unmasked3/binary >>, UnmaskedLen + byte_size(Data), + Rsv); websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv) -> + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), - websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, Unmasked3). + websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, + << Unmasked/binary, Unmasked3/binary >>). -spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) -> {binary(), #state{}}. @@ -513,19 +525,20 @@ is_utf8(_) -> false. -spec websocket_payload_loop(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), rsv()) + opcode(), non_neg_integer(), mask_key(), binary(), + non_neg_integer(), rsv()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). websocket_payload_loop(State=#state{socket=Socket, transport=Transport, messages={OK, Closed, Error}, timeout_ref=TRef}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Rsv) -> + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) -> Transport:setopts(Socket, [{active, once}]), receive {OK, Socket, Data} -> State2 = handler_loop_timeout(State), websocket_payload(State2, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv); + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv); {Closed, Socket} -> handler_terminate(State, Req, HandlerState, {error, closed}); {Error, Socket, Reason} -> @@ -534,13 +547,13 @@ websocket_payload_loop(State=#state{socket=Socket, transport=Transport, websocket_close(State, Req, HandlerState, {normal, timeout}); {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> websocket_payload_loop(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Rsv); + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv); Message -> handler_call(State, Req, HandlerState, <<>>, websocket_info, Message, fun (State2, Req2, HandlerState2, _) -> websocket_payload_loop(State2, Req2, HandlerState2, - Opcode, Len, MaskKey, Unmasked, Rsv) + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) end) end. |