From 1a47417a14151469ae29143e14131cb2f0be04da Mon Sep 17 00:00:00 2001
From: Rob Hoes <rob.hoes@citrix.com>
Date: Tue, 26 Jul 2022 16:20:19 +0000
Subject: [PATCH 6/6] Maximum header length

Signed-off-by: Rob Hoes <rob.hoes@citrix.com>
---
 ocaml/libs/http-svr/http.ml      | 14 ++++++++++----
 ocaml/libs/http-svr/http.mli     |  3 +++
 ocaml/libs/http-svr/http_svr.ml  | 31 ++++++++++++++++++++++---------
 ocaml/libs/http-svr/http_svr.mli |  1 +
 ocaml/libs/http-svr/http_test.ml |  2 +-
 ocaml/xapi/xapi_globs.ml         |  4 ++++
 ocaml/xapi/xapi_mgmt_iface.ml    |  1 +
 7 files changed, 42 insertions(+), 14 deletions(-)

diff --git a/ocaml/libs/http-svr/http.ml b/ocaml/libs/http-svr/http.ml
index af8a56ee2..08ac0c683 100644
--- a/ocaml/libs/http-svr/http.ml
+++ b/ocaml/libs/http-svr/http.ml
@@ -28,6 +28,8 @@ exception Malformed_url of string
 
 exception Timeout
 
+exception Too_large
+
 module D = Debug.Make (struct let name = "http" end)
 
 open D
@@ -283,7 +285,7 @@ let header_len_header = Printf.sprintf "\r\n%s:" Hdr.header_len
 
 let header_len_value_len = 5
 
-let read_up_to ?deadline buf already_read marker fd =
+let read_up_to ?deadline ?max buf already_read marker fd =
   let marker = Scanner.make marker in
   let hl_marker = Scanner.make header_len_header in
   let b = ref 0 in
@@ -310,6 +312,7 @@ let read_up_to ?deadline buf already_read marker fd =
 		Printf.fprintf stderr "b = %d; safe_to_read = %d\n" !b safe_to_read;
 		flush stderr;
 *)
+    Option.iter (fun m -> if !b + safe_to_read > m then raise Too_large) max ;
     let n =
       if !b < already_read then
         min safe_to_read (already_read - !b)
@@ -377,9 +380,9 @@ let set_socket_timeout fd t =
     (* In the unit tests, the fd comes from a pipe... ignore *)
     ()
 
-let read_http_request_header ~read_timeout ~total_timeout fd =
+let read_http_request_header ~read_timeout ~total_timeout ~max_length fd =
   Option.iter (fun t -> set_socket_timeout fd t) read_timeout ;
