From 4edef3c45c70051dc5e8d49a1308e86b73be6e8f Mon Sep 17 00:00:00 2001 From: Ali Sabil Date: Tue, 2 Jul 2013 10:58:12 +0200 Subject: Remove usage of the inflate buffer --- src/cowboy_websocket.erl | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) (limited to 'src') diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index df50162..c9220f0 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -57,9 +57,8 @@ frag_state = undefined :: frag_state(), utf8_state = <<>> :: binary(), deflate_frame = false :: boolean(), - inflate_state :: any(), - inflate_buffer = <<>> :: binary(), - deflate_state :: any() + inflate_state :: undefined | port(), + deflate_state :: undefined | port() }). %% @doc Upgrade an HTTP request to the Websocket protocol. @@ -121,7 +120,6 @@ websocket_extensions(State, Req) -> {ok, State#state{ deflate_frame = true, inflate_state = Inflate, - inflate_buffer = <<>>, deflate_state = Deflate }, Req2}; _ -> @@ -450,14 +448,13 @@ websocket_inflate_frame(Data, << Rsv1:1, _:2 >>, _, #state{deflate_frame = DeflateFrame} = State) when DeflateFrame =:= false orelse Rsv1 =:= 0 -> {Data, State}; -websocket_inflate_frame(Data, << 1:1, _:2 >>, false, - #state{inflate_buffer = Buffer} = State) -> - {<<>>, State#state{inflate_buffer = << Buffer/binary, Data/binary >>}}; -websocket_inflate_frame(Data, << 1:1, _:2 >>, true, - #state{inflate_state = Inflate, inflate_buffer = Buffer} = State) -> - Deflated = << Buffer/binary, Data/binary, 0:8, 0:8, 255:8, 255:8 >>, - Result = zlib:inflate(Inflate, Deflated), - {iolist_to_binary(Result), State#state{inflate_buffer = <<>>}}. +websocket_inflate_frame(Data, << 1:1, _:2 >>, false, State) -> + Result = zlib:inflate(State#state.inflate_state, Data), + {iolist_to_binary(Result), State}; +websocket_inflate_frame(Data, << 1:1, _:2 >>, true, State) -> + Result = zlib:inflate(State#state.inflate_state, + << Data/binary, 0:8, 0:8, 255:8, 255:8 >>), + {iolist_to_binary(Result), State}. -spec websocket_unmask(B, mask_key(), B) -> B when B::binary(). websocket_unmask(<<>>, _, Unmasked) -> -- cgit v1.2.3 From 373f2e8134a5c583fc5179fbd76986a5f6db8a37 Mon Sep 17 00:00:00 2001 From: Ali Sabil Date: Tue, 2 Jul 2013 11:02:32 +0200 Subject: Fix coding style in websocket_deflate_frame/3 --- src/cowboy_websocket.erl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) (limited to 'src') diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index c9220f0..6c58818 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -662,19 +662,20 @@ websocket_opcode(close) -> 8; websocket_opcode(ping) -> 9; websocket_opcode(pong) -> 10. --spec websocket_deflate_frame(opcode(), binary(), #state{}) -> {binary(), <<_:3>>, #state{}}. +-spec websocket_deflate_frame(opcode(), binary(), #state{}) -> + {binary(), <<_:3>>, #state{}}. websocket_deflate_frame(Opcode, Payload, State=#state{deflate_frame = DeflateFrame}) when DeflateFrame =:= false orelse Opcode >= 8 -> - {Payload, <<0:3>>, State}; + {Payload, << 0:3 >>, State}; websocket_deflate_frame(_, Payload, State=#state{deflate_state = Deflate}) -> Deflated = iolist_to_binary(zlib:deflate(Deflate, Payload, sync)), DeflatedBodyLength = erlang:size(Deflated) - 4, Deflated1 = case Deflated of - <> -> Body; + << Body:DeflatedBodyLength/binary, 0:8, 0:8, 255:8, 255:8 >> -> Body; _ -> Deflated end, - {Deflated1, <<1:1, 0:2>>, State}. + {Deflated1, << 1:1, 0:2 >>, State}. -spec websocket_send(frame(), #state{}) -> {ok, #state{}} | {shutdown, #state{}} | {{error, atom()}, #state{}}. -- cgit v1.2.3 From a3b9438d16eceb61584b1637b2de4d66c0aadfc5 Mon Sep 17 00:00:00 2001 From: Ali Sabil Date: Tue, 2 Jul 2013 11:09:27 +0200 Subject: Fix websocket unmasking when compression is enabled The unmasking logic was based on the length of inflated data instead of the length of the deflated data. This meant data would get corrupted when we receive a websocket frame split across multiple TCP packets. --- src/cowboy_websocket.erl | 69 ++++++++++++++++++++++++++++-------------------- 1 file changed, 41 insertions(+), 28 deletions(-) (limited to 'src') diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 6c58818..918c9e6 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -329,45 +329,49 @@ websocket_data(State, Req, HandlerState, Data) -> websocket_data(State=#state{frag_state=undefined}, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 0) -> websocket_payload(State#state{frag_state={nofin, Opcode, <<>>}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv); + Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Subsequent frame fragments. websocket_data(State=#state{frag_state={nofin, _, _}}, Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 0) -> websocket_payload(State, Req, HandlerState, - 0, Len, MaskKey, <<>>, Data, Rsv); + 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Final frame fragment. websocket_data(State=#state{frag_state={nofin, Opcode, SoFar}}, Req, HandlerState, 0, Len, MaskKey, Data, Rsv, 1) -> websocket_payload(State#state{frag_state={fin, Opcode, SoFar}}, - Req, HandlerState, 0, Len, MaskKey, <<>>, Data, Rsv); + Req, HandlerState, 0, Len, MaskKey, <<>>, 0, Data, Rsv); %% Unfragmented frame. websocket_data(State, Req, HandlerState, Opcode, Len, MaskKey, Data, Rsv, 1) -> websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, <<>>, Data, Rsv). + Opcode, Len, MaskKey, <<>>, 0, Data, Rsv). -spec websocket_payload(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), binary(), rsv()) + opcode(), non_neg_integer(), mask_key(), binary(), non_neg_integer(), + binary(), rsv()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). %% Close control frames with a payload MUST contain a valid close code. websocket_payload(State, Req, HandlerState, - Opcode=8, Len, MaskKey, <<>>, << MaskedCode:2/binary, Rest/bits >>, Rsv) -> + Opcode=8, Len, MaskKey, <<>>, 0, + << MaskedCode:2/binary, Rest/bits >>, Rsv) -> Unmasked = << Code:16 >> = websocket_unmask(MaskedCode, MaskKey, <<>>), if Code < 1000; Code =:= 1004; Code =:= 1005; Code =:= 1006; (Code > 1011) and (Code < 3000); Code > 4999 -> websocket_close(State, Req, HandlerState, {error, badframe}); true -> websocket_payload(State, Req, HandlerState, - Opcode, Len - 2, MaskKey, Unmasked, Rest, Rsv) + Opcode, Len - 2, MaskKey, Unmasked, byte_size(MaskedCode), + Rest, Rsv) end; %% Text frames and close control frames MUST have a payload that is valid UTF-8. websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when (byte_size(Data) < Len) andalso ((Opcode =:= 1) orelse ((Opcode =:= 8) andalso (Unmasked =/= <<>>))) -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of false -> @@ -375,14 +379,16 @@ websocket_payload(State=#state{utf8_state=Incomplete}, Utf8State -> websocket_payload_loop(State2#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, Rsv) + << Unmasked/binary, Unmasked3/binary >>, + UnmaskedLen + byte_size(Data), Rsv) end; websocket_payload(State=#state{utf8_state=Incomplete}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when Opcode =:= 1; (Opcode =:= 8) and (Unmasked =/= <<>>) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of <<>> -> @@ -394,10 +400,11 @@ websocket_payload(State=#state{utf8_state=Incomplete}, end; %% Fragmented text frames may cut payload in the middle of UTF-8 codepoints. websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) + Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) when byte_size(Data) < Len -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of false -> @@ -405,13 +412,15 @@ websocket_payload(State=#state{frag_state={_, 1, _}, utf8_state=Incomplete}, Utf8State -> websocket_payload_loop(State2#state{utf8_state=Utf8State}, Req, HandlerState, Opcode, Len - byte_size(Data), MaskKey, - << Unmasked/binary, Unmasked3/binary >>, Rsv) + << Unmasked/binary, Unmasked3/binary >>, + UnmaskedLen + byte_size(Data), Rsv) end; websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, - Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, Data, Rsv) -> + Req, HandlerState, Opcode=0, Len, MaskKey, Unmasked, UnmaskedLen, + Data, Rsv) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), <<>>), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of <<>> -> @@ -427,20 +436,23 @@ websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, end; %% Other frames have a binary payload. websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv) + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) when byte_size(Data) < Len -> Unmasked2 = websocket_unmask(Data, - rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, false, State), websocket_payload_loop(State2, Req, HandlerState, - Opcode, Len - byte_size(Data), MaskKey, Unmasked3, Rsv); + Opcode, Len - byte_size(Data), MaskKey, + << Unmasked/binary, Unmasked3/binary >>, UnmaskedLen + byte_size(Data), + Rsv); websocket_payload(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv) -> + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv) -> << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, - rotate_mask_key(MaskKey, byte_size(Unmasked)), Unmasked), + rotate_mask_key(MaskKey, UnmaskedLen), <<>>), {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), - websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, Unmasked3). + websocket_dispatch(State2, Req, HandlerState, Rest, Opcode, + << Unmasked/binary, Unmasked3/binary >>). -spec websocket_inflate_frame(binary(), rsv(), boolean(), #state{}) -> {binary(), #state{}}. @@ -513,19 +525,20 @@ is_utf8(_) -> false. -spec websocket_payload_loop(#state{}, Req, any(), - opcode(), non_neg_integer(), mask_key(), binary(), rsv()) + opcode(), non_neg_integer(), mask_key(), binary(), + non_neg_integer(), rsv()) -> {ok, Req, cowboy_middleware:env()} | {suspend, module(), atom(), [any()]} when Req::cowboy_req:req(). websocket_payload_loop(State=#state{socket=Socket, transport=Transport, messages={OK, Closed, Error}, timeout_ref=TRef}, - Req, HandlerState, Opcode, Len, MaskKey, Unmasked, Rsv) -> + Req, HandlerState, Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) -> Transport:setopts(Socket, [{active, once}]), receive {OK, Socket, Data} -> State2 = handler_loop_timeout(State), websocket_payload(State2, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Data, Rsv); + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Data, Rsv); {Closed, Socket} -> handler_terminate(State, Req, HandlerState, {error, closed}); {Error, Socket, Reason} -> @@ -534,13 +547,13 @@ websocket_payload_loop(State=#state{socket=Socket, transport=Transport, websocket_close(State, Req, HandlerState, {normal, timeout}); {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> websocket_payload_loop(State, Req, HandlerState, - Opcode, Len, MaskKey, Unmasked, Rsv); + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv); Message -> handler_call(State, Req, HandlerState, <<>>, websocket_info, Message, fun (State2, Req2, HandlerState2, _) -> websocket_payload_loop(State2, Req2, HandlerState2, - Opcode, Len, MaskKey, Unmasked, Rsv) + Opcode, Len, MaskKey, Unmasked, UnmaskedLen, Rsv) end) end. -- cgit v1.2.3 From 6f0b8804bca12cfe3e8dc5798c819e3f16c6936e Mon Sep 17 00:00:00 2001 From: Ali Sabil Date: Tue, 2 Jul 2013 12:36:26 +0200 Subject: Fix handling of websocket fragmented deflated frames --- src/cowboy_websocket.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 918c9e6..54dbcd9 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -421,7 +421,7 @@ websocket_payload(State=#state{frag_state={Fin, 1, _}, utf8_state=Incomplete}, << End:Len/binary, Rest/bits >> = Data, Unmasked2 = websocket_unmask(End, rotate_mask_key(MaskKey, UnmaskedLen), <<>>), - {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, true, State), + {Unmasked3, State2} = websocket_inflate_frame(Unmasked2, Rsv, Fin =:= fin, State), case is_utf8(<< Incomplete/binary, Unmasked3/binary >>) of <<>> -> websocket_dispatch(State2#state{utf8_state= <<>>}, -- cgit v1.2.3 From c5c9c398ffea608d22c410668b9b79b80b5bdb91 Mon Sep 17 00:00:00 2001 From: Ali Sabil Date: Mon, 8 Jul 2013 09:49:35 +0200 Subject: Use the proper typespec for the websocket_deflate_frame rsv bits --- src/cowboy_websocket.erl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'src') diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 54dbcd9..073d7c6 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -676,7 +676,7 @@ websocket_opcode(ping) -> 9; websocket_opcode(pong) -> 10. -spec websocket_deflate_frame(opcode(), binary(), #state{}) -> - {binary(), <<_:3>>, #state{}}. + {binary(), rsv(), #state{}}. websocket_deflate_frame(Opcode, Payload, State=#state{deflate_frame = DeflateFrame}) when DeflateFrame =:= false orelse Opcode >= 8 -> -- cgit v1.2.3