diff options
Diffstat (limited to 'src/cowboy_websocket.erl')
-rw-r--r-- | src/cowboy_websocket.erl | 101 |
1 files changed, 56 insertions, 45 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index df50162..073d7c6 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -57,9 +57,8 @@ frag_state = undefined :: frag_state(), utf8_state = <<>> :: binary(), deflate_frame = false :: boolean(), - inflate_state :: any(), - inflate_buffer = <<>> :: binary(), - deflate_state :: any() + inflate_state :: undefined | port(), + deflate_state :: undefined | port() }). %% @doc Upgrade an HTTP request to the Websocket protocol. @@ -121,7 +120,6 @@ websocket_extensions(State, Req) -> {ok, State#state{ deflate_frame = true, inflate_state = Inflate, - inflate_buffer = <<>>, deflate_state = Deflate }, Req2}; _ -> @@ -331,45 +329,49 @@ websocket_data(State, Req, HandlerState, Data) -> websocket_data(State=#state{frag_state=undefined}, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 0) -> websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv); + Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Subsequent frame fragments. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 0) -> websocket_payload(State, Req, HandlerState, - 0, Len, MaskKey, <<>>, Data, Rsv); + 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Final frame fragment. websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}}, Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) -> websocket_payload(State#state{frag_state={fin, Opcode, SoFar}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv); + Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Unfragmented frame. websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) -> websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, <<>>, Data, Rsv). + Opcode, Len, MaskKey, <<>>, 0, Data, Rsv). -spec websocket_payload(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), binary(), rsv()) + opcode(), non_neg_integer(), mask_key(), binary(), non_neg_integer(), + binary(), rsv()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). %% Close control frames with a payload MUST contain a valid close code. websocket_payload(State, Req, HandlerState, - Opcode=8, Len, MaskKey, <<>>, << MaskedCode:2/binary, Rest/bits >>, Rsv) -> + Opcode=8, Len, MaskKey, <<>>, 0, + << MaskedCode:2/binary, Rest/bits >>, Rsv) -> Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>), if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006; (Code > 1011) and (Code < 3000); Code > 4999 -> websocket_close(State, Req, HandlerState, {error, badframe}); true -> websocket_payload(State, Req, HandlerState, - Opcode, Len - 2, MaskKey, Unmasked, Rest, Rsv) + Opcode, Len - 2, MaskKey, Unmasked, byte_size(MaskedCode), + Rest, Rsv) end; %% Text frames and close control frames MUST have a payload that is valid UTF-8. websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse ((Opcode =:= 8) andalso (Unmasked =/= <<>>))) -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of false -> @@ -377,14 +379,16 @@ websocket_payload(State=#state{utf8_state=Incomplete}, Utf8State -> websocket_payload_loop(State2#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, Rsv) + << Unmasked/binary, Unmasked3/binary >>, + UnmaskedLen + byte_size(Data), Rsv) end; websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of <<>> -> @@ -396,10 +400,11 @@ websocket_payload(State=#state{utf8_state=Incomplete}, end; %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints. websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when byte_size(Data) < Len -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of false -> @@ -407,14 +412,16 @@ websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, Utf8State -> websocket_payload_loop(State2#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, Rsv) + << Unmasked/binary, Unmasked3/binary >>, + UnmaskedLen + byte_size(Data), Rsv) end; websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) -> + Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), + {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, Fin =:= fin, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of <<>> -> websocket_dispatch(State2#state{utf8_state= <<>>}, @@ -429,20 +436,23 @@ websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, end; %% Other frames have a binary payload. websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) when byte_size(Data) < Len -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), websocket_payload_loop(State2, Req, HandlerState, - Opcode, Len - byte_size(Data), MaskKey, Unmasked3, Rsv); + Opcode, Len - byte_size(Data), MaskKey, + << Unmasked/binary, Unmasked3/binary >>, UnmaskedLen + byte_size(Data), + Rsv); websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv) -> + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), - websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, Unmasked3). + websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, + << Unmasked/binary, Unmasked3/binary >>). -spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) -> {binary(), #state{}}. @@ -450,14 +460,13 @@ websocket_inflate_frame(Data, << Rsv1:1, _:2 >>, _, #state{deflate_frame = DeflateFrame} = State) when DeflateFrame =:= false orelse Rsv1 =:= 0 -> {Data, State}; -websocket_inflate_frame(Data, << 1:1, _:2 >>, false, - #state{inflate_buffer = Buffer} = State) -> - {<<>>, State#state{inflate_buffer = << Buffer/binary, Data/binary >>}}; -websocket_inflate_frame(Data, << 1:1, _:2 >>, true, - #state{inflate_state = Inflate, inflate_buffer = Buffer} = State) -> - Deflated = << Buffer/binary, Data/binary, 0:8, 0:8, 255:8, 255:8 >>, - Result = zlib:inflate(Inflate, Deflated), - {iolist_to_binary(Result), State#state{inflate_buffer = <<>>}}. +websocket_inflate_frame(Data, << 1:1, _:2 >>, false, State) -> + Result = zlib:inflate(State#state.inflate_state, Data), + {iolist_to_binary(Result), State}; +websocket_inflate_frame(Data, << 1:1, _:2 >>, true, State) -> + Result = zlib:inflate(State#state.inflate_state, + << Data/binary, 0:8, 0:8, 255:8, 255:8 >>), + {iolist_to_binary(Result), State}. -spec websocket_unmask(B, mask_key(), B) -> B when B::binary(). websocket_unmask(<<>>, _, Unmasked) -> @@ -516,19 +525,20 @@ is_utf8(_) -> false. -spec websocket_payload_loop(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), rsv()) + opcode(), non_neg_integer(), mask_key(), binary(), + non_neg_integer(), rsv()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). websocket_payload_loop(State=#state{socket=Socket, transport=Transport, messages={OK, Closed, Error}, timeout_ref=TRef}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Rsv) -> + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) -> Transport:setopts(Socket, [{active, once}]), receive {OK, Socket, Data} -> State2 = handler_loop_timeout(State), websocket_payload(State2, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv); + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv); {Closed, Socket} -> handler_terminate(State, Req, HandlerState, {error, closed}); {Error, Socket, Reason} -> @@ -537,13 +547,13 @@ websocket_payload_loop(State=#state{socket=Socket, transport=Transport, websocket_close(State, Req, HandlerState, {normal, timeout}); {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> websocket_payload_loop(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Rsv); + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv); Message -> handler_call(State, Req, HandlerState, <<>>, websocket_info, Message, fun (State2, Req2, HandlerState2, _) -> websocket_payload_loop(State2, Req2, HandlerState2, - Opcode, Len, MaskKey, Unmasked, Rsv) + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) end) end. @@ -665,19 +675,20 @@ websocket_opcode(close) -> 8; websocket_opcode(ping) -> 9; websocket_opcode(pong) -> 10. --spec websocket_deflate_frame(opcode(), binary(), #state{}) -> {binary(), <<_:3>>, #state{}}. +-spec websocket_deflate_frame(opcode(), binary(), #state{}) -> + {binary(), rsv(), #state{}}. websocket_deflate_frame(Opcode, Payload, State=#state{deflate_frame = DeflateFrame}) when DeflateFrame =:= false orelse Opcode >= 8 -> - {Payload, <<0:3>>, State}; + {Payload, << 0:3 >>, State}; websocket_deflate_frame(_, Payload, State=#state{deflate_state = Deflate}) -> Deflated = iolist_to_binary(zlib:deflate(Deflate, Payload, sync)), DeflatedBodyLength = erlang:size(Deflated) - 4, Deflated1 = case Deflated of - <<Body:DeflatedBodyLength/binary, 0:8, 0:8, 255:8, 255:8>> -> Body; + << Body:DeflatedBodyLength/binary, 0:8, 0:8, 255:8, 255:8 >> -> Body; _ -> Deflated end, - {Deflated1, <<1:1, 0:2>>, State}. + {Deflated1, << 1:1, 0:2 >>, State}. -spec websocket_send(frame(), #state{}) -> {ok, #state{}} | {shutdown, #state{}} | {{error, atom()}, #state{}}. |