diff options
Diffstat (limited to 'src/cowboy_websocket.erl')
-rw-r--r-- | src/cowboy_websocket.erl | 179 |
1 files changed, 132 insertions, 47 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index e7d8f31..65289cd 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -1,4 +1,4 @@ -%% Copyright (c) 2011-2017, Loïc Hoguin <[email protected]> +%% Copyright (c) Loïc Hoguin <[email protected]> %% %% Permission to use, copy, modify, and/or distribute this software for any %% purpose with or without fee is hereby granted, provided that the above @@ -68,7 +68,12 @@ -type opts() :: #{ active_n => pos_integer(), compress => boolean(), + data_delivery => stream_handlers | relay, + data_delivery_flow => pos_integer(), deflate_opts => cow_ws:deflate_opts(), + dynamic_buffer => false | {pos_integer(), pos_integer()}, + dynamic_buffer_initial_average => non_neg_integer(), + dynamic_buffer_initial_size => pos_integer(), idle_timeout => timeout(), max_frame_size => non_neg_integer() | infinity, req_filter => fun((cowboy_req:req()) -> map()), @@ -76,18 +81,32 @@ }. -export_type([opts/0]). +%% We don't want to reset the idle timeout too often, +%% so we don't reset it on data. Instead we reset the +%% number of ticks we have observed. We divide the +%% timeout value by a value and that value becomes +%% the number of ticks at which point we can drop +%% the connection. This value is the number of ticks. +-define(IDLE_TIMEOUT_TICKS, 10). + -record(state, { parent :: undefined | pid(), ref :: ranch:ref(), socket = undefined :: inet:socket() | {pid(), cowboy_stream:streamid()} | undefined, - transport = undefined :: module() | undefined, + transport :: module() | {data_delivery, stream_handlers | relay}, opts = #{} :: opts(), active = true :: boolean(), handler :: module(), key = undefined :: undefined | binary(), timeout_ref = undefined :: undefined | reference(), + timeout_num = 0 :: 0..?IDLE_TIMEOUT_TICKS, messages = undefined :: undefined | {atom(), atom(), atom()} | {atom(), atom(), atom(), atom()}, + + %% Dynamic buffer moving average and current buffer size. + dynamic_buffer_size = false :: pos_integer() | false, + dynamic_buffer_moving_average = 0 :: non_neg_integer(), + hibernate = false :: boolean(), frag_state = undefined :: cow_ws:frag_state(), frag_buffer = <<>> :: binary(), @@ -103,7 +122,8 @@ %% is trying to upgrade to the Websocket protocol. -spec is_upgrade_request(cowboy_req:req()) -> boolean(). -is_upgrade_request(#{version := 'HTTP/2', method := <<"CONNECT">>, protocol := Protocol}) -> +is_upgrade_request(#{version := Version, method := <<"CONNECT">>, protocol := Protocol}) + when Version =:= 'HTTP/2'; Version =:= 'HTTP/3' -> <<"websocket">> =:= cowboy_bstr:to_lower(Protocol); is_upgrade_request(Req=#{version := 'HTTP/1.1', method := <<"GET">>}) -> ConnTokens = cowboy_req:parse_header(<<"connection">>, Req, []), @@ -131,7 +151,7 @@ upgrade(Req, Env, Handler, HandlerState) -> %% @todo Immediately crash if a response has already been sent. upgrade(Req0=#{version := Version}, Env, Handler, HandlerState, Opts) -> FilteredReq = case maps:get(req_filter, Opts, undefined) of - undefined -> maps:with([method, version, scheme, host, port, path, qs, peer], Req0); + undefined -> maps:with([method, version, scheme, host, port, path, qs, peer, streamid], Req0); FilterFun -> FilterFun(Req0) end, Utf8State = case maps:get(validate_utf8, Opts, true) of @@ -148,13 +168,13 @@ upgrade(Req0=#{version := Version}, Env, Handler, HandlerState, Opts) -> <<"connection">> => <<"upgrade">>, <<"upgrade">> => <<"websocket">> }, Req0), Env}; - %% Use a generic 400 error for HTTP/2. + %% Use 501 Not Implemented for HTTP/2 and HTTP/3 as recommended + %% by RFC9220 3 (WebSockets Upgrade over HTTP/3). {error, upgrade_required} -> - {ok, cowboy_req:reply(400, Req0), Env} + {ok, cowboy_req:reply(501, Req0), Env} catch _:_ -> %% @todo Probably log something here? %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection. - %% @todo Does this even work? {ok, cowboy_req:reply(400, Req0), Env} end. @@ -255,12 +275,27 @@ websocket_handshake(State=#state{key=Key}, %% For HTTP/2 we do not let the process die, we instead keep it %% for the Websocket stream. This is because in HTTP/2 we only %% have a stream, it doesn't take over the whole connection. -websocket_handshake(State, Req=#{ref := Ref, pid := Pid, streamid := StreamID}, +%% +%% There are two methods of delivering data to the Websocket session: +%% - 'stream_handlers' is the default and makes the data go +%% through stream handlers just like when reading a request body; +%% - 'relay' is a new method where data is sent as a message as +%% soon as it is received from the socket in a DATA frame. +websocket_handshake(State=#state{opts=Opts}, + Req=#{ref := Ref, pid := Pid, streamid := StreamID}, HandlerState, _Env) -> %% @todo We don't want date and server headers. Headers = cowboy_req:response_headers(#{}, Req), - Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {State, HandlerState}}}, - takeover(Pid, Ref, {Pid, StreamID}, undefined, undefined, <<>>, + DataDelivery = maps:get(data_delivery, Opts, stream_handlers), + ModState = #{ + data_delivery => DataDelivery, + %% For relay data_delivery. The flow is a hint and may + %% not be used by the underlying protocol. + data_delivery_pid => self(), + data_delivery_flow => maps:get(data_delivery_flow, Opts, 131072) + }, + Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, ModState}}, + takeover(Pid, Ref, {Pid, StreamID}, {data_delivery, DataDelivery}, #{}, <<>>, {State, HandlerState}). %% Connection process. @@ -283,19 +318,26 @@ websocket_handshake(State, Req=#{ref := Ref, pid := Pid, streamid := StreamID}, -type parse_state() :: #ps_header{} | #ps_payload{}. -spec takeover(pid(), ranch:ref(), inet:socket() | {pid(), cowboy_stream:streamid()}, - module() | undefined, any(), binary(), + module() | {data_delivery, stream_handlers | relay}, any(), binary(), {#state{}, any()}) -> no_return(). -takeover(Parent, Ref, Socket, Transport, _Opts, Buffer, - {State0=#state{handler=Handler}, HandlerState}) -> - %% @todo We should have an option to disable this behavior. - ranch:remove_connection(Ref), +takeover(Parent, Ref, Socket, Transport, Opts, Buffer, + {State0=#state{opts=WsOpts, handler=Handler, req=Req}, HandlerState}) -> + case Req of + #{version := 'HTTP/3'} -> ok; + %% @todo We should have an option to disable this behavior. + _ -> ranch:remove_connection(Ref) + end, Messages = case Transport of - undefined -> undefined; + {data_delivery, _} -> undefined; _ -> Transport:messages() end, - State = loop_timeout(State0#state{parent=Parent, + State = set_idle_timeout(State0#state{parent=Parent, ref=Ref, socket=Socket, transport=Transport, - key=undefined, messages=Messages}), + opts=WsOpts#{dynamic_buffer => maps:get(dynamic_buffer, Opts, false)}, + key=undefined, messages=Messages, + %% Dynamic buffer only applies to HTTP/1.1 Websocket. + dynamic_buffer_size=init_dynamic_buffer_size(Opts), + dynamic_buffer_moving_average=maps:get(dynamic_buffer_initial_average, Opts, 0)}, 0), %% We call parse_header/3 immediately because there might be %% some data in the buffer that was sent along with the handshake. %% While it is not allowed by the protocol to send frames immediately, @@ -306,6 +348,12 @@ takeover(Parent, Ref, Socket, Transport, _Opts, Buffer, false -> after_init(State, HandlerState, #ps_header{buffer=Buffer}) end. +-include("cowboy_dynamic_buffer.hrl"). + +%% @todo Implement early socket error detection. +maybe_socket_error(_, _) -> + ok. + after_init(State=#state{active=true}, HandlerState, ParseState) -> %% Enable active,N for HTTP/1.1, and auto read_body for HTTP/2. %% We must do this only after calling websocket_init/1 (if any) @@ -324,13 +372,14 @@ after_init(State, HandlerState, ParseState) -> %% immediately but there might still be data to be processed in %% the message queue. -setopts_active(#state{transport=undefined}) -> +setopts_active(#state{transport={data_delivery, _}}) -> ok; setopts_active(#state{socket=Socket, transport=Transport, opts=Opts}) -> - N = maps:get(active_n, Opts, 100), + N = maps:get(active_n, Opts, 1), Transport:setopts(Socket, [{active, N}]). -maybe_read_body(#state{socket=Stream={Pid, _}, transport=undefined, active=true}) -> +maybe_read_body(#state{transport={data_delivery, stream_handlers}, + socket=Stream={Pid, _}, active=true}) -> %% @todo Keep Ref around. ReadBodyRef = make_ref(), Pid ! {Stream, {read_body, self(), ReadBodyRef, auto, infinity}}, @@ -338,16 +387,25 @@ maybe_read_body(#state{socket=Stream={Pid, _}, transport=undefined, active=true} maybe_read_body(_) -> ok. -active(State) -> +active(State=#state{transport={data_delivery, relay}, + socket=Stream={Pid, _}}) -> + Pid ! {'$cowboy_relay_command', Stream, active}, + State#state{active=true}; +active(State0) -> + State = State0#state{active=true}, setopts_active(State), maybe_read_body(State), - State#state{active=true}. + State. -passive(State=#state{transport=undefined}) -> +passive(State=#state{transport={data_delivery, stream_handlers}}) -> %% Unfortunately we cannot currently cancel read_body. %% But that's OK, we will just stop reading the body %% after the next message. State#state{active=false}; +passive(State=#state{transport={data_delivery, relay}, + socket=Stream={Pid, _}}) -> + Pid ! {'$cowboy_relay_command', Stream, passive}, + State#state{active=false}; passive(State=#state{socket=Socket, transport=Transport, messages=Messages}) -> Transport:setopts(Socket, [{active, false}]), flush_passive(Socket, Messages), @@ -369,28 +427,41 @@ before_loop(State=#state{hibernate=true}, HandlerState, ParseState) -> before_loop(State, HandlerState, ParseState) -> loop(State, HandlerState, ParseState). --spec loop_timeout(#state{}) -> #state{}. -loop_timeout(State=#state{opts=Opts, timeout_ref=PrevRef}) -> +-spec set_idle_timeout(#state{}, 0..?IDLE_TIMEOUT_TICKS) -> #state{}. + +%% @todo Do we really need this for HTTP/2? +set_idle_timeout(State=#state{opts=Opts, timeout_ref=PrevRef}, TimeoutNum) -> + %% Most of the time we don't need to cancel the timer since it + %% will have triggered already. But this call is harmless so + %% it is kept to simplify the code as we do need to cancel when + %% options are changed dynamically. _ = case PrevRef of undefined -> ignore; - PrevRef -> erlang:cancel_timer(PrevRef) + PrevRef -> erlang:cancel_timer(PrevRef, [{async, true}, {info, false}]) end, case maps:get(idle_timeout, Opts, 60000) of infinity -> - State#state{timeout_ref=undefined}; + State#state{timeout_ref=undefined, timeout_num=TimeoutNum}; Timeout -> - TRef = erlang:start_timer(Timeout, self(), ?MODULE), - State#state{timeout_ref=TRef} + TRef = erlang:start_timer(Timeout div ?IDLE_TIMEOUT_TICKS, self(), ?MODULE), + State#state{timeout_ref=TRef, timeout_num=TimeoutNum} end. +-define(reset_idle_timeout(State), State#state{timeout_num=0}). + +tick_idle_timeout(State=#state{timeout_num=?IDLE_TIMEOUT_TICKS}, HandlerState, _) -> + websocket_close(State, HandlerState, timeout); +tick_idle_timeout(State=#state{timeout_num=TimeoutNum}, HandlerState, ParseState) -> + before_loop(set_idle_timeout(State, TimeoutNum + 1), HandlerState, ParseState). + -spec loop(#state{}, any(), parse_state()) -> no_return(). loop(State=#state{parent=Parent, socket=Socket, messages=Messages, timeout_ref=TRef}, HandlerState, ParseState) -> receive %% Socket messages. (HTTP/1.1) {OK, Socket, Data} when OK =:= element(1, Messages) -> - State2 = loop_timeout(State), - parse(State2, HandlerState, ParseState, Data); + State1 = maybe_resize_buffer(State, Data), + parse(?reset_idle_timeout(State1), HandlerState, ParseState, Data); {Closed, Socket} when Closed =:= element(2, Messages) -> terminate(State, HandlerState, {error, closed}); {Error, Socket, Reason} when Error =:= element(3, Messages) -> @@ -403,18 +474,20 @@ loop(State=#state{parent=Parent, socket=Socket, messages=Messages, %% Body reading messages. (HTTP/2) {request_body, _Ref, nofin, Data} -> maybe_read_body(State), - State2 = loop_timeout(State), - parse(State2, HandlerState, ParseState, Data); + parse(?reset_idle_timeout(State), HandlerState, ParseState, Data); %% @todo We need to handle this case as if it was an {error, closed} %% but not before we finish processing frames. We probably should have %% a check in before_loop to let us stop looping if a flag is set. {request_body, _Ref, fin, _, Data} -> maybe_read_body(State), - State2 = loop_timeout(State), - parse(State2, HandlerState, ParseState, Data); + parse(?reset_idle_timeout(State), HandlerState, ParseState, Data); + %% @todo It would be better to check StreamID. + %% @todo We must ensure that IsFin=fin is handled like a socket close? + {'$cowboy_relay_data', {Pid, _StreamID}, _IsFin, Data} when Pid =:= Parent -> + parse(?reset_idle_timeout(State), HandlerState, ParseState, Data); %% Timeouts. {timeout, TRef, ?MODULE} -> - websocket_close(State, HandlerState, timeout); + tick_idle_timeout(State, HandlerState, ParseState); {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> before_loop(State, HandlerState, ParseState); %% System messages. @@ -458,12 +531,16 @@ parse_header(State=#state{opts=Opts, frag_state=FragState, extensions=Extensions websocket_close(State, HandlerState, {error, badframe}) end. -parse_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions}, +parse_payload(State=#state{opts=Opts, frag_state=FragState, utf8_state=Incomplete, extensions=Extensions}, HandlerState, ParseState=#ps_payload{ type=Type, len=Len, mask_key=MaskKey, rsv=Rsv, unmasked=Unmasked, unmasked_len=UnmaskedLen}, Data) -> + MaxFrameSize = case maps:get(max_frame_size, Opts, infinity) of + infinity -> infinity; + MaxFrameSize0 -> MaxFrameSize0 - UnmaskedLen + end, case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, - Type, Len, FragState, Extensions, Rsv) of + Type, Len, FragState, Extensions#{max_inflate_size => MaxFrameSize}, Rsv) of {ok, CloseCode, Payload, Utf8State, Rest} -> dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState, ParseState#ps_payload{unmasked= <<Unmasked/binary, Payload/binary>>, @@ -593,13 +670,16 @@ commands([{active, Active}|Tail], State0=#state{active=Active0}, Data) when is_b commands(Tail, State#state{active=Active}, Data); commands([{deflate, Deflate}|Tail], State, Data) when is_boolean(Deflate) -> commands(Tail, State#state{deflate=Deflate}, Data); -commands([{set_options, SetOpts}|Tail], State0=#state{opts=Opts}, Data) -> - State = case SetOpts of - #{idle_timeout := IdleTimeout} -> - loop_timeout(State0#state{opts=Opts#{idle_timeout => IdleTimeout}}); - _ -> - State0 - end, +commands([{set_options, SetOpts}|Tail], State0, Data) -> + State = maps:fold(fun + (idle_timeout, IdleTimeout, StateF=#state{opts=Opts}) -> + %% We reset the number of ticks when changing the idle_timeout option. + set_idle_timeout(StateF#state{opts=Opts#{idle_timeout => IdleTimeout}}, 0); + (max_frame_size, MaxFrameSize, StateF=#state{opts=Opts}) -> + StateF#state{opts=Opts#{max_frame_size => MaxFrameSize}}; + (_, _, StateF) -> + StateF + end, State0, SetOpts), commands(Tail, State, Data); commands([{shutdown_reason, ShutdownReason}|Tail], State, Data) -> commands(Tail, State#state{shutdown_reason=ShutdownReason}, Data); @@ -613,9 +693,14 @@ commands([Frame|Tail], State, Data0) -> commands(Tail, State, Data) end. -transport_send(#state{socket=Stream={Pid, _}, transport=undefined}, IsFin, Data) -> +transport_send(#state{transport={data_delivery, stream_handlers}, + socket=Stream={Pid, _}}, IsFin, Data) -> Pid ! {Stream, {data, IsFin, Data}}, ok; +transport_send(#state{transport={data_delivery, relay}, + socket=Stream={Pid, _}}, IsFin, Data) -> + Pid ! {'$cowboy_relay_command', Stream, {data, IsFin, Data}}, + ok; transport_send(#state{socket=Socket, transport=Transport}, _, Data) -> Transport:send(Socket, Data). |