diff options
author | Loïc Hoguin <[email protected]> | 2018-04-04 17:23:37 +0200 |
---|---|---|
committer | Loïc Hoguin <[email protected]> | 2018-04-04 17:23:37 +0200 |
commit | bbfc1569ccffab060c4c2b402a45119fb1f57495 (patch) | |
tree | a027f5e43ea26e1ceebefba27888c808ac4cd525 /src/cowboy_websocket.erl | |
parent | a7b06f2e138c0c03c2511ed9fe6803fc9ebf3401 (diff) | |
download | cowboy-bbfc1569ccffab060c4c2b402a45119fb1f57495.tar.gz cowboy-bbfc1569ccffab060c4c2b402a45119fb1f57495.tar.bz2 cowboy-bbfc1569ccffab060c4c2b402a45119fb1f57495.zip |
Add initial implementation of Websocket over HTTP/2
Using the current draft:
https://tools.ietf.org/html/draft-ietf-httpbis-h2-websockets-01
Diffstat (limited to 'src/cowboy_websocket.erl')
-rw-r--r-- | src/cowboy_websocket.erl | 188 |
1 files changed, 132 insertions, 56 deletions
diff --git a/src/cowboy_websocket.erl b/src/cowboy_websocket.erl index df2e1a5..992af52 100644 --- a/src/cowboy_websocket.erl +++ b/src/cowboy_websocket.erl @@ -17,6 +17,7 @@ -module(cowboy_websocket). -behaviour(cowboy_sub_protocol). +-export([is_upgrade_request/1]). -export([upgrade/4]). -export([upgrade/5]). -export([takeover/7]). @@ -82,6 +83,25 @@ req = #{} :: map() }). +%% Because the HTTP/1.1 and HTTP/2 handshakes are so different, +%% this function is necessary to figure out whether a request +%% is trying to upgrade to the Websocket protocol. + +-spec is_upgrade_request(cowboy_req:req()) -> boolean(). +is_upgrade_request(#{version := 'HTTP/2', method := <<"CONNECT">>, protocol := Protocol}) -> + <<"websocket">> =:= cowboy_bstr:to_lower(Protocol); +is_upgrade_request(Req=#{version := 'HTTP/1.1', method := <<"GET">>}) -> + ConnTokens = cowboy_req:parse_header(<<"connection">>, Req, []), + case lists:member(<<"upgrade">>, ConnTokens) of + false -> + false; + true -> + UpgradeTokens = cowboy_req:parse_header(<<"upgrade">>, Req), + lists:member(<<"websocket">>, UpgradeTokens) + end; +is_upgrade_request(_) -> + false. + %% Stream process. -spec upgrade(Req, Env, module(), any()) @@ -94,8 +114,7 @@ upgrade(Req, Env, Handler, HandlerState) -> -> {ok, Req, Env} when Req::cowboy_req:req(), Env::cowboy_middleware:env(). %% @todo Immediately crash if a response has already been sent. -%% @todo Error out if HTTP/2. -upgrade(Req0, Env, Handler, HandlerState, Opts) -> +upgrade(Req0=#{version := Version}, Env, Handler, HandlerState, Opts) -> Timeout = maps:get(idle_timeout, Opts, 60000), MaxFrameSize = maps:get(max_frame_size, Opts, infinity), Compress = maps:get(compress, Opts, false), @@ -108,11 +127,15 @@ upgrade(Req0, Env, Handler, HandlerState, Opts) -> try websocket_upgrade(State0, Req0) of {ok, State, Req} -> websocket_handshake(State, Req, HandlerState, Env); - {error, upgrade_required} -> + %% The status code 426 is specific to HTTP/1.1 connections. + {error, upgrade_required} when Version =:= 'HTTP/1.1' -> {ok, cowboy_req:reply(426, #{ <<"connection">> => <<"upgrade">>, <<"upgrade">> => <<"websocket">> - }, Req0), Env} + }, Req0), Env}; + %% Use a generic 400 error for HTTP/2. + {error, upgrade_required} -> + {ok, cowboy_req:reply(400, Req0), Env} catch _:_ -> %% @todo Probably log something here? %% @todo Test that we can have 2 /ws 400 status code in a row on the same connection. @@ -120,27 +143,27 @@ upgrade(Req0, Env, Handler, HandlerState, Opts) -> {ok, cowboy_req:reply(400, Req0), Env} end. -websocket_upgrade(State, Req) -> - ConnTokens = cowboy_req:parse_header(<<"connection">>, Req, []), - case lists:member(<<"upgrade">>, ConnTokens) of +websocket_upgrade(State, Req=#{version := Version}) -> + case is_upgrade_request(Req) of false -> {error, upgrade_required}; + true when Version =:= 'HTTP/1.1' -> + Key = cowboy_req:header(<<"sec-websocket-key">>, Req), + false = Key =:= undefined, + websocket_version(State#state{key=Key}, Req); true -> - UpgradeTokens = cowboy_req:parse_header(<<"upgrade">>, Req, []), - case lists:member(<<"websocket">>, UpgradeTokens) of - false -> - {error, upgrade_required}; - true -> - Version = cowboy_req:header(<<"sec-websocket-version">>, Req), - IntVersion = binary_to_integer(Version), - true = (IntVersion =:= 7) orelse (IntVersion =:= 8) - orelse (IntVersion =:= 13), - Key = cowboy_req:header(<<"sec-websocket-key">>, Req), - false = Key =:= undefined, - websocket_extensions(State#state{key=Key}, Req#{websocket_version => IntVersion}) - end + websocket_version(State, Req) end. +websocket_version(State, Req) -> + WsVersion = cowboy_req:parse_header(<<"sec-websocket-version">>, Req), + case WsVersion of + 7 -> ok; + 8 -> ok; + 13 -> ok + end, + websocket_extensions(State, Req#{websocket_version => WsVersion}). + websocket_extensions(State=#state{compress=Compress}, Req) -> %% @todo We want different options for this. For example %% * compress everything auto @@ -159,11 +182,16 @@ websocket_extensions(State, Req, [], []) -> {ok, State, Req}; websocket_extensions(State, Req, [], [<<", ">>|RespHeader]) -> {ok, State, cowboy_req:set_resp_header(<<"sec-websocket-extensions">>, lists:reverse(RespHeader), Req)}; -websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid}, +%% For HTTP/2 we ARE on the controlling process and do NOT want to update the owner. +websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid, version := Version}, [{<<"permessage-deflate">>, Params}|Tail], RespHeader) -> %% @todo Make deflate options configurable. - Opts = #{level => best_compression, mem_level => 8, strategy => default}, - try cow_ws:negotiate_permessage_deflate(Params, Extensions, Opts#{owner => Pid}) of + Opts0 = #{level => best_compression, mem_level => 8, strategy => default}, + Opts = case Version of + 'HTTP/1.1' -> Opts0#{owner => Pid}; + _ -> Opts0 + end, + try cow_ws:negotiate_permessage_deflate(Params, Extensions, Opts) of {ok, RespExt, Extensions2} -> websocket_extensions(State#state{extensions=Extensions2}, Req, Tail, [<<", ">>, RespExt|RespHeader]); @@ -172,11 +200,15 @@ websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid}, catch exit:{error, incompatible_zlib_version, _} -> websocket_extensions(State, Req, Tail, RespHeader) end; -websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid}, +websocket_extensions(State=#state{extensions=Extensions}, Req=#{pid := Pid, version := Version}, [{<<"x-webkit-deflate-frame">>, Params}|Tail], RespHeader) -> %% @todo Make deflate options configurable. - Opts = #{level => best_compression, mem_level => 8, strategy => default}, - try cow_ws:negotiate_x_webkit_deflate_frame(Params, Extensions, Opts#{owner => Pid}) of + Opts0 = #{level => best_compression, mem_level => 8, strategy => default}, + Opts = case Version of + 'HTTP/1.1' -> Opts0#{owner => Pid}; + _ -> Opts0 + end, + try cow_ws:negotiate_x_webkit_deflate_frame(Params, Extensions, Opts) of {ok, RespExt, Extensions2} -> websocket_extensions(State#state{extensions=Extensions2}, Req, Tail, [<<", ">>, RespExt|RespHeader]); @@ -192,7 +224,8 @@ websocket_extensions(State, Req, [_|Tail], RespHeader) -> -> {ok, Req, Env} when Req::cowboy_req:req(), Env::cowboy_middleware:env(). websocket_handshake(State=#state{key=Key}, - Req=#{pid := Pid, streamid := StreamID}, HandlerState, Env) -> + Req=#{version := 'HTTP/1.1', pid := Pid, streamid := StreamID}, + HandlerState, Env) -> Challenge = base64:encode(crypto:hash(sha, << Key/binary, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" >>)), %% @todo We don't want date and server headers. @@ -202,7 +235,17 @@ websocket_handshake(State=#state{key=Key}, <<"sec-websocket-accept">> => Challenge }, Req), Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {State, HandlerState}}}, - {ok, Req, Env}. + {ok, Req, Env}; +%% For HTTP/2 we do not let the process die, we instead keep it +%% for the Websocket stream. This is because in HTTP/2 we only +%% have a stream, it doesn't take over the whole connection. +websocket_handshake(State, Req=#{ref := Ref, pid := Pid, streamid := StreamID}, + HandlerState, _Env) -> + %% @todo We don't want date and server headers. + Headers = cowboy_req:response_headers(#{}, Req), + Pid ! {{Pid, StreamID}, {switch_protocol, Headers, ?MODULE, {State, HandlerState}}}, + takeover(Pid, Ref, {Pid, StreamID}, undefined, undefined, <<>>, + {State, HandlerState}). %% Connection process. @@ -223,21 +266,34 @@ websocket_handshake(State=#state{key=Key}, -type parse_state() :: #ps_header{} | #ps_payload{}. --spec takeover(pid(), ranch:ref(), inet:socket(), module(), any(), binary(), +-spec takeover(pid(), ranch:ref(), inet:socket() | {pid(), cowboy_stream:streamid()}, + module() | undefined, any(), binary(), {#state{}, any()}) -> no_return(). takeover(Parent, Ref, Socket, Transport, _Opts, Buffer, {State0=#state{handler=Handler}, HandlerState}) -> %% @todo We should have an option to disable this behavior. ranch:remove_connection(Ref), + Messages = case Transport of + undefined -> undefined; + _ -> Transport:messages() + end, State = loop_timeout(State0#state{parent=Parent, ref=Ref, socket=Socket, transport=Transport, - key=undefined, messages=Transport:messages()}), + key=undefined, messages=Messages}), case erlang:function_exported(Handler, websocket_init, 1) of true -> handler_call(State, HandlerState, #ps_header{buffer=Buffer}, websocket_init, undefined, fun before_loop/3); false -> before_loop(State, HandlerState, #ps_header{buffer=Buffer}) end. +%% @todo We probably shouldn't do the setopts if we have not received a socket message. +%% @todo We need to hibernate when HTTP/2 is used too. +before_loop(State=#state{socket=Stream={Pid, _}, transport=undefined}, + HandlerState, ParseState) -> + %% @todo Keep Ref around. + ReadBodyRef = make_ref(), + Pid ! {Stream, {read_body, ReadBodyRef, auto, infinity}}, + loop(State, HandlerState, ParseState); before_loop(State=#state{socket=Socket, transport=Transport, hibernate=true}, HandlerState, ParseState) -> Transport:setopts(Socket, [{active, once}]), @@ -258,19 +314,32 @@ loop_timeout(State=#state{timeout=Timeout, timeout_ref=PrevRef}) -> State#state{timeout_ref=TRef}. -spec loop(#state{}, any(), parse_state()) -> no_return(). -loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error}, +loop(State=#state{parent=Parent, socket=Socket, messages=Messages, timeout_ref=TRef}, HandlerState, ParseState) -> receive - {OK, Socket, Data} -> + %% Socket messages. (HTTP/1.1) + {OK, Socket, Data} when OK =:= element(1, Messages) -> State2 = loop_timeout(State), parse(State2, HandlerState, ParseState, Data); - {Closed, Socket} -> + {Closed, Socket} when Closed =:= element(2, Messages) -> terminate(State, HandlerState, {error, closed}); - {Error, Socket, Reason} -> + {Error, Socket, Reason} when Error =:= element(3, Messages) -> terminate(State, HandlerState, {error, Reason}); + %% Body reading messages. (HTTP/2) + {request_body, _Ref, nofin, Data} -> + State2 = loop_timeout(State), + parse(State2, HandlerState, ParseState, Data); + %% @todo We need to handle this case as if it was an {error, closed} + %% but not before we finish processing frames. We probably should have + %% a check in before_loop to let us stop looping if a flag is set. + {request_body, _Ref, fin, _, Data} -> + State2 = loop_timeout(State), + parse(State2, HandlerState, ParseState, Data); + %% Timeouts. {timeout, TRef, ?MODULE} -> websocket_close(State, HandlerState, timeout); {timeout, OlderTRef, ?MODULE} when is_reference(OlderTRef) -> + %% @todo This should call before_loop. loop(State, HandlerState, ParseState); %% System messages. {'EXIT', Parent, Reason} -> @@ -282,6 +351,7 @@ loop(State=#state{parent=Parent, socket=Socket, messages={OK, Closed, Error}, %% Calls from supervisor module. {'$gen_call', From, Call} -> cowboy_children:handle_supervisor_call(Call, From, [], ?MODULE), + %% @todo This should call before_loop. loop(State, HandlerState, ParseState); Message -> handler_call(State, HandlerState, ParseState, @@ -341,8 +411,7 @@ parse_payload(State=#state{frag_state=FragState, utf8_state=Incomplete, extensio websocket_close(State, HandlerState, Error) end. -dispatch_frame(State=#state{socket=Socket, transport=Transport, - max_frame_size=MaxFrameSize, frag_state=FragState, +dispatch_frame(State=#state{max_frame_size=MaxFrameSize, frag_state=FragState, frag_buffer=SoFar, extensions=Extensions}, HandlerState, #ps_payload{type=Type0, unmasked=Payload0, close_code=CloseCode0}, RemainingData) -> @@ -363,12 +432,12 @@ dispatch_frame(State=#state{socket=Socket, transport=Transport, {close, CloseCode, Payload} -> websocket_close(State, HandlerState, {remote, CloseCode, Payload}); Frame = ping -> - Transport:send(Socket, cow_ws:frame(pong, Extensions)), + transport_send(State, nofin, cow_ws:frame(pong, Extensions)), handler_call(State, HandlerState, #ps_header{buffer=RemainingData}, websocket_handle, Frame, fun parse_header/3); Frame = {ping, Payload} -> - Transport:send(Socket, cow_ws:frame({pong, Payload}, Extensions)), + transport_send(State, nofin, cow_ws:frame({pong, Payload}, Extensions)), handler_call(State, HandlerState, #ps_header{buffer=RemainingData}, websocket_handle, Frame, fun parse_header/3); @@ -415,24 +484,32 @@ handler_call(State=#state{handler=Handler}, HandlerState, erlang:raise(Class, Reason, erlang:get_stacktrace()) end. +transport_send(#state{socket=Stream={Pid, _}, transport=undefined}, IsFin, Data) -> + Pid ! {Stream, {data, IsFin, Data}}, + ok; +transport_send(#state{socket=Socket, transport=Transport}, _, Data) -> + Transport:send(Socket, Data). + -spec websocket_send(cow_ws:frame(), #state{}) -> ok | stop | {error, atom()}. websocket_send(Frames, State) when is_list(Frames) -> websocket_send_many(Frames, State, []); -websocket_send(Frame, #state{socket=Socket, transport=Transport, extensions=Extensions}) -> - Res = Transport:send(Socket, cow_ws:frame(Frame, Extensions)), +websocket_send(Frame, State=#state{extensions=Extensions}) -> + Data = cow_ws:frame(Frame, Extensions), case is_close_frame(Frame) of - true -> stop; - false -> Res + true -> + _ = transport_send(State, fin, Data), + stop; + false -> + transport_send(State, nofin, Data) end. -websocket_send_many([], #state{socket=Socket, transport=Transport}, Acc) -> - Transport:send(Socket, lists:reverse(Acc)); -websocket_send_many([Frame|Tail], State=#state{socket=Socket, transport=Transport, - extensions=Extensions}, Acc0) -> +websocket_send_many([], State, Acc) -> + transport_send(State, nofin, lists:reverse(Acc)); +websocket_send_many([Frame|Tail], State=#state{extensions=Extensions}, Acc0) -> Acc = [cow_ws:frame(Frame, Extensions)|Acc0], case is_close_frame(Frame) of true -> - _ = Transport:send(Socket, lists:reverse(Acc)), + _ = transport_send(State, fin, lists:reverse(Acc)), stop; false -> websocket_send_many(Tail, State, Acc) @@ -448,23 +525,22 @@ websocket_close(State, HandlerState, Reason) -> websocket_send_close(State, Reason), terminate(State, HandlerState, Reason). -websocket_send_close(#state{socket=Socket, transport=Transport, - extensions=Extensions}, Reason) -> +websocket_send_close(State=#state{extensions=Extensions}, Reason) -> _ = case Reason of Normal when Normal =:= stop; Normal =:= timeout -> - Transport:send(Socket, cow_ws:frame({close, 1000, <<>>}, Extensions)); + transport_send(State, fin, cow_ws:frame({close, 1000, <<>>}, Extensions)); {error, badframe} -> - Transport:send(Socket, cow_ws:frame({close, 1002, <<>>}, Extensions)); + transport_send(State, fin, cow_ws:frame({close, 1002, <<>>}, Extensions)); {error, badencoding} -> - Transport:send(Socket, cow_ws:frame({close, 1007, <<>>}, Extensions)); + transport_send(State, fin, cow_ws:frame({close, 1007, <<>>}, Extensions)); {error, badsize} -> - Transport:send(Socket, cow_ws:frame({close, 1009, <<>>}, Extensions)); + transport_send(State, fin, cow_ws:frame({close, 1009, <<>>}, Extensions)); {crash, _, _} -> - Transport:send(Socket, cow_ws:frame({close, 1011, <<>>}, Extensions)); + transport_send(State, fin, cow_ws:frame({close, 1011, <<>>}, Extensions)); remote -> - Transport:send(Socket, cow_ws:frame(close, Extensions)); + transport_send(State, fin, cow_ws:frame(close, Extensions)); {remote, Code, _} -> - Transport:send(Socket, cow_ws:frame({close, Code, <<>>}, Extensions)) + transport_send(State, fin, cow_ws:frame({close, Code, <<>>}, Extensions)) end, ok. |