diff --git a/CHANGES.md b/CHANGES.md index d3650845..48e141a1 100644 --- a/CHANGES.md +++ b/CHANGES.md @@ -62,6 +62,8 @@ Unreleased ([#37](https://github.com/anmonteiro/httpaf/pull/37)). - httpaf-lwt: Close the communication channel after shutting down the client ([#45](https://github.com/anmonteiro/httpaf/pull/45)) +- httpaf: Fix sending streaming error responses; in particular, allow sending + chunk-encoded responses ([#56](https://github.com/anmonteiro/httpaf/pull/56)) httpaf (upstream) 0.6.5 -------------- diff --git a/lib/body.ml b/lib/body.ml index 456f92d5..6d9aff90 100644 --- a/lib/body.ml +++ b/lib/body.ml @@ -142,6 +142,9 @@ let has_pending_output t = Faraday.has_pending_output t.faraday || (Faraday.is_closed t.faraday && t.write_final_if_chunked) +let requires_output t = + not (is_closed t) || has_pending_output t + let close_reader t = Faraday.close t.faraday; execute_read t diff --git a/lib/reqd.ml b/lib/reqd.ml index 7bb33146..a50e42cb 100644 --- a/lib/reqd.ml +++ b/lib/reqd.ml @@ -34,27 +34,12 @@ type error = [ `Bad_request | `Bad_gateway | `Internal_server_error | `Exn of exn ] -module Response_state = struct - type t = - | Waiting of Optional_thunk.t ref - | Complete of Response.t - | Streaming of Response.t * [`write] Body.t - | Upgrade of Response.t * (unit -> unit) -end - module Input_state = struct type t = | Ready | Complete end -module Output_state = struct - type t = - | Consume - | Wait - | Complete -end - type error_handler = ?request:Request.t -> error -> (Headers.t -> [`write] Body.t) -> unit @@ -180,6 +165,26 @@ let respond_with_streaming ?(flush_headers_immediately=false) t response = failwith "httpaf.Reqd.respond_with_streaming: invalid state, currently handling error"; unsafe_respond_with_streaming ~flush_headers_immediately t response +let unsafe_respond_with_upgrade t headers upgrade_handler = + match t.response_state with + | Waiting when_done_waiting -> + let response = Response.create ~headers `Switching_protocols in + Writer.write_response t.writer response; + if t.persistent then + t.persistent <- Response.persistent_connection response; + t.response_state <- Upgrade (response, upgrade_handler); + Body.close_reader t.request_body; + done_waiting when_done_waiting + | Streaming _ | Upgrade _ -> + failwith "httpaf.Reqd.unsafe_respond_with_upgrade: response already started" + | Complete _ -> + failwith "httpaf.Reqd.unsafe_respond_with_upgrade: response already complete" + +let respond_with_upgrade t response upgrade_handler = + if t.error_code <> `Ok then + failwith "httpaf.Reqd.respond_with_streaming: invalid state, currently handling error"; + unsafe_respond_with_upgrade t response upgrade_handler + let report_error t error = t.persistent <- false; Body.close_reader t.request_body; @@ -225,15 +230,7 @@ let error_code t = | `Ok -> None let on_more_output_available t f = - match t.response_state with - | Waiting when_done_waiting -> - if Optional_thunk.is_some !when_done_waiting - then failwith "httpaf.Reqd.on_more_output_available: only one callback can be registered at a time"; - when_done_waiting := Optional_thunk.some f - | Streaming(_, response_body) -> - Body.when_ready_to_write response_body f - | Complete _ -> - failwith "httpaf.Reqd.on_more_output_available: response already complete" + Response_state.on_more_output_available t.response_state f let persistent_connection t = t.persistent @@ -244,16 +241,7 @@ let input_state t : Input_state.t = else Ready ;; -let output_state t : Output_state.t = - match t.response_state with - | Complete _ -> Complete - | Waiting _ -> Wait - | Streaming(_, response_body) -> - if not (Body.is_closed response_body) - || Body.has_pending_output response_body - then Consume - else Complete - | Upgrade _ -> Consume +let output_state t = Response_state.output_state t.response_state let flush_request_body t = let request_body = request_body t in @@ -262,13 +250,5 @@ let flush_request_body t = with exn -> report_exn t exn let flush_response_body t = - match t.response_state with - | Streaming (response, response_body) -> - let request_method = t.request.Request.meth in - let encoding = - match Response.body_length ~request_method response with - | `Fixed _ | `Close_delimited | `Chunked as encoding -> encoding - | `Error _ -> assert false (* XXX(seliopou): This needs to be handled properly *) - in - Body.transfer_to_writer_with_encoding response_body ~encoding t.writer - | _ -> () + let request_method = t.request.Request.meth in + Response_state.flush_response_body t.response_state ~request_method t.writer diff --git a/lib/respd.ml b/lib/respd.ml index cb221fa0..60d2a4a1 100644 --- a/lib/respd.ml +++ b/lib/respd.ml @@ -125,9 +125,7 @@ let output_state { request_body; state; _ } : Output_state.t = * transition the response descriptor to the `Closed` state. *) Consume | state -> - if state = Uninitialized || - not (Body.is_closed request_body) || - Body.has_pending_output request_body + if state = Uninitialized || Body.requires_output request_body then Consume else Complete diff --git a/lib/response_state.ml b/lib/response_state.ml new file mode 100644 index 00000000..02d0d4a5 --- /dev/null +++ b/lib/response_state.ml @@ -0,0 +1,50 @@ +module Output_state = struct + type t = + | Consume + | Wait + | Complete +end + +type t = + | Waiting of Optional_thunk.t ref + | Complete of Response.t + | Streaming of Response.t * [`write] Body.t + | Upgrade of Response.t * (unit -> unit) + +let on_more_output_available t f = + match t with + | Waiting when_done_waiting -> + if Optional_thunk.is_some !when_done_waiting + then failwith "httpaf.Reqd.on_more_output_available: only one callback can be registered at a time"; + when_done_waiting := Optional_thunk.some f + | Streaming(_, response_body) -> + Body.when_ready_to_write response_body f + | Complete _ -> + failwith "httpaf.Reqd.on_more_output_available: response already complete" + | Upgrade _ -> + (* XXX(anmonteiro): Connections that have been upgraded "require output" + * forever, but outside the HTTP layer, meaning they're permanently + * "yielding". We don't register the wakeup callback because it's not going + * to get called. *) + () + +let output_state t : Output_state.t = + match t with + | Complete _ -> Complete + | Waiting _ -> Wait + | Streaming(_, response_body) -> + if Body.requires_output response_body + then Consume + else Complete + | Upgrade _ -> Consume + +let flush_response_body t ~request_method writer = + match t with + | Streaming (response, response_body) -> + let encoding = + match Response.body_length ~request_method response with + | `Fixed _ | `Close_delimited | `Chunked as encoding -> encoding + | `Error _ -> assert false (* XXX(seliopou): This needs to be handled properly *) + in + Body.transfer_to_writer_with_encoding response_body ~encoding writer + | _ -> () diff --git a/lib/server_connection.ml b/lib/server_connection.ml index 50f3128b..550bbb12 100644 --- a/lib/server_connection.ml +++ b/lib/server_connection.ml @@ -44,6 +44,13 @@ type error = type error_handler = ?request:Request.t -> error -> (Headers.t -> [`write] Body.t) -> unit +type error_code = + | No_error + | Error of + { request: Request.t option + ; mutable response_state: Response_state.t + } + type t = { reader : Reader.request ; writer : Writer.t @@ -58,7 +65,7 @@ type t = (* Represents an unrecoverable error that will cause the connection to * shutdown. Holds on to the response body created by the error handler * that might be streaming to the client. *) - ; mutable error_code : [`Ok | `Error of [`write] Body.t ] + ; mutable error_code : error_code } let is_closed t = @@ -93,8 +100,9 @@ let yield_writer t k = else if Optional_thunk.is_some t.wakeup_writer then failwith "yield_writer: only one callback can be registered at a time" else match t.error_code with - | `Ok -> t.wakeup_writer <- Optional_thunk.some k - | `Error body -> Body.when_ready_to_write body k + | No_error -> t.wakeup_writer <- Optional_thunk.some k + | Error { response_state; _ } -> + Response_state.on_more_output_available response_state k ;; let wakeup_writer t = @@ -146,7 +154,7 @@ let create ?(config=Config.default) ?(error_handler=default_error_handler) reque ; request_queue ; wakeup_writer = Optional_thunk.none ; wakeup_reader = Optional_thunk.none - ; error_code = `Ok + ; error_code = No_error } let shutdown_reader t = @@ -187,19 +195,30 @@ let set_error_and_handle ?request t error = shutdown_reader t; let writer = t.writer in match t.error_code with - | `Ok -> - let body = Body.of_faraday (Writer.faraday writer) in - t.error_code <- `Error body; + | No_error -> + (* The (shared) response body buffer can be used in this case because in + * this conditional branch we're not sending a response + * (is_active t == false), and are therefore not making use of that + * buffer. *) + let response_body = Body.create t.response_body_buffer in + t.error_code <- Error { request; response_state = Waiting (ref Optional_thunk.none) }; t.error_handler ?request error (fun headers -> - Writer.write_response writer (Response.create ~headers status); + let response = Response.create ~headers status in + Writer.write_response writer response; + t.error_code <- Error { request; response_state = Streaming(response, response_body) }; wakeup_writer t; - body) - | `Error _ -> - (* This should not happen. Even if we try to read more, the parser does - * not ingest it, and even if someone attempts to feed more bytes to the - * server when we already told them to [`Close], it's not really our - * problem. *) - assert false + response_body) + | Error _ -> + (* When reading, this should be impossible: even if we try to read more, + * the parser does not ingest it, and even if someone attempts to feed + * more bytes to the parser when we already told them to [`Close], that's + * really their own fault. + * + * We do, however, need to handle this case if any other exception is + * reported (we're already handling an error and e.g. the writing channel + * is closed). Just shut down the connection in that case. + *) + shutdown t end let report_exn t exn = @@ -273,11 +292,42 @@ let read t bs ~off ~len = let read_eof t bs ~off ~len = read_with_more t bs ~off ~len Complete +let flush_response_error_body t ?request response_state = + let request_method = match request with + | Some { Request.meth; _ } -> meth + | None -> + (* XXX(anmonteiro): Error responses may not have a request method if they + * are the result of e.g. an EOF error. Assuming that the request method + * is `GET` smells a little because it's exposing implementation details, + * though the only case where it'd matter would be potentially assuming + * the _successful_ response to a CONNECT request and sending one of the + * forbidden headers according to RFC7231ยง4.3.6: + * + * A server MUST NOT send any Transfer-Encoding or Content-Length + * header fields in a 2xx (Successful) response to CONNECT. + * + * If we're running this code, however, we're not responding with a + * successful status code, which makes us compliant to the above. *) + `GET + in + Response_state.flush_response_body response_state ~request_method t.writer + let rec _next_write_operation t = if not (is_active t) then ( - if Reader.is_closed t.reader && t.error_code = `Ok - then shutdown t; - Writer.next t.writer + match t.error_code with + | No_error -> + if Reader.is_closed t.reader + then shutdown t; + Writer.next t.writer + | Error { request; response_state } -> + match Response_state.output_state response_state with + | Wait -> `Yield + | Consume -> + flush_response_error_body t ?request response_state; + Writer.next t.writer + | Complete -> + shutdown_writer t; + Writer.next t.writer ) else ( let reqd = current_reqd_exn t in match Reqd.output_state reqd with diff --git a/lib_test/test_server_connection.ml b/lib_test/test_server_connection.ml index c186b8b1..9289ac11 100644 --- a/lib_test/test_server_connection.ml +++ b/lib_test/test_server_connection.ml @@ -673,9 +673,6 @@ let streaming_error_handler continue_error ?request:_ _error start_response = ;; let test_malformed_request_streaming_error_response () = - let eof_request_string = - "GET / HTTP/1.1\r\nconnection: close\r\nX-Other-Header: EOF_after_this" - in let writer_woken_up = ref false in let continue_error = ref (fun () -> ()) in let error_handler ?request error start_response = @@ -694,8 +691,8 @@ let test_malformed_request_streaming_error_response () = !continue_error (); Alcotest.(check bool) "Writer woken up" true !writer_woken_up; writer_woken_up := false; - write_string t ~msg:"Error response and first output written" - "HTTP/1.1 400 Bad Request\r\n\r\ngot an error\n"; + write_response t + (Response.create `Bad_request ~headers:Headers.empty); Alcotest.check write_operation "Writer is in a yield state" `Yield (next_write_operation t); yield_writer t (fun () -> writer_woken_up := true); @@ -703,6 +700,95 @@ let test_malformed_request_streaming_error_response () = Alcotest.(check bool) "Writer woken up once more input is available" true !writer_woken_up; write_string t ~msg:"Rest of the error response written" "more output"; + writer_closed t; + Alcotest.(check bool) "Connection is shutdown" true (is_closed t); +;; + +let chunked_error_handler continue_error ?request:_ _error start_response = + let resp_body = + start_response (Headers.of_list ["transfer-encoding", "chunked"]) + in + Body.write_string resp_body "chunk 1\n"; + continue_error := (fun () -> + Body.write_string resp_body "chunk 2\n"; + continue_error := (fun () -> + Body.write_string resp_body "chunk 3\n"; + Body.close_writer resp_body)) +;; + +let test_malformed_request_chunked_error_response () = + let writer_woken_up = ref false in + let continue_error = ref (fun () -> ()) in + let error_handler ?request error start_response = + continue_error := (fun () -> + chunked_error_handler continue_error ?request error start_response) + in + let t = create ~error_handler (basic_handler "") in + Alcotest.check write_operation "Writer is in a yield state" + `Yield (next_write_operation t); + yield_writer t (fun () -> writer_woken_up := true); + let c = read_string_eof t eof_request_string in + Alcotest.(check int) "read consumes all input" + (String.length eof_request_string) c; + Alcotest.check read_operation "Error shuts down the reader" + `Close (next_read_operation t); + Alcotest.(check bool) "Writer hasn't woken up yet" false !writer_woken_up; + !continue_error (); + Alcotest.(check bool) "Writer woken up" true !writer_woken_up; + writer_woken_up := false; + write_response t + ~msg:"First chunk written" + ~body:"8\r\nchunk 1\n\r\n" + (Response.create `Bad_request + ~headers:(Headers.of_list ["transfer-encoding", "chunked"])); + Alcotest.check write_operation "Writer is in a yield state" + `Yield (next_write_operation t); + yield_writer t (fun () -> writer_woken_up := true); + !continue_error (); + write_string t + ~msg:"Second chunk" + "8\r\nchunk 2\n\r\n"; + !continue_error (); + write_string t + ~msg:"Second chunk" + "8\r\nchunk 3\n\r\n"; + write_string t + ~msg:"Final chunk written" + "0\r\n\r\n"; + Alcotest.(check bool) "Writer woken up once more input is available" + true !writer_woken_up; + writer_closed t; + Alcotest.(check bool) "Connection is shutdown" true (is_closed t); +;; + +(* This may happen when writing an asynchronous error response on a broken + * pipe. *) +let test_malformed_request_double_report_exn () = + let writer_woken_up = ref false in + let continue_error = ref (fun () -> ()) in + let error_handler ?request error start_response = + continue_error := (fun () -> + streaming_error_handler continue_error ?request error start_response) + in + let t = create ~error_handler (basic_handler "") in + Alcotest.check write_operation "Writer is in a yield state" + `Yield (next_write_operation t); + yield_writer t (fun () -> writer_woken_up := true); + let c = read_string_eof t eof_request_string in + Alcotest.(check int) "read consumes all input" + (String.length eof_request_string) c; + Alcotest.check read_operation "Error shuts down the reader" + `Close (next_read_operation t); + Alcotest.(check bool) "Writer hasn't woken up yet" false !writer_woken_up; + !continue_error (); + Alcotest.(check bool) "Writer woken up" true !writer_woken_up; + writer_woken_up := false; + write_response t + ~body:"got an error\n" + (Response.create `Bad_request ~headers:Headers.empty); + Alcotest.check write_operation "Writer is in a yield state" + `Yield (next_write_operation t); + report_exn t (Failure "broken pipe"); Alcotest.(check bool) "Connection is shutdown" true (is_closed t); ;; @@ -792,6 +878,8 @@ let tests = ; "malformed request", `Quick, test_malformed_request ; "malformed request (async)", `Quick, test_malformed_request_async ; "multiple malformed requests?", `Quick, test_malformed_request_async_multiple_errors + ; "malformed request, chunked error response", `Quick, test_malformed_request_chunked_error_response + ; "malformed request, double report_exn", `Quick, test_malformed_request_double_report_exn ; "malformed request (EOF)", `Quick, test_malformed_request_eof ; "malformed request, streaming response", `Quick, test_malformed_request_streaming_error_response ; "`flush_headers_immediately` with empty body", `Quick, test_immediate_flush_empty_body diff --git a/nix/sources.nix b/nix/sources.nix index d18cb9a7..5fc5fb25 100644 --- a/nix/sources.nix +++ b/nix/sources.nix @@ -2,8 +2,8 @@ let overlays = builtins.fetchTarball { - url = https://github.com/anmonteiro/nix-overlays/archive/878e180.tar.gz; - sha256 = "1zbdnsaxj2xxsy3kz7rc5ziqlmwf6ba9pvs6pwvhf77xgr08blm0"; + url = https://github.com/anmonteiro/nix-overlays/archive/e155620.tar.gz; + sha256 = "04jkwih5waknpqla928kjr5v6xp1q38p6yf04rcdgwj10cghr0ax"; }; in