From c50d6aa09c9028dca3365516d30f1242cfd43306 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= Date: Sat, 5 Oct 2019 13:04:21 +0200 Subject: Don't discard data following a Websocket upgrade request While the protocol does not allow sending data before receiving a successful Websocket upgrade response, we do not want to discard that data if it does come in. --- src/cowboy_http.erl | 100 +++++++++++++++++++++++------------------------ src/cowboy_websocket.erl | 8 +++- test/sys_SUITE.erl | 6 +-- test/ws_SUITE.erl | 27 +++++++++++-- 4 files changed, 81 insertions(+), 60 deletions(-) diff --git a/src/cowboy_http.erl b/src/cowboy_http.erl index 021657a..b4a6995 100644 --- a/src/cowboy_http.erl +++ b/src/cowboy_http.erl @@ -111,6 +111,7 @@ transport :: module(), proxy_header :: undefined | ranch_proxy_header:proxy_info(), opts = #{} :: cowboy:opts(), + buffer = <<>> :: binary(), %% Some options may be overriden for the current stream. overriden_opts = #{} :: cowboy:opts(), @@ -175,7 +176,7 @@ init(Parent, Ref, Socket, Transport, ProxyHeader, Opts) -> parent=Parent, ref=Ref, socket=Socket, transport=Transport, proxy_header=ProxyHeader, opts=Opts, peer=Peer, sock=Sock, cert=Cert, - last_streamid=LastStreamID}), <<>>); + last_streamid=LastStreamID})); {{error, Reason}, _, _} -> terminate(undefined, {socket_error, Reason, 'A socket error occurred when retrieving the peer name.'}); @@ -187,22 +188,22 @@ init(Parent, Ref, Socket, Transport, ProxyHeader, Opts) -> 'A socket error occurred when retrieving the client TLS certificate.'}) end. -before_loop(State=#state{socket=Socket, transport=Transport}, Buffer) -> +before_loop(State=#state{socket=Socket, transport=Transport}) -> %% @todo disable this when we get to the body, until the stream asks for it? %% Perhaps have a threshold for how much we're willing to read before waiting. Transport:setopts(Socket, [{active, once}]), - loop(State, Buffer). + loop(State). loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts, - timer=TimerRef, children=Children, in_streamid=InStreamID, - last_streamid=LastStreamID, streams=Streams}, Buffer) -> + buffer=Buffer, timer=TimerRef, children=Children, in_streamid=InStreamID, + last_streamid=LastStreamID, streams=Streams}) -> Messages = Transport:messages(), InactivityTimeout = maps:get(inactivity_timeout, Opts, 300000), receive %% Discard data coming in after the last request %% we want to process was received fully. {OK, Socket, _} when OK =:= element(1, Messages), InStreamID > LastStreamID -> - before_loop(State, Buffer); + before_loop(State); %% Socket messages. {OK, Socket, Data} when OK =:= element(1, Messages) -> %% Only reset the timeout if it is idle_timeout (active streams). @@ -218,30 +219,30 @@ loop(State=#state{parent=Parent, socket=Socket, transport=Transport, opts=Opts, %% Timeouts. {timeout, Ref, {shutdown, Pid}} -> cowboy_children:shutdown_timeout(Children, Ref, Pid), - loop(State, Buffer); + loop(State); {timeout, TimerRef, Reason} -> timeout(State, Reason); {timeout, _, _} -> - loop(State, Buffer); + loop(State); %% System messages. {'EXIT', Parent, Reason} -> terminate(State, {stop, {exit, Reason}, 'Parent process terminated.'}); {system, From, Request} -> - sys:handle_system_msg(Request, From, Parent, ?MODULE, [], {State, Buffer}); + sys:handle_system_msg(Request, From, Parent, ?MODULE, [], State); %% Messages pertaining to a stream. {{Pid, StreamID}, Msg} when Pid =:= self() -> - loop(info(State, StreamID, Msg), Buffer); + loop(info(State, StreamID, Msg)); %% Exit signal from children. Msg = {'EXIT', Pid, _} -> - loop(down(State, Pid, Msg), Buffer); + loop(down(State, Pid, Msg)); %% Calls from supervisor module. {'$gen_call', From, Call} -> cowboy_children:handle_supervisor_call(Call, From, Children, ?MODULE), - loop(State, Buffer); + loop(State); %% Unknown messages. Msg -> cowboy:log(warning, "Received stray message ~p.~n", [Msg], Opts), - loop(State, Buffer) + loop(State) after InactivityTimeout -> terminate(State, {internal_error, timeout, 'No message or data received before timeout.'}) end. @@ -293,12 +294,12 @@ timeout(State, idle_timeout) -> 'Connection idle longer than configuration allows.'}). parse(<<>>, State) -> - before_loop(State, <<>>); + before_loop(State#state{buffer= <<>>}); %% Do not process requests that come in after the last request %% and discard the buffer if any to save memory. parse(_, State=#state{in_streamid=InStreamID, in_state=#ps_request_line{}, last_streamid=LastStreamID}) when InStreamID > LastStreamID -> - before_loop(State, <<>>); + before_loop(State#state{buffer= <<>>}); parse(Buffer, State=#state{in_state=#ps_request_line{empty_lines=EmptyLines}}) -> after_parse(parse_request(Buffer, State, EmptyLines)); parse(Buffer, State=#state{in_state=PS=#ps_header{headers=Headers, name=undefined}}) -> @@ -317,7 +318,7 @@ parse(Buffer, State=#state{in_state=#ps_body{}}) -> after_parse({request, Req=#{streamid := StreamID, method := Method, headers := Headers, version := Version}, - State0=#state{opts=Opts, streams=Streams0}, Buffer}) -> + State0=#state{opts=Opts, buffer=Buffer, streams=Streams0}}) -> try cowboy_stream:init(StreamID, Req, Opts) of {Commands, StreamState} -> TE = maps:get(<<"te">>, Headers, undefined), @@ -339,8 +340,8 @@ after_parse({request, Req=#{streamid := StreamID, method := Method, end; %% Streams are sequential so the body is always about the last stream created %% unless that stream has terminated. -after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts, - streams=Streams0=[Stream=#stream{id=StreamID, state=StreamState0}|_]}, Buffer}) -> +after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts, buffer=Buffer, + streams=Streams0=[Stream=#stream{id=StreamID, state=StreamState0}|_]}}) -> try cowboy_stream:data(StreamID, IsFin, Data, StreamState0) of {Commands, StreamState} -> Streams = lists:keyreplace(StreamID, #stream.id, Streams0, @@ -355,17 +356,17 @@ after_parse({data, StreamID, IsFin, Data, State=#state{opts=Opts, end; %% No corresponding stream. We must skip the body of the previous request %% in order to process the next one. -after_parse({data, _, _, _, State, Buffer}) -> - before_loop(State, Buffer); -after_parse({more, State, Buffer}) -> - before_loop(State, Buffer). +after_parse({data, _, _, _, State}) -> + before_loop(State); +after_parse({more, State}) -> + before_loop(State). %% Request-line. -spec parse_request(Buffer, State, non_neg_integer()) - -> {request, cowboy_req:req(), State, Buffer} - | {data, cowboy_stream:streamid(), cowboy_stream:fin(), binary(), State, Buffer} - | {more, State, Buffer} + -> {request, cowboy_req:req(), State} + | {data, cowboy_stream:streamid(), cowboy_stream:fin(), binary(), State} + | {more, State} when Buffer::binary(), State::#state{}. %% Empty lines must be using \r\n. parse_request(<< $\n, _/bits >>, State, _) -> @@ -384,7 +385,7 @@ parse_request(Buffer, State=#state{opts=Opts, in_streamid=InStreamID}, EmptyLine error_terminate(414, State, {connection_error, limit_reached, 'The request-line length is larger than configuration allows. (RFC7230 3.1.1)'}); nomatch -> - {more, State#state{in_state=#ps_request_line{empty_lines=EmptyLines}}, Buffer}; + {more, State#state{buffer=Buffer, in_state=#ps_request_line{empty_lines=EmptyLines}}}; 1 when EmptyLines =:= MaxEmptyLines -> error_terminate(400, State, {connection_error, limit_reached, 'More empty lines were received than configuration allows. (RFC7230 3.5)'}); @@ -527,7 +528,7 @@ before_parse_headers(Rest, State, M, A, P, Q, V) -> %% We need two or more bytes in the buffer to continue. parse_header(Rest, State=#state{in_state=PS}, Headers) when byte_size(Rest) < 2 -> - {more, State#state{in_state=PS#ps_header{headers=Headers}}, Rest}; + {more, State#state{buffer=Rest, in_state=PS#ps_header{headers=Headers}}}; parse_header(<< $\r, $\n, Rest/bits >>, S, Headers) -> request(Rest, S, Headers); parse_header(Buffer, State=#state{opts=Opts, in_state=PS}, Headers) -> @@ -554,7 +555,7 @@ parse_header_colon(Buffer, State=#state{opts=Opts, in_state=PS}, Headers) -> %% so check if we have an LF and abort with an error if we do. case match_eol(Buffer, 0) of nomatch -> - {more, State#state{in_state=PS#ps_header{headers=Headers}}, Buffer}; + {more, State#state{buffer=Buffer, in_state=PS#ps_header{headers=Headers}}}; _ -> error_terminate(400, State#state{in_state=PS#ps_header{headers=Headers}}, {connection_error, protocol_error, @@ -596,7 +597,7 @@ parse_hd_before_value(Buffer, State=#state{opts=Opts, in_state=PS}, H, N) -> {connection_error, limit_reached, 'A header value is larger than configuration allows. (RFC7230 3.2.5, RFC6585 5)'}); nomatch -> - {more, State#state{in_state=PS#ps_header{headers=H, name=N}}, Buffer}; + {more, State#state{buffer=Buffer, in_state=PS#ps_header{headers=H, name=N}}}; _ -> parse_hd_value(Buffer, State, H, N, <<>>) end. @@ -766,7 +767,7 @@ request(Buffer, State0=#state{ref=Ref, transport=Transport, peer=Peer, sock=Sock false -> State0#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}} end, - {request, Req, State, Buffer}; + {request, Req, State#state{buffer=Buffer}}; {true, HTTP2Settings} -> %% We save the headers in case the upgrade will fail %% and we need to pass them to cowboy_stream:early_error. @@ -835,28 +836,28 @@ parse_body(Buffer, State=#state{in_streamid=StreamID, in_state= try TDecode(Buffer, TState0) of more -> %% @todo Asks for 0 or more bytes. - {more, State, Buffer}; + {more, State#state{buffer=Buffer}}; {more, Data, TState} -> %% @todo Asks for 0 or more bytes. - {data, StreamID, nofin, Data, State#state{in_state= - PS#ps_body{received=Received + byte_size(Data), - transfer_decode_state=TState}}, <<>>}; + {data, StreamID, nofin, Data, State#state{buffer= <<>>, + in_state=PS#ps_body{received=Received + byte_size(Data), + transfer_decode_state=TState}}}; {more, Data, _Length, TState} when is_integer(_Length) -> %% @todo Asks for Length more bytes. - {data, StreamID, nofin, Data, State#state{in_state= - PS#ps_body{received=Received + byte_size(Data), - transfer_decode_state=TState}}, <<>>}; + {data, StreamID, nofin, Data, State#state{buffer= <<>>, + in_state=PS#ps_body{received=Received + byte_size(Data), + transfer_decode_state=TState}}}; {more, Data, Rest, TState} -> %% @todo Asks for 0 or more bytes. - {data, StreamID, nofin, Data, State#state{in_state= - PS#ps_body{received=Received + byte_size(Data), - transfer_decode_state=TState}}, Rest}; + {data, StreamID, nofin, Data, State#state{buffer=Rest, + in_state=PS#ps_body{received=Received + byte_size(Data), + transfer_decode_state=TState}}}; {done, _HasTrailers, Rest} -> {data, StreamID, fin, <<>>, set_timeout( - State#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}), Rest}; + State#state{buffer=Rest, in_streamid=StreamID + 1, in_state=#ps_request_line{}})}; {done, Data, _HasTrailers, Rest} -> {data, StreamID, fin, Data, set_timeout( - State#state{in_streamid=StreamID + 1, in_state=#ps_request_line{}}), Rest} + State#state{buffer=Rest, in_streamid=StreamID + 1, in_state=#ps_request_line{}})} catch _:_ -> Reason = {connection_error, protocol_error, 'Failure to decode the content. (RFC7230 4)'}, @@ -1094,7 +1095,7 @@ commands(State=#state{socket=Socket, transport=Transport, streams=Streams, out_s commands(State#state{out_state=done}, StreamID, Tail); %% Protocol takeover. commands(State0=#state{ref=Ref, parent=Parent, socket=Socket, transport=Transport, - out_state=OutState, opts=Opts, children=Children}, StreamID, + out_state=OutState, opts=Opts, buffer=Buffer, children=Children}, StreamID, [{switch_protocol, Headers, Protocol, InitialState}|_Tail]) -> %% @todo This should be the last stream running otherwise we need to wait before switching. %% @todo If there's streams opened after this one, fail instead of 101. @@ -1117,10 +1118,7 @@ commands(State0=#state{ref=Ref, parent=Parent, socket=Socket, transport=Transpor %% Terminate children processes and flush any remaining messages from the mailbox. cowboy_children:terminate(Children), flush(Parent), - %% @todo This is no good because commands return a state normally and here it doesn't - %% we need to let this module go entirely. Perhaps it should be handled directly in - %% cowboy_clear/cowboy_tls? - Protocol:takeover(Parent, Ref, Socket, Transport, Opts, <<>>, InitialState); + Protocol:takeover(Parent, Ref, Socket, Transport, Opts, Buffer, InitialState); %% Set options dynamically. commands(State0=#state{overriden_opts=Opts}, StreamID, [{set_options, SetOpts}|Tail]) -> @@ -1446,12 +1444,12 @@ terminate_linger_loop(State=#state{socket=Socket, transport=Transport}, TimerRef %% System callbacks. --spec system_continue(_, _, {#state{}, binary()}) -> ok. -system_continue(_, _, {State, Buffer}) -> - loop(State, Buffer). +-spec system_continue(_, _, #state{}) -> ok. +system_continue(_, _, State) -> + loop(State). -spec system_terminate(any(), _, _, {#state{}, binary()}) -> no_return(). -system_terminate(Reason, _, _, {State, _}) -> +system_terminate(Reason, _, _, State) -> terminate(State, {stop, {exit, Reason}, 'sys:terminate/2,3 was called.'}). -spec system_code_change(Misc, _, _, _) -> {ok, Misc} when Misc::{#state{}, binary()}. diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 9540b75..5cc061a 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -291,10 +291,14 @@ takeover(Parent, Ref, Socket, Transport, _Opts, Buffer, State = loop_timeout(State0#state{parent=Parent, ref=Ref, socket=Socket, transport=Transport, key=undefined, messages=Messages}), + %% We call parse_header/3 immediately because there might be + %% some data in the buffer that was sent along with the handshake. + %% While it is not allowed by the protocol to send frames immediately, + %% we still want to process that data if any. case erlang:function_exported(Handler, websocket_init, 1) of true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer}, - websocket_init, undefined, fun before_loop/3); - false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer}) + websocket_init, undefined, fun parse_header/3); + false -> parse_header(State, HandlerState, #ps_header{buffer=Buffer}) end. before_loop(State=#state{active=false}, HandlerState, ParseState) -> diff --git a/test/sys_SUITE.erl b/test/sys_SUITE.erl index c7c0e4c..175219c 100644 --- a/test/sys_SUITE.erl +++ b/test/sys_SUITE.erl @@ -602,9 +602,8 @@ sys_get_state_h1(Config) -> {ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []), timer:sleep(100), Pid = get_remote_pid_tcp(Socket), - {State, Buffer} = sys:get_state(Pid), + State = sys:get_state(Pid), state = element(1, State), - true = is_binary(Buffer), ok. sys_get_state_h2(Config) -> @@ -726,9 +725,8 @@ sys_replace_state_h1(Config) -> {ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), []), timer:sleep(100), Pid = get_remote_pid_tcp(Socket), - {State, Buffer} = sys:replace_state(Pid, fun(S) -> S end), + State = sys:replace_state(Pid, fun(S) -> S end), state = element(1, State), - true = is_binary(Buffer), ok. sys_replace_state_h2(Config) -> diff --git a/test/ws_SUITE.erl b/test/ws_SUITE.erl index 4e35ec7..c99830c 100644 --- a/test/ws_SUITE.erl +++ b/test/ws_SUITE.erl @@ -304,6 +304,18 @@ do_ws_deflate_opts_z(Path, Config) -> {error, closed} = gen_tcp:recv(Socket, 0, 6000), ok. +ws_first_frame_with_handshake(Config) -> + doc("Client sends the first frame immediately with the handshake. " + "This is invalid according to the protocol but we still want " + "to accept it if the handshake is successful."), + Mask = 16#37fa213d, + MaskedHello = do_mask(<<"Hello">>, Mask, <<>>), + {ok, Socket, _} = do_handshake("/ws_echo", "", + <<1:1, 0:3, 1:4, 1:1, 5:7, Mask:32, MaskedHello/binary>>, + Config), + {ok, <<1:1, 0:3, 1:4, 0:1, 5:7, "Hello">>} = gen_tcp:recv(Socket, 0, 6000), + ok. + ws_init_return_ok(Config) -> doc("Handler does nothing."), {ok, Socket, _} = do_handshake("/ws_init?ok", Config), @@ -636,9 +648,12 @@ ws_webkit_deflate_single_bytes(Config) -> %% Internal. do_handshake(Path, Config) -> - do_handshake(Path, "", Config). + do_handshake(Path, "", "", Config). do_handshake(Path, ExtraHeaders, Config) -> + do_handshake(Path, ExtraHeaders, "", Config). + +do_handshake(Path, ExtraHeaders, ExtraData, Config) -> {ok, Socket} = gen_tcp:connect("localhost", config(port, Config), [binary, {active, false}]), ok = gen_tcp:send(Socket, [ @@ -650,10 +665,16 @@ do_handshake(Path, ExtraHeaders, Config) -> "Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==\r\n" "Upgrade: websocket\r\n", ExtraHeaders, - "\r\n"]), + "\r\n", + ExtraData]), {ok, Handshake} = gen_tcp:recv(Socket, 0, 6000), {ok, {http_response, {1, 1}, 101, _}, Rest} = erlang:decode_packet(http, Handshake, []), - [Headers, <<>>] = do_decode_headers(erlang:decode_packet(httph, Rest, []), []), + [Headers, Data] = do_decode_headers(erlang:decode_packet(httph, Rest, []), []), + %% Queue extra data back, if any. We don't want to receive it yet. + case Data of + <<>> -> ok; + _ -> gen_tcp:unrecv(Socket, Data) + end, {_, "Upgrade"} = lists:keyfind('Connection', 1, Headers), {_, "websocket"} = lists:keyfind('Upgrade', 1, Headers), {_, "s3pPLMBiTxaQ9kYGzzhZRbK+xOo="} = lists:keyfind("sec-websocket-accept", 1, Headers), -- cgit v1.2.3