From 7eb8821effdd37e9fcca0493962fae70dc8b5e98 Mon Sep 17 00:00:00 2001
From: Rob Hoes <rob.hoes@citrix.com>
Date: Tue, 26 Jul 2022 14:57:42 +0000
Subject: [PATCH 5/6] Total timeout for receiving HTTP headers

Signed-off-by: Rob Hoes <rob.hoes@citrix.com>
---
 ocaml/libs/http-svr/http.ml      | 42 ++++++++++++++++++++++++++------
 ocaml/libs/http-svr/http.mli     |  7 +++++-
 ocaml/libs/http-svr/http_svr.ml  | 33 ++++++++++++++++---------
 ocaml/libs/http-svr/http_svr.mli |  7 +++++-
 ocaml/libs/http-svr/http_test.ml |  3 ++-
 ocaml/xapi/xapi_globs.ml         |  4 +++
 ocaml/xapi/xapi_mgmt_iface.ml    |  1 +
 7 files changed, 74 insertions(+), 23 deletions(-)

diff --git a/ocaml/libs/http-svr/http.ml b/ocaml/libs/http-svr/http.ml
index a4d528d8c..af8a56ee2 100644
--- a/ocaml/libs/http-svr/http.ml
+++ b/ocaml/libs/http-svr/http.ml
@@ -26,6 +26,8 @@ exception Method_not_implemented
 
 exception Malformed_url of string
 
+exception Timeout
+
 module D = Debug.Make (struct let name = "http" end)
 
 open D
@@ -281,7 +283,7 @@ let header_len_header = Printf.sprintf "\r\n%s:" Hdr.header_len
 
 let header_len_value_len = 5
 
-let read_up_to buf already_read marker fd =
+let read_up_to ?deadline buf already_read marker fd =
   let marker = Scanner.make marker in
   let hl_marker = Scanner.make header_len_header in
   let b = ref 0 in
@@ -289,6 +291,12 @@ let read_up_to buf already_read marker fd =
   let header_len = ref None in
   let header_len_value_at = ref None in
   while not (Scanner.matched marker) do
+    Option.iter
+      (fun d ->
+        if Mtime.Span.compare (Mtime_clock.elapsed ()) d > 0 then
+          raise Timeout
+      )
+      deadline ;
     let safe_to_read =
       match (!header_len_value_at, !header_len) with
       | None, None ->
@@ -369,29 +377,47 @@ let set_socket_timeout fd t =
     (* In the unit tests, the fd comes from a pipe... ignore *)
     ()
 
-let read_http_request_header ~read_timeout fd =
+let read_http_request_header ~read_timeout ~total_timeout fd =
   Option.iter (fun t -> set_socket_timeout fd t) read_timeout ;
   let buf = Bytes.create 1024 in
-  Unixext.really_read fd buf 0 6 ;
+  let deadline =
+    Option.map
+      (fun t ->
+        let start = Mtime_clock.elapsed () in
+        let timeout_ns = int_of_float (t *. 1e9) in
+        Mtime.Span.(add start (timeout_ns * ns))
+      )
+      total_timeout
+  in
+  let check_timeout_and_read x y =
+    Option.iter
+      (fun d ->
+        if Mtime.Span.compare (Mtime_clock.elapsed ()) d > 0 then
+          raise Timeout
+      )
+      deadline ;
+    Unixext.really_read fd buf x y
+  in
+  check_timeout_and_read 0 6 ;
   (* return PROXY header if it exists, and then read up to FRAME header length (which also may not exist) *)
   let proxy =
     match Bytes.sub_string buf 0 6 with
     | "PROXY " ->
-        let proxy_header_length = read_up_to buf 6 "\r\n" fd in
+        let proxy_header_length = read_up_to ?deadline buf 6 "\r\n" fd in
         (* chop 'PROXY ' from the beginning, and '\r\n' from the end *)
         let proxy = Bytes.sub_string buf 6 (proxy_header_length - 6 - 2) in
-        Unixext.really_read fd buf 0 frame_header_length ;
+        check_timeout_and_read 0 frame_header_length ;
         Some proxy
     | _ ->
-        Unixext.really_read fd buf 6 (frame_header_length - 6) ;
+        check_timeout_and_read 6 (frame_header_length - 6) ;
         None
   in
   let frame, headers_length =
     match read_frame_header buf with
     | None ->
-        (false, read_up_to buf frame_header_length end_of_headers fd)
+        (false, read_up_to ?deadline buf frame_header_length end_of_headers fd)
     | Some length ->
-        Unixext.really_read fd buf 0 length ;
+        check_timeout_and_read 0 length ;
         (true, length)
   in
   set_socket_timeout fd 0. ;