-  let buf = Bytes.create 1024 in
+  let buf = Bytes.create (Option.value ~default:1024 max_length) in
   let deadline =
     Option.map
       (fun t ->
@@ -415,7 +418,10 @@ let read_http_request_header ~read_timeout ~total_timeout fd =
   let frame, headers_length =
     match read_frame_header buf with
     | None ->
-        (false, read_up_to ?deadline buf frame_header_length end_of_headers fd)
+        let max = Option.map (fun m -> m - frame_header_length) max_length in
+        ( false
+        , read_up_to ?deadline ?max buf frame_header_length end_of_headers fd
+        )
     | Some length ->
         check_timeout_and_read 0 length ;
         (true, length)
diff --git a/ocaml/libs/http-svr/http.mli b/ocaml/libs/http-svr/http.mli
index 23e636a50..53dd5d96f 100644
--- a/ocaml/libs/http-svr/http.mli
+++ b/ocaml/libs/http-svr/http.mli
@@ -30,6 +30,8 @@ exception Forbidden
 
 exception Timeout
 
+exception Too_large
+
 type authorization = Basic of string * string | UnknownAuth of string
 
 val make_frame_header : string -> string
@@ -37,6 +39,7 @@ val make_frame_header : string -> string
 val read_http_request_header :
      read_timeout:float option
   -> total_timeout:float option
+  -> max_length:int option
   -> Unix.file_descr
   -> bool * string * string option
 
diff --git a/ocaml/libs/http-svr/http_svr.ml b/ocaml/libs/http-svr/http_svr.ml
index 155462d33..112c26a1e 100644
--- a/ocaml/libs/http-svr/http_svr.ml
+++ b/ocaml/libs/http-svr/http_svr.ml
@@ -170,6 +170,13 @@ let response_request_timeout s =
   in
   response_error_html s "408" "Request Timeout" [] body
 
+let response_request_header_fields_too_large s =
+  let body =
+    "<html><body><h1>HTTP 431 request header fields too large</h1>Exceeded the \
+     maximum header size.</body></html>"
+  in
+  response_error_html s "431" "Request Header Fields Too Large" [] body
+
 let response_internal_error ?req ?extra s =
   let version = Option.map get_return_version req in
   let extra =
@@ -322,10 +329,11 @@ exception Generic_error of string
 
 (** [request_of_bio_exn ic] reads a single Http.req from [ic] and returns it. On error
     	it simply throws an exception and doesn't touch the output stream. *)
-let request_of_bio_exn ~proxy_seen ~read_timeout ~total_timeout bio =
+let request_of_bio_exn ~proxy_seen ~read_timeout ~total_timeout ~max_length bio
+    =
   let fd = Buf_io.fd_of bio in
   let frame, headers, proxy' =
-    Http.read_http_request_header ~read_timeout ~total_timeout fd
+    Http.read_http_request_header ~read_timeout ~total_timeout ~max_length fd
   in
   let proxy = match proxy' with None -> proxy_seen | x -> x in
   let additional_headers =
@@ -402,10 +410,10 @@ let request_of_bio_exn ~proxy_seen ~read_timeout ~total_timeout bio =
 
 (** [request_of_bio ic] returns [Some req] read from [ic], or [None]. If [None] it will have
     	already sent back a suitable error code and response to the client. *)
-let request_of_bio ?proxy_seen ~read_timeout ~total_timeout ic =
+let request_of_bio ?proxy_seen ~read_timeout ~total_timeout ~max_length ic =
   try
     let r, proxy =
-      request_of_bio_exn ~proxy_seen ~read_timeout ~total_timeout ic
+      request_of_bio_exn ~proxy_seen ~read_timeout ~total_timeout ~max_length ic
     in
     (Some r, proxy)
   with e ->
@@ -432,6 +440,8 @@ let request_of_bio ?proxy_seen ~read_timeout ~total_timeout ic =
             ()
         | Unix.Unix_error (Unix.EAGAIN, _, _) | Http.Timeout ->
             response_request_timeout ss
+        | Http.Too_large ->
+            response_request_header_fields_too_large ss
         (* Premature termination of connection! *)
         | Unix.Unix_error (a, b, c) ->
             response_internal_error ss
@@ -501,7 +511,7 @@ let handle_one (x : 'a Server.t) ss context req =
     !finished
 
 let handle_connection ~header_read_timeout ~header_total_timeout
-    (x : 'a Server.t) caller ss =
+    ~max_header_length (x : 'a Server.t) caller ss =
   ( match caller with
   | Unix.ADDR_UNIX _ ->
       debug "Accepted unix connection"
@@ -519,7 +529,8 @@ let handle_connection ~header_read_timeout ~header_total_timeout
   let rec loop ~read_timeout ~total_timeout proxy_seen =
     (* 1. we must successfully parse a request *)
     let req, proxy =
-      request_of_bio ?proxy_seen ~read_timeout ~total_timeout ic
+      request_of_bio ?proxy_seen ~read_timeout ~total_timeout
+        ~max_length:max_header_length ic
     in
     (* 2. now we attempt to process the request *)
     let finished =
@@ -600,12 +611,14 @@ let socket_table = Hashtbl.create 10
 type socket = Unix.file_descr * string
 
 (* Start an HTTP server on a new socket *)
-let start ?header_read_timeout ?header_total_timeout ~conn_limit
-    (x : 'a Server.t) (socket, name) =
+let start ?header_read_timeout ?header_total_timeout ?max_header_length
+    ~conn_limit (x : 'a Server.t) (socket, name) =
   let handler =
     {
       Server_io.name
-    ; body= handle_connection ~header_read_timeout ~header_total_timeout x
+    ; body=
+        handle_connection ~header_read_timeout ~header_total_timeout
+          ~max_header_length x
     ; lock= Xapi_stdext_threads.Semaphore.create conn_limit
     }
   in
diff --git a/ocaml/libs/http-svr/http_svr.mli b/ocaml/libs/http-svr/http_svr.mli
index 761e39436..323511bf4 100644
--- a/ocaml/libs/http-svr/http_svr.mli
+++ b/ocaml/libs/http-svr/http_svr.mli
@@ -62,6 +62,7 @@ val bind_retry : ?listen_backlog:int -> Unix.sockaddr -> socket
 val start :
      ?header_read_timeout:float
   -> ?header_total_timeout:float
+  -> ?max_header_length:int
   -> conn_limit:int
   -> 'a Server.t
   -> socket
diff --git a/ocaml/libs/http-svr/http_test.ml b/ocaml/libs/http-svr/http_test.ml
index 462f46066..4dad98a36 100644
--- a/ocaml/libs/http-svr/http_test.ml
+++ b/ocaml/libs/http-svr/http_test.ml
@@ -201,7 +201,7 @@ let test_read_http_request_header _ =
          with_fd (mk_header_string ~frame ~proxy ~header) (fun fd ->
              let actual_frame, actual_header, actual_proxy =
                Http.read_http_request_header ~read_timeout:None
-                 ~total_timeout:None fd
+                 ~total_timeout:None ~max_length:None fd
              in
              assert (actual_frame = frame) ;
              assert (actual_header = header) ;
diff --git a/ocaml/xapi/xapi_globs.ml b/ocaml/xapi/xapi_globs.ml
index 69a4fae68..505db0262 100644
--- a/ocaml/xapi/xapi_globs.ml
+++ b/ocaml/xapi/xapi_globs.ml
@@ -963,6 +963,9 @@ let header_read_timeout_tcp = ref 10.
 let header_total_timeout_tcp = ref 60.
 (* Timeout in seconds to receive all HTTP headers (on TCP only) *)
 
+let max_header_length_tcp = ref 1024
+(* Maximum accepted size of HTTP headers in bytes (on TCP only) *)
+
 let conn_limit_tcp = ref 800
 
 let conn_limit_unix = ref 1024
@@ -1044,6 +1047,7 @@ let xapi_globs_spec =
     )
   ; ("header_read_timeout_tcp", Float header_read_timeout_tcp)
   ; ("header_total_timeout_tcp", Float header_total_timeout_tcp)
+  ; ("max_header_length_tcp", Int max_header_length_tcp)
   ; ("conn_limit_tcp", Int conn_limit_tcp)
   ; ("conn_limit_unix", Int conn_limit_unix)
   ; ("conn_limit_clientcert", Int conn_limit_clientcert)
diff --git a/ocaml/xapi/xapi_mgmt_iface.ml b/ocaml/xapi/xapi_mgmt_iface.ml
index 084b43531..3e82cc8eb 100644
--- a/ocaml/xapi/xapi_mgmt_iface.ml
+++ b/ocaml/xapi/xapi_mgmt_iface.ml
@@ -84,6 +84,7 @@ end = struct
     Http_svr.start
       ~header_read_timeout:!Xapi_globs.header_read_timeout_tcp
       ~header_total_timeout:!Xapi_globs.header_total_timeout_tcp
+      ~max_header_length:!Xapi_globs.max_header_length_tcp
       ~conn_limit:!Xapi_globs.conn_limit_tcp Xapi_http.server socket ;
     management_servers := socket :: !management_servers ;
     if Pool_role.is_master () && addr = None then
-- 
2.31.1

