Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement shutdown, as required by mirage-flow 4.0.0 #512

Merged
merged 5 commits into from
Mar 26, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions src/stack-unix/tcp_socket.ml
Original file line number Diff line number Diff line change
Expand Up @@ -65,4 +65,12 @@ let close fd =
| Unix.Unix_error (Unix.EBADF, _, _) -> Lwt.return_unit
| e -> Lwt.fail e)

let shutdown fd mode =
let cmd = match mode with
| `read -> Lwt_unix.SHUTDOWN_RECEIVE
| `write -> Lwt_unix.SHUTDOWN_SEND
| `read_write -> Lwt_unix.SHUTDOWN_ALL
in
Lwt.return (Lwt_unix.shutdown fd cmd)

let input _t ~src:_ ~dst:_ _buf = Lwt.return_unit
28 changes: 20 additions & 8 deletions src/tcp/flow.ml
Original file line number Diff line number Diff line change
Expand Up @@ -125,18 +125,22 @@ struct
(Cstruct.create 0)

(* Queue up an immediate close segment *)
let close pcb =
Log.debug (fun f -> f "Closing connection %a" WIRE.pp pcb.id);
let shutdown ctx pcb =
Log.debug (fun f -> f "%s connection %a" (match ctx with `Close -> "Closing" | `Shutdown -> "Shutting down") WIRE.pp pcb.id);
match State.state pcb.state with
| State.Established | State.Close_wait ->
UTX.wait_for_flushed pcb.utx >>= fun () ->
(let { wnd; _ } = pcb in
STATE.tick pcb.state (State.Send_fin (Window.tx_nxt wnd));
TXS.output ~flags:Segment.Fin pcb.txq (Cstruct.create 0)
TXS.output ~flags:Segment.Fin pcb.txq Cstruct.empty
)
| State.Closed | State.Syn_rcvd _ | State.Syn_sent _ when ctx = `Close ->
State.on_close pcb.state;
Lwt.return_unit
| _ ->
Log.debug (fun fmt ->
fmt "TX.close: close requested but no action needed, state=%a" State.pp pcb.state);
let msg = match ctx with `Close -> "close" | `Shutdown -> "shutdown" in
fmt "TX.%s: %s requested but no action needed, state=%a" msg msg State.pp pcb.state);
Lwt.return_unit

(* Thread that transmits ACKs in response to received packets,
Expand Down Expand Up @@ -179,6 +183,10 @@ struct
(* Coalesce any outstanding segments and retrieve ready segments *)
RXS.input rxq parsed

let shutdown pcb =
User_buffer.Rx.remove_all pcb.urx;
User_buffer.Rx.add_r pcb.urx None

(* Thread that spools the data into an application receive buffer,
and notifies the ACK subsystem that new data is here *)
let thread pcb ~rx_data =
Expand All @@ -199,8 +207,7 @@ struct
| None ->
(* don't send an ACK in this case; this already happened *)
STATE.tick pcb.state State.Recv_fin;
User_buffer.Rx.add_r urx None >>= fun () ->
Lwt.return_unit
User_buffer.Rx.add_r urx None
| Some data ->
signal_ack winadv >>= fun () ->
let rec queue = function
Expand Down Expand Up @@ -632,8 +639,13 @@ struct
let write_nodelay pcb data = writefn pcb (UTX.write_nodelay pcb.utx) data |> cast
let writev_nodelay pcb data = iter_s (write_nodelay pcb) data |> cast

(* Close - no more will be written *)
let close pcb = Tx.close pcb
(* Close *)
let close pcb = Tx.shutdown `Close pcb

let shutdown pcb mode =
let wr, rd = match mode with | `read -> false, true | `write -> true, false | `read_write -> true, true in
(if wr then Tx.shutdown `Shutdown pcb else Lwt.return_unit) >>= fun () ->
(if rd then Rx.shutdown pcb else Lwt.return_unit)

let dst pcb = WIRE.dst pcb.id, WIRE.dst_port pcb.id

Expand Down
3 changes: 2 additions & 1 deletion src/tcp/state.ml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ type t = {
let t ~id ~on_close =
{ on_close; id; state=Closed }

let on_close t = t.on_close ()

let state t = t.state

let pf = Format.fprintf
Expand Down Expand Up @@ -174,5 +176,4 @@ module Make(Time:Mirage_time.S) = struct
Log.debug (fun fmt -> fmt "%d %a - %a -> %a" t.id
pp_tcpstate old_state pp_action i pp_tcpstate new_state);
t.state <- new_state;

end
2 changes: 2 additions & 0 deletions src/tcp/state.mli
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ type t
val state : t -> tcpstate
val t : id:int -> on_close:close_cb -> t

val on_close : t -> unit

val pp: Format.formatter -> t -> unit

module Make(Time : Mirage_time.S) : sig
Expand Down
7 changes: 7 additions & 0 deletions src/tcp/user_buffer.ml
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,13 @@ module Rx = struct
| None -> 0
| Some b -> Cstruct.length b

let remove_all t =
let rec rm = function
| 0 -> ()
| n -> ignore (Lwt_dllist.take_l t.q); rm (pred n)
in
rm (Lwt_dllist.length t.q)

let add_r t s =
if t.cur_size > t.max_size then
let th,u = Lwt.wait () in
Expand Down
1 change: 1 addition & 0 deletions src/tcp/user_buffer.mli
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ module Rx : sig
type t

val create : max_size:int32 -> wnd:Window.t -> t
val remove_all : t -> unit
val add_r : t -> Cstruct.t option -> unit Lwt.t
val take_l : t -> Cstruct.t option Lwt.t
val cur_size : t -> int32
Expand Down
2 changes: 1 addition & 1 deletion tcpip.opam
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ depends: [
"randomconv"
"ethernet" {>= "3.0.0"}
"arp" {>= "3.0.0"}
"mirage-flow" {>= "2.0.0"}
"mirage-flow" {>= "4.0.0"}
"mirage-vnetif" {with-test & >= "0.5.0"}
"alcotest" {with-test & >="1.5.0"}
"pcap-format" {with-test}
Expand Down
Loading