diff --git a/ocaml/libs/http-svr/http.mli b/ocaml/libs/http-svr/http.mli
index b06ad105f..23e636a50 100644
--- a/ocaml/libs/http-svr/http.mli
+++ b/ocaml/libs/http-svr/http.mli
@@ -28,12 +28,17 @@ exception Method_not_implemented
 
 exception Forbidden
 
+exception Timeout
+
 type authorization = Basic of string * string | UnknownAuth of string
 
 val make_frame_header : string -> string
 
 val read_http_request_header :
-  read_timeout:float option -> Unix.file_descr -> bool * string * string option
+     read_timeout:float option
+  -> total_timeout:float option
+  -> Unix.file_descr
+  -> bool * string * string option
 
 val read_http_response_header : bytes -> Unix.file_descr -> int
 
diff --git a/ocaml/libs/http-svr/http_svr.ml b/ocaml/libs/http-svr/http_svr.ml
index 77dea08bd..155462d33 100644
--- a/ocaml/libs/http-svr/http_svr.ml
+++ b/ocaml/libs/http-svr/http_svr.ml
@@ -322,9 +322,11 @@ exception Generic_error of string
 
 (** [request_of_bio_exn ic] reads a single Http.req from [ic] and returns it. On error
     	it simply throws an exception and doesn't touch the output stream. *)
-let request_of_bio_exn ~proxy_seen ~read_timeout bio =
+let request_of_bio_exn ~proxy_seen ~read_timeout ~total_timeout bio =
   let fd = Buf_io.fd_of bio in
-  let frame, headers, proxy' = Http.read_http_request_header ~read_timeout fd in
+  let frame, headers, proxy' =
+    Http.read_http_request_header ~read_timeout ~total_timeout fd
+  in
   let proxy = match proxy' with None -> proxy_seen | x -> x in
   let additional_headers =
     proxy |> Option.fold ~none:[] ~some:(fun p -> [("STUNNEL_PROXY", p)])
@@ -400,9 +402,11 @@ let request_of_bio_exn ~proxy_seen ~read_timeout bio =
 
 (** [request_of_bio ic] returns [Some req] read from [ic], or [None]. If [None] it will have
     	already sent back a suitable error code and response to the client. *)
-let request_of_bio ?proxy_seen ~read_timeout ic =
+let request_of_bio ?proxy_seen ~read_timeout ~total_timeout ic =
   try
-    let r, proxy = request_of_bio_exn ~proxy_seen ~read_timeout ic in
+    let r, proxy =
+      request_of_bio_exn ~proxy_seen ~read_timeout ~total_timeout ic
+    in
     (Some r, proxy)
   with e ->
     D.warn "%s (%s)" (Printexc.to_string e) __LOC__ ;
@@ -426,7 +430,7 @@ let request_of_bio ?proxy_seen ~read_timeout ic =
         (* Generic errors thrown during parsing *)
         | End_of_file ->
             ()
-        | Unix.Unix_error (Unix.EAGAIN, _, _) ->
+        | Unix.Unix_error (Unix.EAGAIN, _, _) | Http.Timeout ->
             response_request_timeout ss
         (* Premature termination of connection! *)
         | Unix.Unix_error (a, b, c) ->
@@ -496,7 +500,8 @@ let handle_one (x : 'a Server.t) ss context req =
     ) ;
     !finished
 
-let handle_connection ~header_read_timeout (x : 'a Server.t) caller ss =
+let handle_connection ~header_read_timeout ~header_total_timeout
+    (x : 'a Server.t) caller ss =
   ( match caller with
   | Unix.ADDR_UNIX _ ->
       debug "Accepted unix connection"
@@ -511,9 +516,11 @@ let handle_connection ~header_read_timeout (x : 'a Server.t) caller ss =
      just once per connection. To allow for the PROXY metadata (including e.g. the
      client IP) to be added to all request records on a connection, it must be passed
      along in the loop below. *)
-  let rec loop ~read_timeout proxy_seen =
+  let rec loop ~read_timeout ~total_timeout proxy_seen =
     (* 1. we must successfully parse a request *)
-    let req, proxy = request_of_bio ?proxy_seen ~read_timeout ic in
+    let req, proxy =
+      request_of_bio ?proxy_seen ~read_timeout ~total_timeout ic
+    in
     (* 2. now we attempt to process the request *)
     let finished =
       Option.fold ~none:true
@@ -522,9 +529,10 @@ let handle_connection ~header_read_timeout (x : 'a Server.t) caller ss =
     in
     (* 3. do it again if the connection is kept open, but without timeouts *)
     if not finished then
-      loop ~read_timeout:None proxy
+      loop ~read_timeout:None ~total_timeout:None proxy
   in
-  loop ~read_timeout:header_read_timeout None ;
+  loop ~read_timeout:header_read_timeout ~total_timeout:header_total_timeout
+    None ;
   debug "Closing connection" ;
   Unix.close ss
 
@@ -592,11 +600,12 @@ let socket_table = Hashtbl.create 10
 type socket = Unix.file_descr * string
 
 (* Start an HTTP server on a new socket *)
-let start ?header_read_timeout ~conn_limit (x : 'a Server.t) (socket, name) =
+let start ?header_read_timeout ?header_total_timeout ~conn_limit
+    (x : 'a Server.t) (socket, name) =
   let handler =
     {
       Server_io.name
-    ; body= handle_connection ~header_read_timeout x
+    ; body= handle_connection ~header_read_timeout ~header_total_timeout x
     ; lock= Xapi_stdext_threads.Semaphore.create conn_limit
     }
   in
diff --git a/ocaml/libs/http-svr/http_svr.mli b/ocaml/libs/http-svr/http_svr.mli
index 40a5074ea..761e39436 100644
--- a/ocaml/libs/http-svr/http_svr.mli
+++ b/ocaml/libs/http-svr/http_svr.mli
@@ -60,7 +60,12 @@ val bind : ?listen_backlog:int -> Unix.sockaddr -> string -> socket
 val bind_retry : ?listen_backlog:int -> Unix.sockaddr -> socket
 
 val start :
-  ?header_read_timeout:float -> conn_limit:int -> 'a Server.t -> socket -> unit
+     ?header_read_timeout:float
+  -> ?header_total_timeout:float
+  -> conn_limit:int
+  -> 'a Server.t
+  -> socket
+  -> unit
 
 val handle_one : 'a Server.t -> Unix.file_descr -> 'a -> Http.Request.t -> bool
 
diff --git a/ocaml/libs/http-svr/http_test.ml b/ocaml/libs/http-svr/http_test.ml
index e067a8b8a..462f46066 100644
--- a/ocaml/libs/http-svr/http_test.ml
+++ b/ocaml/libs/http-svr/http_test.ml
@@ -200,7 +200,8 @@ let test_read_http_request_header _ =
   |> List.iter (fun (frame, proxy, header) ->
          with_fd (mk_header_string ~frame ~proxy ~header) (fun fd ->
              let actual_frame, actual_header, actual_proxy =
-               Http.read_http_request_header ~read_timeout:None fd
+               Http.read_http_request_header ~read_timeout:None
+                 ~total_timeout:None fd
              in
              assert (actual_frame = frame) ;
              assert (actual_header = header) ;
diff --git a/ocaml/xapi/xapi_globs.ml b/ocaml/xapi/xapi_globs.ml
index 4c874ff65..69a4fae68 100644
--- a/ocaml/xapi/xapi_globs.ml
+++ b/ocaml/xapi/xapi_globs.ml
@@ -960,6 +960,9 @@ let samba_dir = "/var/lib/samba"
 let header_read_timeout_tcp = ref 10.
 (* Timeout in seconds for every read while reading HTTP headers (on TCP only) *)
 
+let header_total_timeout_tcp = ref 60.
+(* Timeout in seconds to receive all HTTP headers (on TCP only) *)
+
 let conn_limit_tcp = ref 800
 
 let conn_limit_unix = ref 1024
@@ -1040,6 +1043,7 @@ let xapi_globs_spec =
     , Float winbind_update_closest_kdc_interval
     )
   ; ("header_read_timeout_tcp", Float header_read_timeout_tcp)
+  ; ("header_total_timeout_tcp", Float header_total_timeout_tcp)
   ; ("conn_limit_tcp", Int conn_limit_tcp)
   ; ("conn_limit_unix", Int conn_limit_unix)
   ; ("conn_limit_clientcert", Int conn_limit_clientcert)
diff --git a/ocaml/xapi/xapi_mgmt_iface.ml b/ocaml/xapi/xapi_mgmt_iface.ml
index 80a4852aa..084b43531 100644
--- a/ocaml/xapi/xapi_mgmt_iface.ml
+++ b/ocaml/xapi/xapi_mgmt_iface.ml
@@ -83,6 +83,7 @@ end = struct
     in
     Http_svr.start
       ~header_read_timeout:!Xapi_globs.header_read_timeout_tcp
+      ~header_total_timeout:!Xapi_globs.header_total_timeout_tcp
       ~conn_limit:!Xapi_globs.conn_limit_tcp Xapi_http.server socket ;
     management_servers := socket :: !management_servers ;
     if Pool_role.is_master () && addr = None then
-- 
2.31.1

