aboutsummaryrefslogtreecommitdiffstats
path: root/src/cow_ws.erl
diff options
context:
space:
mode:
authorLoïc Hoguin <[email protected]>2015-03-12 16:02:06 +0100
committerLoïc Hoguin <[email protected]>2015-03-12 16:02:06 +0100
commita8db5d9f7a21f36d4582dd083001757513e5990e (patch)
tree525a3cd83162f7833fdca6bc7df815a4c3b7ee4b /src/cow_ws.erl
parent0bd62c8920cd7c0f901207320f07229452cc5681 (diff)
downloadcowlib-a8db5d9f7a21f36d4582dd083001757513e5990e.tar.gz
cowlib-a8db5d9f7a21f36d4582dd083001757513e5990e.tar.bz2
cowlib-a8db5d9f7a21f36d4582dd083001757513e5990e.zip
Add missing client functionality to Websocket code
Diffstat (limited to 'src/cow_ws.erl')
-rw-r--r--src/cow_ws.erl190
1 files changed, 156 insertions, 34 deletions
diff --git a/src/cow_ws.erl b/src/cow_ws.erl
index 0858b68..c89c17a 100644
--- a/src/cow_ws.erl
+++ b/src/cow_ws.erl
@@ -14,33 +14,59 @@
-module(cow_ws).
+-export([key/0]).
+-export([encode_key/1]).
+
-export([negotiate_permessage_deflate/3]).
-export([negotiate_x_webkit_deflate_frame/3]).
+-export([validate_permessage_deflate/3]).
+
-export([parse_header/3]).
-export([parse_payload/9]).
-export([make_frame/4]).
+
-export([frame/2]).
+-export([masked_frame/2]).
-type close_code() :: 1000..1003 | 1006..1011 | 3000..4999.
-export_type([close_code/0]).
+-type extensions() :: map().
+-export_type([extensions/0]).
+
-type frag_state() :: undefined | {fin | nofin, text | binary, rsv()}.
-export_type([frag_state/0]).
-type frame() :: close | ping | pong
| {text | binary | close | ping | pong, iodata()}
| {close, close_code(), iodata()}
- | {fragment, fin | nofin, text | binary, iodata()}.
+ | {fragment, fin | nofin, text | binary | continuation, iodata()}.
-export_type([frame/0]).
--type utf8_state() :: 0..8.
--export_type([utf8_state/0]).
-
--type extensions() :: map().
-type frame_type() :: fragment | text | binary | close | ping | pong.
+-export_type([frame_type/0]).
+
-type mask_key() :: undefined | 0..16#ffffffff.
+-export_type([mask_key/0]).
+
-type rsv() :: <<_:3>>.
+-export_type([rsv/0]).
+
+-type utf8_state() :: 0..8.
+-export_type([utf8_state/0]).
+
+%% @doc Generate a key for the Websocket handshake request.
+
+-spec key() -> binary().
+key() ->
+ base64:encode(crypto:rand_bytes(16)).
+
+%% @doc Encode the key into the accept value for the Websocket handshake response.
+
+-spec encode_key(binary()) -> binary().
+encode_key(Key) ->
+ base64:encode(crypto:hash(sha, [Key, "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"])).
%% @doc Negotiate the permessage-deflate extension.
@@ -54,7 +80,7 @@ negotiate_permessage_deflate(Params, Extensions, Opts) ->
ignore;
Params2 ->
%% @todo Might want to make these configurable defaults.
- case parse_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of
+ case parse_request_permessage_deflate_params(Params2, 15, takeover, 15, takeover, []) of
ignore ->
ignore;
{ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver, RespParams} ->
@@ -68,33 +94,33 @@ negotiate_permessage_deflate(Params, Extensions, Opts) ->
end
end.
-parse_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) ->
+parse_request_permessage_deflate_params([], CB, CTO, SB, STO, RespParams) ->
{CB, CTO, SB, STO, RespParams};
-parse_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) ->
- parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
+parse_request_permessage_deflate_params([<<"client_max_window_bits">>|Tail], CB, CTO, SB, STO, RespParams) ->
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
[<<"; ">>, <<"client_max_window_bits=">>, integer_to_binary(CB)|RespParams]);
-parse_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) ->
+parse_request_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO, RespParams) ->
case parse_max_window_bits(Max) of
error ->
ignore;
CB ->
- parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
[<<"; ">>, <<"client_max_window_bits=">>, Max|RespParams])
end;
-parse_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) ->
- parse_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]);
-parse_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) ->
+parse_request_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO, RespParams) ->
+ parse_request_permessage_deflate_params(Tail, CB, no_takeover, SB, STO, [<<"; ">>, <<"client_no_context_takeover">>|RespParams]);
+parse_request_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO, RespParams) ->
case parse_max_window_bits(Max) of
error ->
ignore;
SB ->
- parse_permessage_deflate_params(Tail, CB, CTO, SB, STO,
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, STO,
[<<"; ">>, <<"server_max_window_bits=">>, Max|RespParams])
end;
-parse_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) ->
- parse_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]);
-%% Ignore if unknown parameter; ignore if parameter with invalid value.
-parse_permessage_deflate_params(_, _, _, _, _, _) ->
+parse_request_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _, RespParams) ->
+ parse_request_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover, [<<"; ">>, <<"server_no_context_takeover">>|RespParams]);
+%% Ignore if unknown parameter; ignore if parameter with invalid or missing value.
+parse_request_permessage_deflate_params(_, _, _, _, _, _) ->
ignore.
parse_max_window_bits(<<"8">>) -> 8;
@@ -108,19 +134,19 @@ parse_max_window_bits(<<"15">>) -> 15;
parse_max_window_bits(_) -> error.
% A negative WindowBits value indicates that zlib headers are not used.
-init_permessage_deflate(ClientWindowBits, ServerWindowBits, Opts) ->
+init_permessage_deflate(InflateWindowBits, DeflateWindowBits, Opts) ->
Inflate = zlib:open(),
- ok = zlib:inflateInit(Inflate, -ClientWindowBits),
+ ok = zlib:inflateInit(Inflate, -InflateWindowBits),
Deflate = zlib:open(),
%% @todo Remove this case .. of for OTP 18+ if PR https://github.com/erlang/otp/pull/633 gets merged.
- ServerWindowBits2 = case ServerWindowBits of
+ DeflateWindowBits2 = case DeflateWindowBits of
8 -> 9;
- _ -> ServerWindowBits
+ _ -> DeflateWindowBits
end,
ok = zlib:deflateInit(Deflate,
maps:get(level, Opts, best_compression),
deflated,
- -ServerWindowBits2,
+ -DeflateWindowBits2,
maps:get(mem_level, Opts, 8),
maps:get(strategy, Opts, default)),
{Inflate, Deflate}.
@@ -144,6 +170,51 @@ negotiate_x_webkit_deflate_frame(_Params, Extensions, Opts) ->
inflate => Inflate,
inflate_takeover => takeover}}.
+%% @doc Validate the negotiated permessage-deflate extension.
+
+%% Error when more than one deflate extension was negotiated.
+validate_permessage_deflate(_, #{deflate := _}, _) ->
+ error;
+validate_permessage_deflate(Params, Extensions, Opts) ->
+ case lists:usort(Params) of
+ %% Error if multiple parameters with the same name.
+ Params2 when length(Params) =/= length(Params2) ->
+ error;
+ Params2 ->
+ %% @todo Might want to make some of these configurable defaults if at all possible.
+ case parse_response_permessage_deflate_params(Params2, 15, takeover, 15, takeover) of
+ error ->
+ error;
+ {ClientWindowBits, ClientTakeOver, ServerWindowBits, ServerTakeOver} ->
+ {Inflate, Deflate} = init_permessage_deflate(ServerWindowBits, ClientWindowBits, Opts),
+ {ok, Extensions#{
+ deflate => Deflate,
+ deflate_takeover => ClientTakeOver,
+ inflate => Inflate,
+ inflate_takeover => ServerTakeOver}}
+ end
+ end.
+
+parse_response_permessage_deflate_params([], CB, CTO, SB, STO) ->
+ {CB, CTO, SB, STO};
+parse_response_permessage_deflate_params([{<<"client_max_window_bits">>, Max}|Tail], _, CTO, SB, STO) ->
+ case parse_max_window_bits(Max) of
+ error -> error;
+ CB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO)
+ end;
+parse_response_permessage_deflate_params([<<"client_no_context_takeover">>|Tail], CB, _, SB, STO) ->
+ parse_response_permessage_deflate_params(Tail, CB, no_takeover, SB, STO);
+parse_response_permessage_deflate_params([{<<"server_max_window_bits">>, Max}|Tail], CB, CTO, _, STO) ->
+ case parse_max_window_bits(Max) of
+ error -> error;
+ SB -> parse_response_permessage_deflate_params(Tail, CB, CTO, SB, STO)
+ end;
+parse_response_permessage_deflate_params([<<"server_no_context_takeover">>|Tail], CB, CTO, SB, _) ->
+ parse_response_permessage_deflate_params(Tail, CB, CTO, SB, no_takeover);
+%% Error if unknown parameter; error if parameter with invalid or missing value.
+parse_response_permessage_deflate_params(_, _, _, _, _) ->
+ error.
+
%% @doc Parse and validate the Websocket frame header.
%%
%% This function also updates the fragmentation state according to
@@ -244,6 +315,7 @@ frag_state(_, 1, _, FragState) -> FragState.
%% Empty last frame of compressed message.
parse_payload(Data, _, Utf8State, _, _, 0, {fin, _, << 1:1, 0:2 >>},
#{inflate := Inflate, inflate_takeover := TakeOver}, _) ->
+ zlib:inflate(Inflate, << 0, 0, 255, 255 >>),
case TakeOver of
no_takeover -> zlib:inflateReset(Inflate);
takeover -> ok
@@ -307,34 +379,35 @@ validate_close_code(Code) ->
true -> ok
end.
+unmask(Data, undefined, _) ->
+ Data;
unmask(Data, MaskKey, 0) ->
- do_unmask(Data, MaskKey, <<>>);
+ mask(Data, MaskKey, <<>>);
%% We unmask on the fly so we need to continue from the right mask byte.
unmask(Data, MaskKey, UnmaskedLen) ->
Left = UnmaskedLen rem 4,
Right = 4 - Left,
MaskKey2 = (MaskKey bsl (Left * 8)) + (MaskKey bsr (Right * 8)),
- do_unmask(Data, MaskKey2, <<>>).
+ mask(Data, MaskKey2, <<>>).
-do_unmask(<<>>, _, Unmasked) ->
+mask(<<>>, _, Unmasked) ->
Unmasked;
-do_unmask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
+mask(<< O:32, Rest/bits >>, MaskKey, Acc) ->
T = O bxor MaskKey,
- do_unmask(Rest, MaskKey, << Acc/binary, T:32 >>);
-do_unmask(<< O:24 >>, MaskKey, Acc) ->
+ mask(Rest, MaskKey, << Acc/binary, T:32 >>);
+mask(<< O:24 >>, MaskKey, Acc) ->
<< MaskKey2:24, _:8 >> = << MaskKey:32 >>,
T = O bxor MaskKey2,
<< Acc/binary, T:24 >>;
-do_unmask(<< O:16 >>, MaskKey, Acc) ->
+mask(<< O:16 >>, MaskKey, Acc) ->
<< MaskKey2:16, _:16 >> = << MaskKey:32 >>,
T = O bxor MaskKey2,
<< Acc/binary, T:16 >>;
-do_unmask(<< O:8 >>, MaskKey, Acc) ->
+mask(<< O:8 >>, MaskKey, Acc) ->
<< MaskKey2:8, _:24 >> = << MaskKey:32 >>,
T = O bxor MaskKey2,
<< Acc/binary, T:8 >>.
-%% @todo Try using iodata() and see if it improves anything.
inflate_frame(Data, Inflate, TakeOver, FragState, true)
when FragState =:= undefined; element(1, FragState) =:= fin ->
Data2 = zlib:inflate(Inflate, << Data/binary, 0, 0, 255, 255 >>),
@@ -416,7 +489,6 @@ make_frame(pong, <<>>, _, _) -> pong;
make_frame(pong, Payload, _, _) -> {pong, Payload}.
%% @doc Construct an unmasked Websocket frame.
-%% @todo Add fragments support.
-spec frame(frame(), extensions()) -> iodata().
%% Control frames. Control packets must not be > 125 in length.
@@ -457,6 +529,56 @@ frame({binary, Payload}, _) ->
Len = payload_length(Payload),
[<< 1:1, 0:3, 2:4, 0:1, Len/bits >>, Payload].
+%% @doc Construct a masked Websocket frame.
+%%
+%% We use a mask key of 0 if there is no payload for close, ping and pong frames.
+
+-spec masked_frame(frame(), extensions()) -> iodata().
+%% Control frames. Control packets must not be > 125 in length.
+masked_frame(close, _) ->
+ << 1:1, 0:3, 8:4, 1:1, 0:39 >>;
+masked_frame(ping, _) ->
+ << 1:1, 0:3, 9:4, 1:1, 0:39 >>;
+masked_frame(pong, _) ->
+ << 1:1, 0:3, 10:4, 1:1, 0:39 >>;
+masked_frame({close, Payload}, Extensions) ->
+ frame({close, 1000, Payload}, Extensions);
+masked_frame({close, StatusCode, Payload}, _) ->
+ Len = 2 + iolist_size(Payload),
+ true = Len =< 125,
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+ [<< 1:1, 0:3, 8:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary([<< StatusCode:16 >>, Payload]), MaskKey, <<>>)];
+masked_frame({ping, Payload}, _) ->
+ Len = iolist_size(Payload),
+ true = Len =< 125,
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+ [<< 1:1, 0:3, 9:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
+masked_frame({pong, Payload}, _) ->
+ Len = iolist_size(Payload),
+ true = Len =< 125,
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+ [<< 1:1, 0:3, 10:4, 1:1, Len:7 >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
+%% Data frames, deflate-frame extension.
+masked_frame({text, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) ->
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+ Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>),
+ Len = payload_length(Payload2),
+ [<< 1:1, 1:1, 0:2, 1:4, 1:1, Len/bits >>, MaskKeyBin, Payload2];
+masked_frame({binary, Payload}, #{deflate := Deflate, deflate_takeover := TakeOver}) ->
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+ Payload2 = mask(deflate_frame(Payload, Deflate, TakeOver), MaskKey, <<>>),
+ Len = payload_length(Payload2),
+ [<< 1:1, 1:1, 0:2, 2:4, 1:1, Len/bits >>, MaskKeyBin, Payload2];
+%% Data frames.
+masked_frame({text, Payload}, _) ->
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+ Len = payload_length(Payload),
+ [<< 1:1, 0:3, 1:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)];
+masked_frame({binary, Payload}, _) ->
+ MaskKeyBin = << MaskKey:32 >> = crypto:rand_bytes(4),
+ Len = payload_length(Payload),
+ [<< 1:1, 0:3, 2:4, 1:1, Len/bits >>, MaskKeyBin, mask(iolist_to_binary(Payload), MaskKey, <<>>)].
+
payload_length(Payload) ->
case byte_size(Payload) of
N when N =< 125 -> << N:7 >>;