From 0a4b14ed1ab626c9eb70f88398780b735fbfad45 Mon Sep 17 00:00:00 2001 From: Ujin Date: Mon, 27 Jun 2022 05:17:57 +0300 Subject: [PATCH 1/3] permessage-deflate support --- src/WebSockets.jl | 57 +++++++++++++++++++++++++++++++++++------------ 1 file changed, 43 insertions(+), 14 deletions(-) diff --git a/src/WebSockets.jl b/src/WebSockets.jl index 14342f49d..c830d6100 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -1,6 +1,6 @@ module WebSockets -using Base64, LoggingExtras, UUIDs, Sockets, Random +using Base64, LoggingExtras, UUIDs, Sockets, Random, CodecZlib using MbedTLS: digest, MD_SHA1, SSLContext using ..IOExtras, ..Streams, ..ConnectionPool, ..Messages, ..Conditions, ..Servers import ..open @@ -55,7 +55,7 @@ FrameFlags(final::Bool, opcode::OpCode, masked::Bool, len::Integer; rsv1::Bool=f ) Base.show(io::IO, x::FrameFlags) = - print(io, "FrameFlags(", "final=", x.final, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")") + print(io, "FrameFlags(", "final=", x.final, ", isdeflate=", x.rsv1, ", ", "opcode=", x.opcode, ", ", "masked=", x.masked, ", ", "len=", x.len, ")") primitive type Mask 32 end Base.UInt32(x::Mask) = Base.bitcast(UInt32, x) @@ -91,6 +91,27 @@ function mask!(bytes::Vector{UInt8}, mask) return end +function compress(data::T) where T <: AbstractVector{UInt8} + compressed = transcode(DeflateCompressor, data) + return vcat(compressed, 0x00) +end + +function compress(data::String) + compressed = transcode(DeflateCompressor, data) + return String(vcat(compressed, 0x00)) +end + +function decompress(data::T) where T <: AbstractVector{UInt8} + decompressed = transcode(DeflateDecompressor, vcat(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return decompressed +end + +function decompress(data::String) + decompressed = transcode(DeflateDecompressor, vcat(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return String(decompressed) +end + + # send method Frame constructor function Frame(final::Bool, opcode::OpCode, client::Bool, payload::AbstractVector{UInt8}; rsv1::Bool=false, rsv2::Bool=false, rsv3::Bool=false) len, extlen = wslength(length(payload)) @@ -293,12 +314,13 @@ mutable struct WebSocket writebuffer::Vector{UInt8} readclosed::Bool writeclosed::Bool + isdeflate::Bool end const DEFAULT_MAX_FRAG = 1024 -WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG) = - WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false) +WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate::Bool=false) = + WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate) """ WebSockets.isclosed(ws) -> Bool @@ -347,7 +369,7 @@ WebSockets.open(url) do ws end ``` """ -function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) +function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, headers=[], maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...) key = base64encode(rand(Random.RandomDevice(), UInt8, 16)) headers = [ "Upgrade" => "websocket", @@ -363,13 +385,14 @@ function open(f::Function, url; suppress_close_error::Bool=false, verbose=false, if header(http, "Sec-WebSocket-Accept") != hashedkey(key) throw(WebSocketError("Invalid Sec-WebSocket-Accept\n" * "$(http.message)")) end + isdeflate = occursin("permessage-deflate", header(http, "Sec-Websocket-Extensions")) # later stream logic checks to see if the HTTP message is "complete" # by seeing if ntoread is 0, which is typemax(Int) for websockets by default # so set it to 0 so it's correctly viewed as "complete" once we're done # doing websocket things http.ntoread = 0 io = http.stream - ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation) + ws = WebSocket(io, http.message.request, http.message; maxframesize, maxfragmentation, isdeflate) @debugv 2 "$(ws.id): WebSocket opened" try f(ws) @@ -416,7 +439,8 @@ function listen end listen(f, args...; kw...) = Servers.listen(http -> upgrade(f, http; kw...), args...; kw...) listen!(f, args...; kw...) = Servers.listen!(http -> upgrade(f, http; kw...), args...; kw...) -function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, kw...) +function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=false, maxframesize::Integer=typemax(Int), + maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate=false, kw...) @debugv 2 "Server websocket upgrade requested" isupgrade(http.message) || handshakeerror() if !hasheader(http, "Sec-WebSocket-Version", "13") @@ -430,10 +454,11 @@ function upgrade(f::Function, http::Streams.Stream; suppress_close_error::Bool=f setheader(http, "Connection" => "Upgrade") key = header(http, "Sec-WebSocket-Key") setheader(http, "Sec-WebSocket-Accept" => hashedkey(key)) + isdeflate && setheader(http, "Sec-Websocket-Extensions" => "permessage-deflate; client_no_context_takeover") startwrite(http) io = http.stream req = http.message - ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation) + ws = WebSocket(io, req, req.response; client=false, maxframesize, maxfragmentation, isdeflate) @debugv 2 "$(ws.id): WebSocket upgraded; connection established" try f(ws) @@ -507,7 +532,7 @@ function Sockets.send(ws::WebSocket, x) # so we can appropriately set the FIN bit for the last fragmented frame nextstate = iterate(x, st) while true - n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item))) + n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? ws.isdeflate : false)) first = false nextstate === nothing && break item, st = nextstate @@ -516,7 +541,8 @@ function Sockets.send(ws::WebSocket, x) else # single binary or text frame for message @label write_single_frame - return writeframe(ws.io, Frame(true, opcode(x), ws.client, payload(ws, x))) + pl = ws.isdeflate ? compress(payload(ws, x)) : payload(ws, x) + return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=ws.isdeflate)) end end @@ -603,7 +629,7 @@ end @noinline utf8check(x) = isvalid(x) || throw(WebSocketError(CloseFrameBody(1007, "Invalid UTF-8"))) function checkreadframe!(ws::WebSocket, frame::Frame) - if frame.flags.rsv1 || frame.flags.rsv2 || frame.flags.rsv3 + if frame.flags.rsv2 || frame.flags.rsv3 throw(WebSocketError(CloseFrameBody(1002, "Reserved bits set in control frame"))) end opcode = frame.flags.opcode @@ -624,8 +650,6 @@ function checkreadframe!(ws::WebSocket, frame::Frame) elseif opcode == PONG control_len_check(frame.flags.len) return false - elseif frame.flags.final && frame.flags.opcode == TEXT && frame.payload isa String - utf8check(frame.payload) end return frame.flags.final end @@ -659,7 +683,11 @@ function receive(ws::WebSocket) @debugv 2 "$(ws.id): Received frame: $frame" done = checkreadframe!(ws, frame) # common case of reading single non-control frame - done && return frame.payload + if done + payload = ws.isdeflate ? decompress(frame.payload) : frame.payload + payload isa String && utf8check(payload) + return payload + end opcode = frame.flags.opcode iscontrol(opcode) && return receive(ws) # if we're here, we're reading a fragmented message @@ -674,6 +702,7 @@ function receive(ws::WebSocket) end done && break end + payload = ws.isdeflate ? decompress(payload) : payload payload isa String && utf8check(payload) @debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])" return payload From 544e283e3c9f7d9fa0ff5d48dd215e0efd8bbd2b Mon Sep 17 00:00:00 2001 From: Ujin Date: Sun, 3 Jul 2022 06:18:21 +0300 Subject: [PATCH 2/3] fix array allocation --- src/WebSockets.jl | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/WebSockets.jl b/src/WebSockets.jl index c830d6100..c5b180ed1 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -93,21 +93,23 @@ end function compress(data::T) where T <: AbstractVector{UInt8} compressed = transcode(DeflateCompressor, data) - return vcat(compressed, 0x00) + push!(compressed, 0x00) + return compressed end function compress(data::String) compressed = transcode(DeflateCompressor, data) - return String(vcat(compressed, 0x00)) + push!(compressed, 0x00) + return String(compressed) end function decompress(data::T) where T <: AbstractVector{UInt8} - decompressed = transcode(DeflateDecompressor, vcat(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + decompressed = transcode(DeflateDecompressor, append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) return decompressed end function decompress(data::String) - decompressed = transcode(DeflateDecompressor, vcat(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + decompressed = transcode(DeflateDecompressor, append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) return String(decompressed) end From c3a3b7c79e23f4de61cfdaa311e2f0912d9ce9e4 Mon Sep 17 00:00:00 2001 From: Ujin Date: Fri, 5 Aug 2022 15:19:20 +0300 Subject: [PATCH 3/3] fix codec initializing --- src/WebSockets.jl | 69 ++++++++++++++++++++++++++++------------------- 1 file changed, 42 insertions(+), 27 deletions(-) diff --git a/src/WebSockets.jl b/src/WebSockets.jl index c5b180ed1..7631c9eae 100644 --- a/src/WebSockets.jl +++ b/src/WebSockets.jl @@ -90,27 +90,18 @@ function mask!(bytes::Vector{UInt8}, mask) end return end - -function compress(data::T) where T <: AbstractVector{UInt8} - compressed = transcode(DeflateCompressor, data) - push!(compressed, 0x00) - return compressed -end - -function compress(data::String) - compressed = transcode(DeflateCompressor, data) - push!(compressed, 0x00) - return String(compressed) +function final_deflate_codecs(t::Tuple) + CodecZlib.TranscodingStreams.finalize(t[1]) + CodecZlib.TranscodingStreams.finalize(t[2]) end -function decompress(data::T) where T <: AbstractVector{UInt8} - decompressed = transcode(DeflateDecompressor, append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) - return decompressed -end +function init_deflate_codecs() + codecco = DeflateCompressor() + CodecZlib.TranscodingStreams.initialize(codecco) + codecde = DeflateDecompressor() + CodecZlib.TranscodingStreams.initialize(codecde) -function decompress(data::String) - decompressed = transcode(DeflateDecompressor, append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) - return String(decompressed) + return (codecco, codecde) end @@ -316,13 +307,13 @@ mutable struct WebSocket writebuffer::Vector{UInt8} readclosed::Bool writeclosed::Bool - isdeflate::Bool + deflate::Union{Nothing, Tuple{CodecZlib.CompressorCodec, CodecZlib.DecompressorCodec}} end const DEFAULT_MAX_FRAG = 1024 WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, maxframesize::Integer=typemax(Int), maxfragmentation::Integer=DEFAULT_MAX_FRAG, isdeflate::Bool=false) = - WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate) + WebSocket(uuid4(), io, req, resp, maxframesize, maxfragmentation, client, UInt8[], UInt8[], false, false, isdeflate ? init_deflate_codecs() : nothing) """ WebSockets.isclosed(ws) -> Bool @@ -330,6 +321,7 @@ WebSocket(io::Connection, req=Request(), resp=Response(); client::Bool=true, max Check whether a `WebSocket` has sent and received CLOSE frames. """ isclosed(ws::WebSocket) = ws.readclosed && ws.writeclosed +isdeflate(ws::WebSocket) = !isnothing(ws.deflate) # Handshake "Check whether a HTTP.Request or HTTP.Response is a websocket upgrade request/response" @@ -534,7 +526,7 @@ function Sockets.send(ws::WebSocket, x) # so we can appropriately set the FIN bit for the last fragmented frame nextstate = iterate(x, st) while true - n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? ws.isdeflate : false)) + n += writeframe(ws.io, Frame(nextstate === nothing, first ? opcode(item) : CONTINUATION, ws.client, payload(ws, item); rsv1 = first ? isdeflate(ws) : false)) first = false nextstate === nothing && break item, st = nextstate @@ -543,8 +535,8 @@ function Sockets.send(ws::WebSocket, x) else # single binary or text frame for message @label write_single_frame - pl = ws.isdeflate ? compress(payload(ws, x)) : payload(ws, x) - return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=ws.isdeflate)) + pl = isdeflate(ws) ? compress(ws, payload(ws, x)) : payload(ws, x) + return writeframe(ws.io, Frame(true, opcode(x), ws.client, pl; rsv1=isdeflate(ws))) end end @@ -559,7 +551,7 @@ to when a PING message is received by a websocket connection. function ping(ws::WebSocket, data=UInt8[]) @require !ws.writeclosed @debugv 2 "$(ws.id): sending ping" - return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, data))) + return writeframe(ws.io, Frame(true, PING, ws.client, payload(ws, isdeflate(ws) ? compress(ws, data) : data))) end """ @@ -620,11 +612,34 @@ function Base.close(ws::WebSocket, body::CloseFrameBody=CloseFrameBody(1000, "") @assert ws.readclosed # if we're the server, it's our job to close the underlying socket !ws.client && isopen(ws.io) && close(ws.io) + final_deflate_codecs(ws.deflate) return end # Receiving messages +function compress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8} + compressed = transcode(ws.deflate[1], data) + push!(compressed, 0x00) + return compressed +end + +function compress(ws::WebSocket, data::String) + compressed = transcode(ws.deflate[1], data) + push!(compressed, 0x00) + return String(compressed) +end + +function decompress(ws::WebSocket, data::T) where T <: AbstractVector{UInt8} + decompressed = transcode(ws.deflate[2], append!(data, [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return decompressed +end + +function decompress(ws::WebSocket, data::String) + decompressed = transcode(ws.deflate[2], append!(Vector{UInt8}(data), [0x00, 0x00, 0xff, 0xff, 0x03, 0x00])) + return String(decompressed) +end + # returns whether additional frames should be read # true if fragmented message or a ping/pong frame was handled @noinline control_len_check(len) = len > 125 && throw(WebSocketError(CloseFrameBody(1002, "Invalid length for control frame"))) @@ -644,7 +659,7 @@ function checkreadframe!(ws::WebSocket, frame::Frame) if !ws.writeclosed close(ws) end - throw(WebSocketError(frame.payload)) + throw(WebSocketError(isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload)) elseif opcode == PING control_len_check(frame.flags.len) pong(ws, frame.payload) @@ -686,7 +701,7 @@ function receive(ws::WebSocket) done = checkreadframe!(ws, frame) # common case of reading single non-control frame if done - payload = ws.isdeflate ? decompress(frame.payload) : frame.payload + payload = isdeflate(ws) ? decompress(ws, frame.payload) : frame.payload payload isa String && utf8check(payload) return payload end @@ -704,7 +719,7 @@ function receive(ws::WebSocket) end done && break end - payload = ws.isdeflate ? decompress(payload) : payload + payload = isdeflate(ws) ? decompress(ws, payload) : payload payload isa String && utf8check(payload) @debugv 2 "Read message: $(payload[1:min(1024, sizeof(payload))])" return payload