From a8db5d9f7a21f36d4582dd083001757513e5990e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Lo=C3=AFc=20Hoguin?= Date: Thu, 12 Mar 2015 16:02:06 +0100 Subject: Add missing client functionality to Websocket code --- src/cow_ws.erl | 190 ++++++++++++++++++++++++++++++++++++++++++++++----------- 1 file changed, 156 insertions(+), 34 deletions(-) diff --git a/src/cow_ws.erl b/src/cow_ws.erl index 0858b68..c89c17a 100644 --- a/src/cow_ws.erl +++ b/src/cow_ws.erl @@ -14,33 +14,59 @@ -module(cow_ws). +-export([key/0]). +-export([encode_key/1]). + -export([negotiate_permessage_deflate/3]). -export([negotiate_x_webkit_deflate_frame/3]). +-export([validate_permessage_deflate/3]). + -export([parse_header/3]). -export([parse_payload/9]). -export([make_frame/4]). + -export([frame/2]). +-export([masked_frame/2]). -type close_code() :: 1000..1003 | 1006..1011 | 3000..4999. -export_type([close_code/0]). +-type extensions() :: map(). +-export_type([extensions/0]). + -type frag_state() :: undefined | {fin | nofin, text | binary, rsv()}. -export_type([frag_state/0]). -type frame() :: close | ping | pong | {text | binary | close | ping | pong, iodata()} | {close, close_code(), iodata()} - | {fragment, fin | nofin, text | binary, iodata()}. + | {fragment, fin | nofin, text | binary | continuation, iodata()}. -export_type([frame/0]). --type utf8_state() :: 0..8. --export_type([utf8_state/0]). - --type extensions() :: map(). -type frame_type() :: fragment | text | binary | close | ping | pong. +-export_type([frame_type/0]). + -type mask_key() :: undefined | 0..16#ffffffff. +-export_type([mask_key/0]). + -type rsv() :: <<_:3>>. +-export_type([rsv/0]). + +-type utf8_state() :: 0..8. +-export_type([utf8_state/0]). + +%% @doc Generate a key for the Websocket handshake request. + +-spec key() -> binary(). +key() -> + base64:encode(crypto:rand_bytes(16)). + +%% @doc Encode the key into the accept value for the Websocket handshake response. + +-spec encode_key(binary()) -> binary(). +encode_key(Key) -> + base64:encode(crypto:hash(sha, [Key, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"])). %% @doc Negotiate the permessage-deflate extension. @@ -54,7 +80,7 @@ negotiate_permessage_deflate(Params, Extensions, Opts) -> ignore; Params2 -> %% @todo Might want to make these configurable defaults. - case parse_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of + case parse_request_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of ignore -> ignore; {ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver, RespParams} -> @@ -68,33 +94,33 @@ negotiate_permessage_deflate(Params, Extensions, Opts) -> end end. -parse_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) -> +parse_request_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) -> {CB, CTO, SB, STO, RespParams}; -parse_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) -> - parse_permessage_deflate_params(Tail, CB, CTO, SB, STO, +parse_request_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) -> + parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO, [<<"; ">>, <<"client_max_window_bits=">>, integer_to_binary(CB)|RespParams]); -parse_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) -> +parse_request_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) -> case parse_max_window_bits(Max) of error -> ignore; CB -> - parse_permessage_deflate_params(Tail, CB, CTO, SB, STO, + parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO, [<<"; ">>, <<"client_max_window_bits=">>, Max|RespParams]) end; -parse_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) -> - parse_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]); -parse_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) -> +parse_request_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) -> + parse_request_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]); +parse_request_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) -> case parse_max_window_bits(Max) of error -> ignore; SB -> - parse_permessage_deflate_params(Tail, CB, CTO, SB, STO, + parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO, [<<"; ">>, <<"server_max_window_bits=">>, Max|RespParams]) end; -parse_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) -> - parse_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]); -%% Ignore if unknown parameter; ignore if parameter with invalid value. -parse_permessage_deflate_params(_, _, _, _, _, _) -> +parse_request_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) -> + parse_request_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]); +%% Ignore if unknown parameter; ignore if parameter with invalid or missing value. +parse_request_permessage_deflate_params(_, _, _, _, _, _) -> ignore. parse_max_window_bits(<<"8">>) -> 8; @@ -108,19 +134,19 @@ parse_max_window_bits(<<"15">>) -> 15; parse_max_window_bits(_) -> error. % A negative WindowBits value indicates that zlib headers are not used. -init_permessage_deflate(ClientWindowBits, ServerWindowBits, Opts) -> +init_permessage_deflate(InflateWindowBits, DeflateWindowBits, Opts) -> Inflate = zlib:open(), - ok = zlib:inflateInit(Inflate, -ClientWindowBits), + ok = zlib:inflateInit(Inflate, -InflateWindowBits), Deflate = zlib:open(), %% @todo Remove this case .. of for OTP 18+ if PR https://github.com/erlang/otp/pull/633 gets merged. - ServerWindowBits2 = case ServerWindowBits of + DeflateWindowBits2 = case DeflateWindowBits of 8 -> 9; - _ -> ServerWindowBits + _ -> DeflateWindowBits end, ok = zlib:deflateInit(Deflate, maps:get(level, Opts, best_compression), deflated, - -ServerWindowBits2, + -DeflateWindowBits2, maps:get(mem_level, Opts, 8), maps:get(strategy, Opts, default)), {Inflate, Deflate}. @@ -144,6 +170,51 @@ negotiate_x_webkit_deflate_frame(_Params, Extensions, Opts) -> inflate => Inflate, inflate_takeover => takeover}}. +%% @doc Validate the negotiated permessage-deflate extension. + +%% Error when more than one deflate extension was negotiated. +validate_permessage_deflate(_, #{deflate := _}, _) -> + error; +validate_permessage_deflate(Params, Extensions, Opts) -> + case lists:usort(Params) of + %% Error if multiple parameters with the same name. + Params2 when length(Params) =/= length(Params2) -> + error; + Params2 -> + %% @todo Might want to make some of these configurable defaults if at all possible. + case parse_response_permessage_deflate_params(Params2, 15, takeover, 15, takeover) of + error -> + error; + {ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver} -> + {Inflate, Deflate} = init_permessage_deflate(ServerWindowBits, ClientWindowBits, Opts), + {ok, Extensions#{ + deflate => Deflate, + deflate_takeover => ClientTakeOver, + inflate => Inflate, + inflate_takeover => ServerTakeOver}} + end + end. + +parse_response_permessage_deflate_params([], CB, CTO, SB, STO) -> + {CB, CTO, SB, STO}; +parse_response_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO) -> + case parse_max_window_bits(Max) of + error -> error; + CB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO) + end; +parse_response_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO) -> + parse_response_permessage_deflate_params(Tail, CB, no_takeover, SB, STO); +parse_response_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO) -> + case parse_max_window_bits(Max) of + error -> error; + SB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO) + end; +parse_response_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _) -> + parse_response_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover); +%% Error if unknown parameter; error if parameter with invalid or missing value. +parse_response_permessage_deflate_params(_, _, _, _, _) -> + error. + %% @doc Parse and validate the Websocket frame header. %% %% This function also updates the fragmentation state according to @@ -244,6 +315,7 @@ frag_state(_, 1, _, FragState) -> FragState. %% Empty last frame of compressed message. parse_payload(Data, _, Utf8State, _, _, 0, {fin, _, << 1:1, 0:2 >>}, #{inflate := Inflate, inflate_takeover := TakeOver}, _) -> + zlib:inflate(Inflate, << 0, 0, 255, 255 >>), case TakeOver of no_takeover -> zlib:inflateReset(Inflate); takeover -> ok @@ -307,34 +379,35 @@ validate_close_code(Code) -> true -> ok end. +unmask(Data, undefined, _) -> + Data; unmask(Data, MaskKey, 0) -> - do_unmask(Data, MaskKey, <<>>); + mask(Data, MaskKey, <<>>); %% We unmask on the fly so we need to continue from the right mask byte. unmask(Data, MaskKey, UnmaskedLen) -> Left = UnmaskedLen rem 4, Right = 4 - Left, MaskKey2 = (MaskKey bsl (Left * 8)) + (MaskKey bsr (Right * 8)), - do_unmask(Data, MaskKey2, <<>>). + mask(Data, MaskKey2, <<>>). -do_unmask(<<>>, _, Unmasked) -> +mask(<<>>, _, Unmasked) -> Unmasked; -do_unmask(<< O:32, Rest/bits >>, MaskKey, Acc) -> +mask(<< O:32, Rest/bits >>, MaskKey, Acc) -> T = O bxor MaskKey, - do_unmask(Rest, MaskKey, << Acc/binary, T:32 >>); -do_unmask(<< O:24 >>, MaskKey, Acc) -> + mask(Rest, MaskKey, << Acc/binary, T:32 >>); +mask(<< O:24 >>, MaskKey, Acc) -> << MaskKey2:24, _:8 >> = << MaskKey:32 >>, T = O bxor MaskKey2, << Acc/binary, T:24 >>; -do_unmask(<< O:16 >>, MaskKey, Acc) -> +mask(<< O:16 >>, MaskKey, Acc) -> << MaskKey2:16, _:16 >> = << MaskKey:32 >>, T = O bxor MaskKey2, << Acc/binary, T:16 >>; -do_unmask(<< O:8 >>, MaskKey, Acc) -> +mask(<< O:8 >>, MaskKey, Acc) -> << MaskKey2:8, _:24 >> = << MaskKey:32 >>, T = O bxor MaskKey2, << Acc/binary, T:8 >>. -%% @todo Try using iodata() and see if it improves anything. inflate_frame(Data, Inflate, TakeOver, FragState, true) when FragState =:= undefined; element(1, FragState) =:= fin -> Data2 = zlib:inflate(Inflate, << Data/binary, 0, 0, 255, 255 >>), @@ -416,7 +489,6 @@ make_frame(pong, <<>>, _, _) -> pong; make_frame(pong, Payload, _, _) -> {pong, Payload}. %% @doc Construct an unmasked Websocket frame. -%% @todo Add fragments support. -spec frame(frame(), extensions()) -> iodata(). %% Control frames. Control packets must not be > 125 in length. @@ -457,6 +529,56 @@ frame({binary, Payload}, _) -> Len = payload_length(Payload), [<< 1:1, 0:3, 2:4, 0:1, Len/bits >>, Payload]. +%% @doc Construct a masked Websocket frame. +%% +%% We use a mask key of 0 if there is no payload for close, ping and pong frames. + +-spec masked_frame(frame(), extensions()) -> iodata(). +%% Control frames. Control packets must not be > 125 in length. +masked_frame(close, _) -> + << 1:1, 0:3, 8:4, 1:1, 0:39 >>; +masked_frame(ping, _) -> + << 1:1, 0:3, 9:4, 1:1, 0:39 >>; +masked_frame(pong, _) -> + << 1:1, 0:3, 10:4, 1:1, 0:39 >>; +masked_frame({close, Payload}, Extensions) -> + frame({close, 1000, Payload}, Extensions); +masked_frame({close, StatusCode, Payload}, _) -> + Len = 2 + iolist_size(Payload), + true = Len =< 125, + MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4), + [<< 1:1, 0:3, 8:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary([<< StatusCode:16 >>, Payload]), MaskKey, <<>>)]; +masked_frame({ping, Payload}, _) -> + Len = iolist_size(Payload), + true = Len =< 125, + MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4), + [<< 1:1, 0:3, 9:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)]; +masked_frame({pong, Payload}, _) -> + Len = iolist_size(Payload), + true = Len =< 125, + MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4), + [<< 1:1, 0:3, 10:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)]; +%% Data frames, deflate-frame extension. +masked_frame({text, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) -> + MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4), + Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>), + Len = payload_length(Payload2), + [<< 1:1, 1:1, 0:2, 1:4, 1:1, Len/bits >>, MaskKeyBin, Payload2]; +masked_frame({binary, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) -> + MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4), + Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>), + Len = payload_length(Payload2), + [<< 1:1, 1:1, 0:2, 2:4, 1:1, Len/bits >>, MaskKeyBin, Payload2]; +%% Data frames. +masked_frame({text, Payload}, _) -> + MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4), + Len = payload_length(Payload), + [<< 1:1, 0:3, 1:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)]; +masked_frame({binary, Payload}, _) -> + MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4), + Len = payload_length(Payload), + [<< 1:1, 0:3, 2:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)]. + payload_length(Payload) -> case byte_size(Payload) of N when N =< 125 -> << N:7 >>; -- cgit v1.2.3