From 21c9c669719b864f6cc091125bc766183b43bd87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= Date: Fri, 23 Mar 2018 16:32:53 +0100 Subject: Merge the two separate receive loops in cowboy_websocket Also rename a bunch of functions to make the code easier to read. --- src/cowboy_websocket.erl | 209 ++++++++++++++++++++++++----------------------- test/sys_SUITE.erl | 3 - 2 files changed, 109 insertions(+), 103 deletions(-) diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 4f8117a..725d7ec 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -20,7 +20,7 @@ -export([upgrade/4]). -export([upgrade/5]). -export([takeover/7]). --export([handler_loop/3]). +-export([loop/3]). -export([system_continue/3]). -export([system_terminate/4]). @@ -202,53 +202,64 @@ websocket_handshake(State=#state{key=Key}, %% Connection process. -%% @todo Keep parent and handle system messages. +-record(ps_header, { + buffer = <<>> :: binary() +}). + +-record(ps_payload, { + type :: cow_ws:frame_type(), + len :: non_neg_integer(), + mask_key :: cow_ws:mask_key(), + rsv :: cow_ws:rsv(), + close_code = undefined :: undefined | cow_ws:close_code(), + unmasked = <<>> :: binary(), + unmasked_len = 0 :: non_neg_integer(), + buffer = <<>> :: binary() +}). + +-type parse_state() :: #ps_header{} | #ps_payload{}. + -spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(), - {#state{}, any()}) -> ok. + {#state{}, any()}) -> no_return(). takeover(Parent, Ref, Socket, Transport, _Opts, Buffer, {State0=#state{handler=Handler}, HandlerState}) -> %% @todo We should have an option to disable this behavior. ranch:remove_connection(Ref), - State1 = handler_loop_timeout(State0#state{parent=Parent, - ref=Ref, socket=Socket, transport=Transport}), - State = State1#state{key=undefined, messages=Transport:messages()}, + State = loop_timeout(State0#state{parent=Parent, + ref=Ref, socket=Socket, transport=Transport, + key=undefined, messages=Transport:messages()}), case erlang:function_exported(Handler, websocket_init, 1) of - true -> handler_call(State, HandlerState, Buffer, websocket_init, undefined, fun handler_before_loop/3); - false -> handler_before_loop(State, HandlerState, Buffer) + true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer}, + websocket_init, undefined, fun before_loop/3); + false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer}) end. --spec handler_before_loop(#state{}, any(), binary()) -%% @todo Yeah not env. - -> {ok, cowboy_middleware:env()}. -handler_before_loop(State=#state{ - socket=Socket, transport=Transport, hibernate=true}, - HandlerState, SoFar) -> +before_loop(State=#state{socket=Socket, transport=Transport, hibernate=true}, + HandlerState, ParseState) -> Transport:setopts(Socket, [{active, once}]), - proc_lib:hibernate(?MODULE, handler_loop, - [State#state{hibernate=false}, HandlerState, SoFar]); -handler_before_loop(State=#state{socket=Socket, transport=Transport}, - HandlerState, SoFar) -> + proc_lib:hibernate(?MODULE, loop, + [State#state{hibernate=false}, HandlerState, ParseState]); +before_loop(State=#state{socket=Socket, transport=Transport}, + HandlerState, ParseState) -> Transport:setopts(Socket, [{active, once}]), - handler_loop(State, HandlerState, SoFar). + loop(State, HandlerState, ParseState). --spec handler_loop_timeout(#state{}) -> #state{}. -handler_loop_timeout(State=#state{timeout=infinity}) -> +-spec loop_timeout(#state{}) -> #state{}. +loop_timeout(State=#state{timeout=infinity}) -> State#state{timeout_ref=undefined}; -handler_loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) -> +loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) -> _ = case PrevRef of undefined -> ignore; PrevRef -> erlang:cancel_timer(PrevRef) end, TRef = erlang:start_timer(Timeout, self(), ?MODULE), State#state{timeout_ref=TRef}. --spec handler_loop(#state{}, any(), binary()) - -> {ok, cowboy_middleware:env()}. -handler_loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error}, - timeout_ref=TRef}, HandlerState, SoFar) -> +-spec loop(#state{}, any(), parse_state()) -> no_return(). +loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error}, + timeout_ref=TRef}, HandlerState, ParseState) -> receive {OK, Socket, Data} -> - State2 = handler_loop_timeout(State), - websocket_data(State2, HandlerState, - << SoFar/binary, Data/binary >>); + State2 = loop_timeout(State), + parse(State2, HandlerState, ParseState, Data); {Closed, Socket} -> terminate(State, HandlerState, {error, closed}); {Error, Socket, Reason} -> @@ -256,124 +267,121 @@ handler_loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Er {timeout, TRef, ?MODULE} -> websocket_close(State, HandlerState, timeout); {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> - handler_loop(State, HandlerState, SoFar); + loop(State, HandlerState, ParseState); %% System messages. {'EXIT', Parent, Reason} -> %% @todo We should exit gracefully. exit(Reason); {system, From, Request} -> sys:handle_system_msg(Request, From, Parent, ?MODULE, [], - {State, HandlerState, SoFar}); + {State, HandlerState, ParseState}); %% Calls from supervisor module. {'$gen_call', From, Call} -> cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE), - handler_loop(State, HandlerState, SoFar); + loop(State, HandlerState, ParseState); Message -> - handler_call(State, HandlerState, - SoFar, websocket_info, Message, fun handler_before_loop/3) + handler_call(State, HandlerState, ParseState, + websocket_info, Message, fun before_loop/3) end. --spec websocket_data(#state{}, any(), binary()) - -> {ok, cowboy_middleware:env()}. -websocket_data(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState, Data) -> +parse(State, HandlerState, PS=#ps_header{buffer=Buffer}, Data) -> + parse_header(State, HandlerState, PS#ps_header{ + buffer= <>}); +parse(State, HandlerState, PS=#ps_payload{buffer=Buffer}, Data) -> + parse_payload(State, HandlerState, PS#ps_payload{buffer= <<>>}, + <>). + +parse_header(State=#state{frag_state=FragState, extensions=Extensions}, HandlerState, + ParseState=#ps_header{buffer=Data}) -> case cow_ws:parse_header(Data, Extensions, FragState) of %% All frames sent from the client to the server are masked. {_, _, _, _, undefined, _} -> websocket_close(State, HandlerState, {error, badframe}); {Type, FragState2, Rsv, Len, MaskKey, Rest} -> - websocket_payload(State#state{frag_state=FragState2}, HandlerState, Type, Len, MaskKey, Rsv, undefined, <<>>, 0, Rest); + parse_payload(State#state{frag_state=FragState2}, HandlerState, + #ps_payload{type=Type, len=Len, mask_key=MaskKey, rsv=Rsv}, Rest); more -> - handler_before_loop(State, HandlerState, Data); + before_loop(State, HandlerState, ParseState); error -> websocket_close(State, HandlerState, {error, badframe}) end. -websocket_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions}, - HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data) -> - case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, Type, Len, FragState, Extensions, Rsv) of - {ok, CloseCode2, Payload, Utf8State, Rest} -> - websocket_dispatch(State#state{utf8_state=Utf8State}, - HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode2, Rest); +parse_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensions=Extensions}, + HandlerState, ParseState=#ps_payload{ + type=Type, len=Len, mask_key=MaskKey, rsv=Rsv, + unmasked=Unmasked, unmasked_len=UnmaskedLen}, Data) -> + case cow_ws:parse_payload(Data, MaskKey, Incomplete, UnmaskedLen, + Type, Len, FragState, Extensions, Rsv) of + {ok, CloseCode, Payload, Utf8State, Rest} -> + dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState, + ParseState#ps_payload{unmasked= <>, + close_code=CloseCode}, Rest); {ok, Payload, Utf8State, Rest} -> - websocket_dispatch(State#state{utf8_state=Utf8State}, - HandlerState, Type, << Unmasked/binary, Payload/binary >>, CloseCode, Rest); - {more, CloseCode2, Payload, Utf8State} -> - websocket_payload_loop(State#state{utf8_state=Utf8State}, - HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode2, - << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data)); + dispatch_frame(State#state{utf8_state=Utf8State}, HandlerState, + ParseState#ps_payload{unmasked= <>}, + Rest); + {more, CloseCode, Payload, Utf8State} -> + before_loop(State#state{utf8_state=Utf8State}, HandlerState, + ParseState#ps_payload{len=Len - byte_size(Data), close_code=CloseCode, + unmasked= <>, + unmasked_len=UnmaskedLen + byte_size(Data)}); {more, Payload, Utf8State} -> - websocket_payload_loop(State#state{utf8_state=Utf8State}, - HandlerState, Type, Len - byte_size(Data), MaskKey, Rsv, CloseCode, - << Unmasked/binary, Payload/binary >>, UnmaskedLen + byte_size(Data)); + before_loop(State#state{utf8_state=Utf8State}, HandlerState, + ParseState#ps_payload{len=Len - byte_size(Data), + unmasked= <>, + unmasked_len=UnmaskedLen + byte_size(Data)}); Error = {error, _Reason} -> websocket_close(State, HandlerState, Error) end. -websocket_payload_loop(State=#state{socket=Socket, transport=Transport, - messages={OK, Closed, Error}, timeout_ref=TRef}, - HandlerState, Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) -> - Transport:setopts(Socket, [{active, once}]), - receive - {OK, Socket, Data} -> - State2 = handler_loop_timeout(State), - websocket_payload(State2, HandlerState, - Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen, Data); - {Closed, Socket} -> - terminate(State, HandlerState, {error, closed}); - {Error, Socket, Reason} -> - terminate(State, HandlerState, {error, Reason}); - {timeout, TRef, ?MODULE} -> - websocket_close(State, HandlerState, timeout); - {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> - websocket_payload_loop(State, HandlerState, - Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen); - Message -> - handler_call(State, HandlerState, - <<>>, websocket_info, Message, - fun (State2, HandlerState2, _) -> - websocket_payload_loop(State2, HandlerState2, - Type, Len, MaskKey, Rsv, CloseCode, Unmasked, UnmaskedLen) - end) - end. - -websocket_dispatch(State=#state{socket=Socket, transport=Transport, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions}, - HandlerState, Type0, Payload0, CloseCode0, RemainingData) -> +dispatch_frame(State=#state{socket=Socket, transport=Transport, + frag_state=FragState, frag_buffer=SoFar, extensions=Extensions}, + HandlerState, #ps_payload{type=Type0, unmasked=Payload0, close_code=CloseCode0}, + RemainingData) -> case cow_ws:make_frame(Type0, Payload0, CloseCode0, FragState) of %% @todo Allow receiving fragments. {fragment, nofin, _, Payload} -> - websocket_data(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, HandlerState, RemainingData); + parse_header(State#state{frag_buffer= << SoFar/binary, Payload/binary >>}, + HandlerState, #ps_header{buffer=RemainingData}); {fragment, fin, Type, Payload} -> - handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState, RemainingData, - websocket_handle, {Type, << SoFar/binary, Payload/binary >>}, fun websocket_data/3); + handler_call(State#state{frag_state=undefined, frag_buffer= <<>>}, HandlerState, + #ps_header{buffer=RemainingData}, + websocket_handle, {Type, << SoFar/binary, Payload/binary >>}, + fun parse_header/3); close -> websocket_close(State, HandlerState, remote); {close, CloseCode, Payload} -> websocket_close(State, HandlerState, {remote, CloseCode, Payload}); Frame = ping -> Transport:send(Socket, cow_ws:frame(pong, Extensions)), - handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3); + handler_call(State, HandlerState, + #ps_header{buffer=RemainingData}, + websocket_handle, Frame, fun parse_header/3); Frame = {ping, Payload} -> Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)), - handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3); + handler_call(State, HandlerState, + #ps_header{buffer=RemainingData}, + websocket_handle, Frame, fun parse_header/3); Frame -> - handler_call(State, HandlerState, RemainingData, websocket_handle, Frame, fun websocket_data/3) + handler_call(State, HandlerState, + #ps_header{buffer=RemainingData}, + websocket_handle, Frame, fun parse_header/3) end. --spec handler_call(#state{}, any(), binary(), atom(), any(), fun()) -> no_return(). handler_call(State=#state{handler=Handler}, HandlerState, - RemainingData, Callback, Message, NextState) -> + ParseState, Callback, Message, NextState) -> try case Callback of websocket_init -> Handler:websocket_init(HandlerState); _ -> Handler:Callback(Message, HandlerState) end of {ok, HandlerState2} -> - NextState(State, HandlerState2, RemainingData); + NextState(State, HandlerState2, ParseState); {ok, HandlerState2, hibernate} -> - NextState(State#state{hibernate=true}, HandlerState2, RemainingData); + NextState(State#state{hibernate=true}, HandlerState2, ParseState); {reply, Payload, HandlerState2} -> case websocket_send(Payload, State) of ok -> - NextState(State, HandlerState2, RemainingData); + NextState(State, HandlerState2, ParseState); stop -> terminate(State, HandlerState2, stop); Error = {error, _} -> @@ -383,7 +391,7 @@ handler_call(State=#state{handler=Handler}, HandlerState, case websocket_send(Payload, State) of ok -> NextState(State#state{hibernate=true}, - HandlerState2, RemainingData); + HandlerState2, ParseState); stop -> terminate(State, HandlerState2, stop); Error = {error, _} -> @@ -458,15 +466,16 @@ handler_terminate(#state{handler=Handler, req=Req}, HandlerState, Reason) -> %% System callbacks. --spec system_continue(_, _, {#state{}, any(), binary()}) -> ok. -system_continue(_, _, {State, HandlerState, SoFar}) -> - handler_loop(State, HandlerState, SoFar). +-spec system_continue(_, _, {#state{}, any(), parse_state()}) -> no_return(). +system_continue(_, _, {State, HandlerState, ParseState}) -> + loop(State, HandlerState, ParseState). --spec system_terminate(any(), _, _, {#state{}, any(), binary()}) -> no_return(). +-spec system_terminate(any(), _, _, {#state{}, any(), parse_state()}) -> no_return(). system_terminate(Reason, _, _, {State, HandlerState, _}) -> %% @todo We should exit gracefully, if possible. terminate(State, HandlerState, Reason). --spec system_code_change(Misc, _, _, _) -> {ok, Misc} when Misc::{#state{}, any(), binary()}. +-spec system_code_change(Misc, _, _, _) + -> {ok, Misc} when Misc::{#state{}, any(), parse_state()}. system_code_change(Misc, _, _, _) -> {ok, Misc}. diff --git a/test/sys_SUITE.erl b/test/sys_SUITE.erl index 3850796..6a460bf 100644 --- a/test/sys_SUITE.erl +++ b/test/sys_SUITE.erl @@ -112,9 +112,6 @@ proc_lib_initial_call_tls(Config) -> %% so that it doesn't eat up system messages. It should only %% flush messages that are specific to cowboy_http. -%% @todo The cowboy_websocket module needs to have the functions -%% handler_loop and websocket_payload_loop merged into one. - bad_system_from_h1(Config) -> doc("h1: Sending a system message with a bad From value results in a process crash."), {ok, Socket} = gen_tcp:connect("localhost", config(clear_port, Config), [{active, false}]), -- cgit v1.2.3