diff options
Diffstat (limited to 'src')
-rw-r--r-- | src/cowboy_websocket.erl | 46 |
1 files changed, 42 insertions, 4 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index 067019c..2aa56b9 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -31,7 +31,12 @@ -export([system_terminate/4]). -export([system_code_change/4]). --type call_result(State) :: {ok, State} +-type commands() :: [cow_ws:frame()]. +-export_type([commands/0]). + +-type call_result(State) :: {commands(), State} | {commands(), State, hibernate}. + +-type deprecated_call_result(State) :: {ok, State} | {ok, State, hibernate} | {reply, cow_ws:frame() | [cow_ws:frame()], State} | {reply, cow_ws:frame() | [cow_ws:frame()], State, hibernate} @@ -48,13 +53,13 @@ when Req::cowboy_req:req(). -callback websocket_init(State) - -> call_result(State) when State::any(). + -> call_result(State) | deprecated_call_result(State) when State::any(). -optional_callbacks([websocket_init/1]). -callback websocket_handle(ping | pong | {text | binary | ping | pong, binary()}, State) - -> call_result(State) when State::any(). + -> call_result(State) | deprecated_call_result(State) when State::any(). -callback websocket_info(any(), State) - -> call_result(State) when State::any(). + -> call_result(State) | deprecated_call_result(State) when State::any(). -callback terminate(any(), cowboy_req:req(), any()) -> ok. -optional_callbacks([terminate/3]). @@ -457,6 +462,13 @@ handler_call(State=#state{handler=Handler}, HandlerState, websocket_init -> Handler:websocket_init(HandlerState); _ -> Handler:Callback(Message, HandlerState) end of + {Commands, HandlerState2} when is_list(Commands) -> + handler_call_result(State, + HandlerState2, ParseState, NextState, Commands); + {Commands, HandlerState2, hibernate} when is_list(Commands) -> + handler_call_result(State#state{hibernate=true}, + HandlerState2, ParseState, NextState, Commands); + %% The following call results are deprecated. {ok, HandlerState2} -> NextState(State, HandlerState2, ParseState); {ok, HandlerState2, hibernate} -> @@ -488,6 +500,32 @@ handler_call(State=#state{handler=Handler}, HandlerState, erlang:raise(Class, Reason, erlang:get_stacktrace()) end. +-spec handler_call_result(#state{}, any(), parse_state(), fun(), commands()) -> no_return(). +handler_call_result(State0, HandlerState, ParseState, NextState, Commands) -> + case commands(Commands, State0, []) of + {ok, State} -> + NextState(State, HandlerState, ParseState); + {stop, State} -> + terminate(State, HandlerState, stop); + {Error = {error, _}, State} -> + terminate(State, HandlerState, Error) + end. + +commands([], State, []) -> + {ok, State}; +commands([], State, Data) -> + Result = transport_send(State, nofin, lists:reverse(Data)), + {Result, State}; +commands([Frame|Tail], State=#state{extensions=Extensions}, Data0) -> + Data = [cow_ws:frame(Frame, Extensions)|Data0], + case is_close_frame(Frame) of + true -> + _ = transport_send(State, fin, lists:reverse(Data)), + {stop, State}; + false -> + commands(Tail, State, Data) + end. + transport_send(#state{socket=Stream={Pid, _}, transport=undefined}, IsFin, Data) -> Pid ! {Stream, {data, IsFin, Data}}, ok; |