From 0401e4eaca2c31c8500e1e2d9db68ea6b25fc5e9 Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Tue, 6 Aug 2024 01:21:54 -0700 Subject: [PATCH 1/5] feat: yield the reader if reads not scheduled --- lib/client_connection.ml | 4 +-- lib/parse.ml | 68 +++++++++++++++++++++++++------------ lib/payload.ml | 47 +++++++++++++++++-------- lib/server_connection.ml | 2 +- lib/websocket_connection.ml | 49 ++++++++++++++++++++++---- 5 files changed, 125 insertions(+), 45 deletions(-) diff --git a/lib/client_connection.ml b/lib/client_connection.ml index eb9a6c2f..ffc3e678 100644 --- a/lib/client_connection.ml +++ b/lib/client_connection.ml @@ -121,7 +121,7 @@ let next_read_operation t = (* TODO(anmonteiro): handle this *) assert false (* set_error_and_handle t (`Exn (Failure message)); `Close *) - | (`Read | `Close) as operation -> operation + | (`Read | `Yield | `Close) as operation -> operation let read t bs ~off ~len = match t.state with @@ -152,7 +152,7 @@ let report_exn t exn = let yield_reader t f = match t.state with | Handshake handshake -> Client_handshake.yield_reader handshake f - | Websocket _websocket -> assert false + | Websocket websocket -> Websocket_connection.yield_reader websocket f let yield_writer t f = match t.state with diff --git a/lib/parse.ml b/lib/parse.ml index 3a5016a2..e955cdcf 100644 --- a/lib/parse.ml +++ b/lib/parse.ml @@ -2,7 +2,6 @@ type t = { payload_length: int ; is_fin: bool ; mask: int32 option - ; payload: Payload.t ; opcode: Websocket.Opcode.t } @@ -87,7 +86,7 @@ let parse_headers = >>= fun headers_len -> Unsafe.take headers_len Bigstringaf.sub ;; -let payload_parser t = +let payload_parser t payload = let open Angstrom in let unmask t bs ~src_off = match t.mask with @@ -120,16 +119,16 @@ let payload_parser t = available >>= fun m -> let m' = (min m n) in let n' = n - m' in - schedule_size ~src_off t.payload m' + schedule_size ~src_off payload m' >>= fun () -> read_exact (src_off + m') n' in fun n -> read_exact 0 n in read_exact t.payload_length - >>= fun () -> finish t.payload + >>= fun () -> finish payload ;; -let frame ~buf = +let frame = let open Angstrom in parse_headers >>| fun headers -> @@ -137,11 +136,7 @@ let frame ~buf = and is_fin = is_fin headers and opcode = opcode headers and mask = mask headers in - let payload = match payload_length with - | 0 -> Payload.empty - | _ -> Payload.create buf - in - { is_fin; opcode; mask; payload_length; payload } + { is_fin; opcode; mask; payload_length } ;; module Reader = struct @@ -155,25 +150,54 @@ module Reader = struct type 'error t = { parser : unit Angstrom.t ; mutable parse_state : 'error parse_state - ; mutable closed : bool } + ; mutable closed : bool + ; mutable wakeup : Optional_thunk.t + } + + let wakeup t = + let f = t.wakeup in + t.wakeup <- Optional_thunk.none; + Optional_thunk.call_if_some f - let create frame_handler = - let parser = + let create frame_handler ~on_payload_eof = + let rec parser t = let open Angstrom in let buf = Bigstringaf.create 0x1000 in skip_many - (frame ~buf <* commit >>= fun frame -> - let payload = frame.payload in - let { is_fin; opcode; payload_length = len; _ } = frame in - frame_handler ~opcode ~is_fin ~len payload; - payload_parser frame) + (frame <* commit >>= fun frame -> + let { is_fin; opcode; payload_length; _ } = frame in + let payload = + match payload_length with + | 0 -> Payload.create_empty ~on_payload_eof + | _ -> + Payload.create buf + ~when_ready_to_read:(Optional_thunk.some (fun () -> + wakeup (Lazy.force t))) + ~on_payload_eof + in + frame_handler ~opcode ~is_fin ~len:payload_length payload; + payload_parser frame payload) + and t = lazy ( + { parser = parser t + ; parse_state = Done + ; closed = false + ; wakeup = Optional_thunk.none + } + ) in - { parser - ; parse_state = Done - ; closed = false - } + Lazy.force t ;; + let is_closed t = + t.closed + + let on_wakeup t k = + if is_closed t + then failwith "on_wakeup on closed reader" + else if Optional_thunk.is_some t.wakeup + then failwith "on_wakeup: only one callback can be registered at a time" + else t.wakeup <- Optional_thunk.some k + let transition t state = match state with | AU.Done(consumed, ()) diff --git a/lib/payload.ml b/lib/payload.ml index dc4e8853..f75317ce 100644 --- a/lib/payload.ml +++ b/lib/payload.ml @@ -38,35 +38,44 @@ module IOVec = Httpun.IOVec { faraday : Faraday.t ; mutable read_scheduled : bool ; mutable on_eof : unit -> unit + ; mutable eof_has_been_called : bool ; mutable on_read : Bigstringaf.t -> off:int -> len:int -> unit + ; when_ready_to_read : Optional_thunk.t + ; on_payload_eof : unit -> unit } let default_on_eof = Sys.opaque_identity (fun () -> ()) let default_on_read = Sys.opaque_identity (fun _ ~off:_ ~len:_ -> ()) - let of_faraday faraday = - { faraday + let create buffer ~when_ready_to_read ~on_payload_eof = + { faraday = Faraday.of_bigstring buffer ; read_scheduled = false + ; eof_has_been_called = false ; on_eof = default_on_eof ; on_read = default_on_read + ; when_ready_to_read + ; on_payload_eof } - let create buffer = - of_faraday (Faraday.of_bigstring buffer) - - let create_empty () = - let t = create Bigstringaf.empty in + let create_empty ~on_payload_eof = + let t = + create + Bigstringaf.empty + ~when_ready_to_read:Optional_thunk.none + ~on_payload_eof + in Faraday.close t.faraday; + t.on_payload_eof (); t - let empty = create_empty () - let is_closed t = Faraday.is_closed t.faraday let unsafe_faraday t = t.faraday + let ready_to_read t = Optional_thunk.call_if_some t.when_ready_to_read + let rec do_execute_read t on_eof on_read = match Faraday.operation t.faraday with | `Yield -> () @@ -74,7 +83,12 @@ module IOVec = Httpun.IOVec t.read_scheduled <- false; t.on_eof <- default_on_eof; t.on_read <- default_on_read; - on_eof () + if not t.eof_has_been_called then begin + t.eof_has_been_called <- true; + t.on_payload_eof (); + on_eof (); + end + (* [Faraday.operation] never returns an empty list of iovecs *) | `Writev [] -> assert false | `Writev (iovec::_) -> t.read_scheduled <- false; @@ -96,10 +110,15 @@ module IOVec = Httpun.IOVec t.on_eof <- on_eof; t.on_read <- on_read; end; - do_execute_read t on_eof on_read - - let is_read_scheduled t = t.read_scheduled + do_execute_read t on_eof on_read; + ready_to_read t let close t = Faraday.close t.faraday; - execute_read t + execute_read t; + ready_to_read t + ;; + + let has_pending_output t = Faraday.has_pending_output t.faraday + + let is_read_scheduled t = t.read_scheduled diff --git a/lib/server_connection.ml b/lib/server_connection.ml index 0b46809e..2a09e2ff 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -103,7 +103,7 @@ let read_eof t bs ~off ~len = let yield_reader t f = match t.state with | Handshake handshake -> Server_handshake.yield_reader handshake f - | Websocket _ -> assert false + | Websocket websocket -> Websocket_connection.yield_reader websocket f let next_write_operation t = match t.state with diff --git a/lib/websocket_connection.ml b/lib/websocket_connection.ml index 78bb9041..fdd114cf 100644 --- a/lib/websocket_connection.ml +++ b/lib/websocket_connection.ml @@ -5,10 +5,14 @@ type error = [ `Exn of exn ] type error_handler = Wsd.t -> error -> unit +let default_payload = + Sys.opaque_identity (Payload.create_empty ~on_payload_eof:(fun () -> ())) + type t = { reader : [`Parse of string list * string] Reader.t ; wsd : Wsd.t ; eof : unit -> unit + ; mutable current_payload: Payload.t } type input_handlers = @@ -34,13 +38,33 @@ let default_error_handler wsd (`Exn exn) = Wsd.close wsd ;; +let wakeup_reader t = Reader.wakeup t.reader + let create ~mode ?(error_handler = default_error_handler) websocket_handler = let wsd = Wsd.create ~error_handler mode in let { frame; eof } = websocket_handler wsd in - { reader = Reader.create frame - ; wsd - ; eof - } + let handler t = fun ~opcode ~is_fin ~len payload -> + let t = Lazy.force t in + t.current_payload <- payload; + frame ~opcode ~is_fin ~len payload + in + let rec reader = + lazy ( + Reader.create + (fun ~opcode ~is_fin ~len payload -> (handler t ~opcode ~is_fin ~len payload)) + ~on_payload_eof:(fun () -> + let t = Lazy.force t in + t.current_payload <- default_payload; + ) + ) + and t = lazy + { reader = Lazy.force reader + ; wsd + ; eof + ; current_payload = default_payload + } + in + Lazy.force t let shutdown { wsd; _ } = Wsd.close wsd @@ -53,7 +77,17 @@ let next_read_operation t = match Reader.next t.reader with | `Error (`Parse (_, message)) -> set_error_and_handle t (`Exn (Failure message)); `Close - | (`Read | `Close) as operation -> operation + | `Read -> + begin match t.current_payload == default_payload with + | true -> `Read + | false -> + if Payload.is_read_scheduled t.current_payload + then `Read + else begin + `Yield + end + end + | `Close -> `Close let next_write_operation t = Wsd.next t.wsd @@ -83,4 +117,7 @@ let is_closed { wsd; _ } = let report_exn t exn = set_error_and_handle t (`Exn exn) -let yield_reader _t _f = () +let yield_reader t k = + if Reader.is_closed t.reader + then k () + else Reader.on_wakeup t.reader k From a20035109941d402a8cc5f364022333926caf760 Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Wed, 7 Aug 2024 00:13:11 -0700 Subject: [PATCH 2/5] fix tests --- examples/eio/echo_server.ml | 15 +++++-- lib_test/test_httpun_ws.ml | 87 +++++++++++++++++++++++++------------ 2 files changed, 71 insertions(+), 31 deletions(-) diff --git a/examples/eio/echo_server.ml b/examples/eio/echo_server.ml index 98c2eed9..56ec01d5 100644 --- a/examples/eio/echo_server.ml +++ b/examples/eio/echo_server.ml @@ -6,15 +6,22 @@ let connection_handler ~sw : Eio.Net.Sockaddr.stream -> _ Eio.Net.stream_socket let module Status = Httpun.Status in let websocket_handler _client_address wsd = - let frame ~opcode ~is_fin:_ ~len:_ payload = + let frame ~opcode ~is_fin ~len payload = + Format.eprintf "FRAME %a %d %B@." Httpun_ws.Websocket.Opcode.pp_hum opcode len is_fin; match (opcode: Httpun_ws.Websocket.Opcode.t) with | #Httpun_ws.Websocket.Opcode.standard_non_control as opcode -> + let rec on_read bs ~off ~len = + Format.eprintf "do it %d %S@." len (Bigstringaf.substring bs ~off ~len); + Httpun_ws.Wsd.schedule wsd bs ~kind:opcode ~off ~len; + Httpun_ws.Payload.schedule_read payload + ~on_eof:ignore + ~on_read + in Httpun_ws.Payload.schedule_read payload ~on_eof:ignore - ~on_read:(fun bs ~off ~len -> - Httpun_ws.Wsd.schedule wsd bs ~kind:opcode ~off ~len) + ~on_read | `Connection_close -> - Httpun_ws.Wsd.close wsd + Httpun_ws.Wsd.close ~code:(`Other 1005) wsd | `Ping -> Httpun_ws.Wsd.send_pong wsd | `Pong diff --git a/lib_test/test_httpun_ws.ml b/lib_test/test_httpun_ws.ml index 87586cd3..a169c9ad 100644 --- a/lib_test/test_httpun_ws.ml +++ b/lib_test/test_httpun_ws.ml @@ -11,11 +11,22 @@ module Websocket = struct module Parser = struct open Httpun_ws__ - let parse_frame serialized_frame = + let parse_frame ~handler serialized_frame = let parser = let open Angstrom in - (Parse.frame ~buf:Bigstringaf.empty) >>= fun frame -> - lift (fun () -> frame) (Parse.payload_parser frame) + Parse.frame >>= fun frame -> + let { Parse.payload_length; _ } = frame in + let payload = + match payload_length with + | 0 -> Payload.create_empty ~on_payload_eof:(fun () -> ()) + | _ -> + Payload.create (Bigstringaf.create 0x100) + ~when_ready_to_read:(Optional_thunk.some (fun () -> ())) + ~on_payload_eof:(fun () -> ()) + in + let payload_parser = Parse.payload_parser frame payload in + handler frame payload; + payload_parser in match Angstrom.parse_string ~consume:All parser serialized_frame with | Ok frame -> frame @@ -35,45 +46,67 @@ module Websocket = struct Faraday.serialize_to_string f let test_parsing_ping_frame () = - let frame = parse_frame "\137\128\000\000\046\216" in - Alcotest.check Testable.opcode "opcode" `Ping frame.opcode; - Alcotest.(check (option int32)) "mask" (Some 11992l) frame.mask; - Alcotest.(check int) "payload_length" 0 frame.payload_length + let parsed = ref false in + parse_frame "\137\128\000\000\046\216" ~handler:(fun frame _payload -> + parsed := true; + Alcotest.check Testable.opcode "opcode" `Ping frame.opcode; + Alcotest.(check (option int32)) "mask" (Some 11992l) frame.mask; + Alcotest.(check int) "payload_length" 0 frame.payload_length); + Alcotest.(check bool) "parsed" true !parsed let test_parsing_close_frame () = - let frame = parse_frame "\136\000" in - Alcotest.check Testable.opcode "opcode" `Connection_close frame.opcode; - Alcotest.(check int) "payload_length" 0 frame.payload_length; - Alcotest.(check bool) "is_fin" true frame.is_fin + let parsed = ref false in + parse_frame "\136\000" ~handler:(fun frame _payload -> + parsed := true; + Alcotest.check Testable.opcode "opcode" `Connection_close frame.opcode; + Alcotest.(check int) "payload_length" 0 frame.payload_length; + Alcotest.(check bool) "is_fin" true frame.is_fin); + Alcotest.(check bool) "parsed" true !parsed - let read_payload frame = + let read_payload payload = let rev_payload_chunks = ref [] in - let payload = frame.Parse.payload in Payload.schedule_read payload ~on_eof:ignore ~on_read:(fun bs ~off ~len -> - rev_payload_chunks := Bigstringaf.substring bs ~off ~len :: !rev_payload_chunks + rev_payload_chunks := + Bigstringaf.substring bs ~off ~len :: !rev_payload_chunks ); !rev_payload_chunks let test_parsing_text_frame () = - let frame = parse_frame "\129\139\086\057\046\216\103\011\029\236\099\015\025\224\111\009\036" in - Alcotest.check Testable.opcode "opcode" `Text frame.opcode; - Alcotest.(check (option int32)) "mask" (Some 1446588120l) frame.mask; - Alcotest.(check int) "payload_length" 11 frame.payload_length; - let rev_payload_chunks = read_payload frame in - Alcotest.(check bool) "is_fin" true frame.is_fin; + let parsed = ref false in + let payload = ref None in + parse_frame + "\129\139\086\057\046\216\103\011\029\236\099\015\025\224\111\009\036" + ~handler:(fun frame pload -> + parsed := true; + Alcotest.check Testable.opcode "opcode" `Text frame.opcode; + Alcotest.(check (option int32)) "mask" (Some 1446588120l) frame.mask; + Alcotest.(check int) "payload_length" 11 frame.payload_length; + Alcotest.(check bool) "is_fin" true frame.is_fin; + payload := Some pload; + ); + Alcotest.(check bool) "parsed" true !parsed; + let rev_payload_chunks = read_payload (Option.get !payload) in Alcotest.(check (list string)) "payload" ["1234567890\n"] rev_payload_chunks let test_parsing_fin_bit () = - let frame = parse_frame (serialize_frame ~is_fin:false "hello") in - Alcotest.check Testable.opcode "opcode" `Text frame.opcode; - Alcotest.(check bool) "is_fin" false frame.is_fin; - let frame = parse_frame (serialize_frame ~is_fin:true "hello") in - Alcotest.check Testable.opcode "opcode" `Text frame.opcode; - Alcotest.(check bool) "is_fin" true frame.is_fin; - let rev_payload_chunks = read_payload frame in + let parsed = ref false in + parse_frame (serialize_frame ~is_fin:false "hello") ~handler:(fun frame _payload -> + parsed := true; + Alcotest.check Testable.opcode "opcode" `Text frame.opcode; + Alcotest.(check bool) "is_fin" false frame.is_fin); + Alcotest.(check bool) "parsed" true !parsed; + parsed := false; + let payload = ref None in + parse_frame (serialize_frame ~is_fin:true "hello") ~handler:(fun frame pload -> + parsed := true; + Alcotest.check Testable.opcode "opcode" `Text frame.opcode; + Alcotest.(check bool) "is_fin" true frame.is_fin; + payload := Some pload); + Alcotest.(check bool) "parsed" true !parsed; + let rev_payload_chunks = read_payload (Option.get !payload) in Alcotest.(check (list string)) "payload" ["hello"] rev_payload_chunks let test_parsing_multiple_frames () = From 4bf38e770ebb52922099a51149b4876c677cea24 Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Sat, 24 Aug 2024 15:15:14 -0700 Subject: [PATCH 3/5] wip --- lib/parse.ml | 30 ++++----- lib/payload.ml | 21 ++++--- lib/websocket_connection.ml | 118 +++++++++++++++++++++++++----------- 3 files changed, 113 insertions(+), 56 deletions(-) diff --git a/lib/parse.ml b/lib/parse.ml index e955cdcf..220b9227 100644 --- a/lib/parse.ml +++ b/lib/parse.ml @@ -1,9 +1,9 @@ type t = -{ payload_length: int - ; is_fin: bool - ; mask: int32 option - ; opcode: Websocket.Opcode.t -} + { payload_length: int + ; is_fin: bool + ; mask: int32 option + ; opcode: Websocket.Opcode.t + } let is_fin headers = let bits = Bigstringaf.unsafe_get headers 0 |> Char.code in @@ -159,23 +159,22 @@ module Reader = struct t.wakeup <- Optional_thunk.none; Optional_thunk.call_if_some f - let create frame_handler ~on_payload_eof = + let create frame_handler = let rec parser t = let open Angstrom in let buf = Bigstringaf.create 0x1000 in skip_many (frame <* commit >>= fun frame -> - let { is_fin; opcode; payload_length; _ } = frame in + let { payload_length; _ } = frame in let payload = match payload_length with - | 0 -> Payload.create_empty ~on_payload_eof + | 0 -> Payload.create_empty () | _ -> Payload.create buf ~when_ready_to_read:(Optional_thunk.some (fun () -> wakeup (Lazy.force t))) - ~on_payload_eof in - frame_handler ~opcode ~is_fin ~len:payload_length payload; + frame_handler frame payload; payload_parser frame payload) and t = lazy ( { parser = parser t @@ -235,12 +234,15 @@ module Reader = struct end; consumed + let force_close t = + t.closed <- true; + ;; + let next t = match t.parse_state with - | Done -> - if t.closed - then `Close - else `Read | Fail failure -> `Error failure + | _ when t.closed -> `Close + | Done -> `Read | Partial _ -> `Read + ;; end diff --git a/lib/payload.ml b/lib/payload.ml index f75317ce..2586bd86 100644 --- a/lib/payload.ml +++ b/lib/payload.ml @@ -41,31 +41,27 @@ module IOVec = Httpun.IOVec ; mutable eof_has_been_called : bool ; mutable on_read : Bigstringaf.t -> off:int -> len:int -> unit ; when_ready_to_read : Optional_thunk.t - ; on_payload_eof : unit -> unit } let default_on_eof = Sys.opaque_identity (fun () -> ()) let default_on_read = Sys.opaque_identity (fun _ ~off:_ ~len:_ -> ()) - let create buffer ~when_ready_to_read ~on_payload_eof = + let create buffer ~when_ready_to_read = { faraday = Faraday.of_bigstring buffer ; read_scheduled = false ; eof_has_been_called = false ; on_eof = default_on_eof ; on_read = default_on_read ; when_ready_to_read - ; on_payload_eof } - let create_empty ~on_payload_eof = + let create_empty () = let t = create Bigstringaf.empty ~when_ready_to_read:Optional_thunk.none - ~on_payload_eof in Faraday.close t.faraday; - t.on_payload_eof (); t let is_closed t = @@ -85,7 +81,6 @@ module IOVec = Httpun.IOVec t.on_read <- default_on_read; if not t.eof_has_been_called then begin t.eof_has_been_called <- true; - t.on_payload_eof (); on_eof (); end (* [Faraday.operation] never returns an empty list of iovecs *) @@ -122,3 +117,15 @@ module IOVec = Httpun.IOVec let has_pending_output t = Faraday.has_pending_output t.faraday let is_read_scheduled t = t.read_scheduled + +type input_state = + | Ready + | Wait + | Complete + +let input_state t : input_state = + if is_closed t + then Complete + else if is_read_scheduled t + then Ready + else Wait diff --git a/lib/websocket_connection.ml b/lib/websocket_connection.ml index fdd114cf..497fda46 100644 --- a/lib/websocket_connection.ml +++ b/lib/websocket_connection.ml @@ -5,19 +5,24 @@ type error = [ `Exn of exn ] type error_handler = Wsd.t -> error -> unit -let default_payload = - Sys.opaque_identity (Payload.create_empty ~on_payload_eof:(fun () -> ())) +type frame_handler = + opcode:Websocket.Opcode.t + -> is_fin:bool + -> len:int + -> Payload.t + -> unit type t = { reader : [`Parse of string list * string] Reader.t ; wsd : Wsd.t + ; frame_handler : frame_handler ; eof : unit -> unit - ; mutable current_payload: Payload.t + ; frame_queue: (Parse.t * Payload.t) Queue.t } type input_handlers = - { frame : opcode:Websocket.Opcode.t -> is_fin:bool -> len:int -> Payload.t -> unit - ; eof : unit -> unit } + { frame : frame_handler + ; eof : unit -> unit } (* TODO: this should be passed as an argument from the runtime, to allow for * cryptographically secure random number generation. *) @@ -42,52 +47,95 @@ let wakeup_reader t = Reader.wakeup t.reader let create ~mode ?(error_handler = default_error_handler) websocket_handler = let wsd = Wsd.create ~error_handler mode in - let { frame; eof } = websocket_handler wsd in - let handler t = fun ~opcode ~is_fin ~len payload -> - let t = Lazy.force t in - t.current_payload <- payload; - frame ~opcode ~is_fin ~len payload + let { frame = frame_handler; eof } = websocket_handler wsd in + let frame_queue = Queue.create () in + let handler frame payload = + let call_handler = Queue.is_empty frame_queue in + + Queue.push (frame, payload) frame_queue; + if call_handler + then + let { Parse.opcode; is_fin; payload_length; _ } = frame in + frame_handler ~opcode ~is_fin ~len:payload_length payload in - let rec reader = - lazy ( - Reader.create - (fun ~opcode ~is_fin ~len payload -> (handler t ~opcode ~is_fin ~len payload)) - ~on_payload_eof:(fun () -> - let t = Lazy.force t in - t.current_payload <- default_payload; - ) - ) + let rec reader = lazy (Reader.create handler) and t = lazy { reader = Lazy.force reader ; wsd + ; frame_handler ; eof - ; current_payload = default_payload + ; frame_queue } in Lazy.force t -let shutdown { wsd; _ } = - Wsd.close wsd +let shutdown_reader t = + Reader.force_close t.reader; + wakeup_reader t + +let shutdown t = + shutdown_reader t; + Wsd.close t.wsd let set_error_and_handle t error = Wsd.report_error t.wsd error; shutdown t +let advance_frame_queue t = + ignore (Queue.take t.frame_queue); + if not (Queue.is_empty t.frame_queue) + then + let { Parse.opcode; is_fin; payload_length; _ }, payload = Queue.peek t.frame_queue in + t.frame_handler ~opcode ~is_fin ~len:payload_length payload +;; + +let rec _next_read_operation t = + begin match Queue.peek t.frame_queue with + | _, payload -> + begin match Payload.input_state payload with + | Wait -> + begin match Reader.next t.reader with + | (`Error _ | `Close) as operation -> operation + | _ -> `Yield + end + | Ready -> Reader.next t.reader + | Complete -> + (* Don't advance the request queue if in an error state. *) + begin match Reader.next t.reader with + | `Error _ as op -> + (* we just don't advance the request queue in the case of a parser + error. *) + op + | `Read as op -> + (* Keep reading when in a "partial" state (`Read). *) + advance_frame_queue t; + op + | `Close -> + advance_frame_queue t; + _next_read_operation t + end + end; + | exception Queue.Empty -> + let next = Reader.next t.reader in + begin match next with + | `Error _ -> + (* Don't tear down the whole connection if we saw an unrecoverable + * parsing error, as we might be in the process of streaming back the + * error response body to the client. *) + shutdown_reader t + | `Close -> () + | _ -> () + end; + next + end + let next_read_operation t = - match Reader.next t.reader with + match _next_read_operation t with | `Error (`Parse (_, message)) -> - set_error_and_handle t (`Exn (Failure message)); `Close - | `Read -> - begin match t.current_payload == default_payload with - | true -> `Read - | false -> - if Payload.is_read_scheduled t.current_payload - then `Read - else begin - `Yield - end - end - | `Close -> `Close + set_error_and_handle t (`Exn (Failure message)); + `Close + | `Read -> `Read + | (`Yield | `Close) as operation -> operation let next_write_operation t = Wsd.next t.wsd From 8e67fb12ba2ac49cfdd8364c8ff6a44f64b8005c Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Sat, 24 Aug 2024 15:18:15 -0700 Subject: [PATCH 4/5] fix tests --- lib_test/test_httpun_ws.ml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/lib_test/test_httpun_ws.ml b/lib_test/test_httpun_ws.ml index a169c9ad..649de1fa 100644 --- a/lib_test/test_httpun_ws.ml +++ b/lib_test/test_httpun_ws.ml @@ -18,11 +18,10 @@ module Websocket = struct let { Parse.payload_length; _ } = frame in let payload = match payload_length with - | 0 -> Payload.create_empty ~on_payload_eof:(fun () -> ()) + | 0 -> Payload.create_empty () | _ -> Payload.create (Bigstringaf.create 0x100) ~when_ready_to_read:(Optional_thunk.some (fun () -> ())) - ~on_payload_eof:(fun () -> ()) in let payload_parser = Parse.payload_parser frame payload in handler frame payload; From 6d3865fd43c2a364520dae849482fae2d386c0cc Mon Sep 17 00:00:00 2001 From: Antonio Nuno Monteiro Date: Sat, 24 Aug 2024 15:49:24 -0700 Subject: [PATCH 5/5] fix more --- lib/parse.ml | 29 +++++++++++++++++++++++------ lib/websocket_connection.ml | 22 +++++++++++++++++----- lib_test/test_httpun_ws.ml | 10 ++++++---- 3 files changed, 46 insertions(+), 15 deletions(-) diff --git a/lib/parse.ml b/lib/parse.ml index 220b9227..8b743c8b 100644 --- a/lib/parse.ml +++ b/lib/parse.ml @@ -218,20 +218,37 @@ module Reader = struct t.parse_state <- Partial continue | _ -> assert false - let rec read_with_more t bs ~off ~len more = + let rec _read_with_more t bs ~off ~len more = + let initial = match t.parse_state with Done -> true | _ -> false in let consumed = match t.parse_state with | Fail _ -> 0 + (* Don't feed empty input when we're at a request boundary *) + | Done when len = 0 -> 0 | Done -> start t (AU.parse t.parser); - read_with_more t bs ~off ~len more; + _read_with_more t bs ~off ~len more; | Partial continue -> transition t (continue bs more ~off ~len) in - begin match more with - | Complete -> t.closed <- true; - | Incomplete -> () - end; + (* Special case where the parser just started and was fed a zero-length + * bigstring. Avoid putting them parser in an error state in this scenario. + * If we were already in a `Partial` state, return the error. *) + if initial && len = 0 then t.parse_state <- Done; + match t.parse_state with + | Done when consumed < len -> + let off = off + consumed + and len = len - consumed in + consumed + _read_with_more t bs ~off ~len more + | _ -> consumed + ;; + + let read_with_more t bs ~off ~len more = + let consumed = _read_with_more t bs ~off ~len more in + (match more with + | Complete -> + t.closed <- true + | Incomplete -> ()); consumed let force_close t = diff --git a/lib/websocket_connection.ml b/lib/websocket_connection.ml index 497fda46..34cc9c69 100644 --- a/lib/websocket_connection.ml +++ b/lib/websocket_connection.ml @@ -140,11 +140,26 @@ let next_read_operation t = let next_write_operation t = Wsd.next t.wsd +let report_exn t exn = + set_error_and_handle t (`Exn exn) + +let read_with_more t bs ~off ~len more = + let consumed = Reader.read_with_more t.reader bs ~off ~len more in + if not (Queue.is_empty t.frame_queue) + then ( + let _, payload = Queue.peek t.frame_queue in + if Payload.has_pending_output payload + then try Payload.execute_read payload + with exn -> report_exn t exn + ); + consumed +;; + let read t bs ~off ~len = - Reader.read_with_more t.reader bs ~off ~len Incomplete + read_with_more t bs ~off ~len Incomplete let read_eof t bs ~off ~len = - let r = Reader.read_with_more t.reader bs ~off ~len Complete in + let r = read_with_more t bs ~off ~len Complete in t.eof (); r @@ -162,9 +177,6 @@ let yield_writer t k = let is_closed { wsd; _ } = Wsd.is_closed wsd -let report_exn t exn = - set_error_and_handle t (`Exn exn) - let yield_reader t k = if Reader.is_closed t.reader then k () diff --git a/lib_test/test_httpun_ws.ml b/lib_test/test_httpun_ws.ml index 649de1fa..a73d468f 100644 --- a/lib_test/test_httpun_ws.ml +++ b/lib_test/test_httpun_ws.ml @@ -115,10 +115,11 @@ module Websocket = struct match opcode with | `Text -> incr frames_parsed; - Payload.schedule_read payload - ~on_eof:ignore - ~on_read:(fun bs ~off ~len -> - Wsd.schedule wsd bs ~kind:`Text ~off ~len) + let rec on_read bs ~off ~len = + Wsd.schedule wsd ~kind:`Text bs ~off ~len; + Payload.schedule_read payload ~on_eof:ignore ~on_read + in + Payload.schedule_read payload ~on_eof:ignore ~on_read | `Binary | `Continuation | `Connection_close @@ -144,6 +145,7 @@ module Websocket = struct let len = String.length frames in let bs = Bigstringaf.of_string ~off:0 ~len frames in let read = Server_connection.read t bs ~off:0 ~len in + ignore @@ Server_connection.next_read_operation t; Alcotest.(check int) "Reads both frames" len read; Alcotest.(check int) "Both frames parsed and handled" 2 !frames_parsed; ;;