(*
 * ocamlcgi - Objective Caml library for writing CGIs
 * Copyright (C) 2003-2004 Merjis Ltd. (http://www.merjis.com/)
 * Copyright (C) 1997 Daniel de Rauglaudre, INRIA
 * Copyright (C) 1998 Jean-Christophe FILLIATRE
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the Free
 * Software Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 *
 * $Id: cgi.ml,v 1.33 2005/01/25 18:35:38 ChriS Exp $
 *)

open Apache
open Mod_caml_config
open Printf

let may f v = match v with None -> () | Some v -> f v

(* Split a string at the given separator character into multiple substrings. *)
let split separator text =
  let len = String.length text in
  let rec loop pos =
    if pos < len then
      try
       	let last = String.index_from text pos separator in
 	let str = String.sub text pos (last-pos) in
	  str::(loop (succ last))
      with Not_found ->
 	if pos < len then [String.sub text pos (len-pos)]
 	else []
    else []
  in
  loop 0



(* Generic expires headers. *)
module Expires =
struct
  let short_weekday =
    let wd = [| "Sun"; "Mon"; "Tue"; "Wed"; "Thu"; "Fri"; "Sat"; "Sun" |] in
    fun i -> Array.unsafe_get wd i

  let short_month =
  let month = [| "Jan"; "Feb"; "Mar"; "Apr"; "May"; "Jun"; "Jul"; "Aug";
		 "Sep"; "Oct"; "Nov"; "Dec" |] in
  fun i -> Array.unsafe_get month i

  let make offset =
    let t = Unix.time () in
    let tm = Unix.gmtime (t +. float offset) in
    Printf.sprintf "%s, %02d %s %04d %02d:%02d:%02d GMT"
      (short_weekday tm.Unix.tm_wday)
      tm.Unix.tm_mday
      (short_month tm.Unix.tm_mon)
      (tm.Unix.tm_year + 1900)
      tm.Unix.tm_hour
      tm.Unix.tm_min
      tm.Unix.tm_sec

  let past () = make ((-5) * 60)
  let short () = make (5 * 60)
  let medium () = make 86400
  let long () = make (365 * 2 * 86400)

end


module Cookie =
struct
  let special_rfc2068 = function
    | '\000' .. '\031' | '\127' .. '\255' (* Control chars and non-ASCII *)
    | '(' | ')' | '<' | '>' | '@' | ',' | ';' | ':' | '\\' | '"'
    | '/' | '[' | ']' | '?' | '=' | '{' | '}' (* tspecials *)
    | '+' | '%' (* not in RFC2068 but they serve to encode *)
      (* ' ' must also be encoded but its encoding '+' takes a single char. *)
	-> true
    | _ -> false

  let encode = Cgi_escape.encode_wrt special_rfc2068

  class cookie ~name ~value ~domain ~max_age ~expires ~path ~secure =
  object (self)
    val mutable name = name
    val mutable value = value
    val mutable domain = domain
    val mutable max_age = max_age
    val mutable expires = expires (* obsolete *)
    val mutable path = path
    val mutable secure = secure

    method name = name
    method value = value
    method domain = domain
    method max_age = max_age
    method expires = expires (* obsolete *)
    method path = path
    method secure = secure

    method set_name v = name <- v
    method set_value v = value <- v
    method set_domain v = domain <- v
    method set_max_age sec =
      max_age <- sec;
      expires <- (match sec with Some s -> Expires.make s | None -> "")
    method set_expires v = expires <- v (* obsolete *)
    method set_path v = path <- v
    method set_secure v = secure <- v

    method to_string =
      let buf = Buffer.create 128 in
      if String.length name > 0 && String.unsafe_get name 0 = '$' then
	(* TRICK: names cannot start with '$', so if it does add '+'
	   in front to protect it. '+' will be decoded as space, then
	   stripped. *)
	Buffer.add_char buf '+';
      Buffer.add_string buf (encode name);
      Buffer.add_char buf '=';
      Buffer.add_string buf (encode value);
      begin match max_age with
      | None ->
	  if expires <> "" then begin
	    Buffer.add_string buf "; expires=";
	    Buffer.add_string buf (encode expires);
	  end
      | Some sec ->
	  if sec > 0 then (
	    Buffer.add_string buf "; expires=";
	    Buffer.add_string buf (Expires.make sec)
	  )
	  else
	    Buffer.add_string buf "; expires=Thu, 1 Jan 1970 00:00:01 GMT";
      end;
      (* We do not encode the domain and path because they will be
	 interpreted by the browser to determine whether the cookie
	 must be sent back. *)
      if domain <> "" then begin
	Buffer.add_string buf "; domain=";
	Buffer.add_string buf domain;
      end;
      if path <> "" then begin
	Buffer.add_string buf "; path=";
	Buffer.add_string buf path;
      end;
      if secure then Buffer.add_string buf "; secure";
      Buffer.contents buf
  end

  let cookie ?max_age ?(expires="") ?(domain="") ?(path="") ?(secure=false)
      name value =
    new cookie ~name ~value ~expires ~max_age ~domain ~path ~secure

  let split_re = Pcre.regexp "; ?"

  let parse header =
    let cookies = Pcre.split ~rex:split_re header in
    List.map
      (fun s ->
	 let name, value =
	   try
	     let i = String.index s '=' in
	     (Cgi_escape.decode_range s 0 i,
	      Cgi_escape.decode_range s (succ i) (String.length s))
	   with
	     Not_found ->
	       (Cgi_escape.decode_range s 0 (String.length s), "") in
	 cookie name value) cookies
end


(* Parsing of the CGI arguments.*)
module Cgi_args =
struct
  let parse s =
    let assocs = split '&' s in
    (* [split] creates copy of [s], so we can modify them. *)
    let one_assoc s =
      try
	let i = String.index s '=' in
	(Cgi_escape.decode_range s 0 i,
	 Cgi_escape.decode_range s (succ i) (String.length s))
      with
	Not_found ->
	  (Cgi_escape.decode_range s 0 (String.length s), "") in
    List.map one_assoc assocs

  let make bindings =
    let encode (key, value) =
      Cgi_escape.encode key ^ "=" ^ Cgi_escape.encode value in
    String.concat "&" (List.map encode bindings)
end

(* Read the full body of a POST request provided it is at maximum
   [post_max] bytes long.
   @raise HttpError if the post data exceeds [post_max] bytes. *)
let get_post_body post_max r =
  (* http://www.auburn.edu/docs/apache/misc/client_block_api.html *)
  Request.setup_client_block r Request.REQUEST_CHUNKED_ERROR;

  if Request.should_client_block r then (
    let buf = Buffer.create 8192 in
    let rec loop () =
      let s = Request.get_client_block r in
      if s <> "" then (
	(* Check if there is space before adding the string so as to
	   avoid a String.create error if post_max = Sys.max_string_length *)
	if Buffer.length buf + String.length s <= post_max then (
	  Buffer.add_string buf s;
	  loop ()
	)
	else raise(HttpError cHTTP_REQUEST_ENTITY_TOO_LARGE);
      )
    in
    loop ();
    Buffer.contents buf
  )
  else ""


let parse_args post_max r =
  let s =
    match Request.method_number r with
    | M_GET ->
	Request.discard_request_body r;
	(try  Request.args r
	 with Not_found -> "")
    | M_POST ->
	get_post_body post_max r
    | _ -> raise (HttpError cHTTP_METHOD_NOT_ALLOWED)
  in
  Cgi_args.parse s

(* multipart_args: parsing of the CGI arguments for multipart/form-data
   encoding *)

(* XXX RWMJ's note: slowly converting this file to use Pcre. *)
let boundary_re1 =
  Str.regexp_case_fold "boundary=\"\\([^\"]+\\)\""
let boundary_re2 =
  Str.regexp_case_fold "boundary=\\([^ \t\r\n]+\\)"
let name_re1 =
  Str.regexp_case_fold "name=\"\\([^\"]+\\)\""
let name_re2 =
  Str.regexp_case_fold "name=\\([^ \t\r\n;:]+\\)"
let filename_re1 =
  Str.regexp_case_fold "filename=\"\\([^\"]*\\)\""
let filename_re2 =
  Str.regexp_case_fold "filename=\\([^ \t\r\n;:]+\\)"
let content_type_re1 =
  Str.regexp_case_fold "Content-type:[ \t]*\"\\([^\"]+\\)\""
let content_type_re2 =
  Str.regexp_case_fold "Content-type:[ \t]*\\([^ \t\r\n;:]+\\)"
let separator_re =
  Str.regexp "\r\n\r\n"

let match_string re1 re2 str =
  try
    ignore(Str.search_forward re1 str 0); Str.matched_group 1 str
  with Not_found ->
    ignore(Str.search_forward re2 str 0); Str.matched_group 1 str

(* Extract field name and value from a chunk.  Raise Not_found if not
   a valid chunk. *)

type upload_data = {
  upload_value: string;
  upload_filename: string;
  upload_content_type: string
}

let extract_field chunk =
  let pos_separator = Str.search_forward separator_re chunk 0 in
  let header = String.sub chunk 0 pos_separator in
  let field_name = match_string name_re1 name_re2 header in
  let field_filename =
    try match_string filename_re1 filename_re2 header
    with Not_found -> "" in
  let field_content_type =
    try match_string content_type_re1 content_type_re2 header
    with Not_found -> "" in
  let beg_value = pos_separator + 4 in
  (* Chop final \r\n that browsers insist on putting *)
  let end_value =
    let len = String.length chunk in
    if len >= beg_value && String.sub chunk (len - 2) 2 = "\r\n"
    then len - 2
    else len in
  let field_value =
    String.sub chunk beg_value (end_value - beg_value) in
  (field_name, { upload_filename = field_filename;
                 upload_content_type = field_content_type;
                 upload_value = field_value })

(* Same, for a list of chunks *)

let rec extract_fields accu = function
  | [] ->
      accu
  | chunk :: rem ->
      extract_fields
	(try extract_field chunk :: accu with Not_found -> accu)
	rem

(* Return true if the string starts with the prefix. *)
let string_starts_with s pref =
  String.length s >= String.length pref &&
  String.sub s 0 (String.length pref) = pref

(* Parse a multipart body. *)
let parse_multipart_args post_max r =
  (* Determine boundary delimiter *)
  let content_type = Table.get (Request.headers_in r) "Content-Type" in
  let boundary =
    try
      match_string boundary_re1 boundary_re2 content_type
    with Not_found ->
      failwith ("Cgi: no boundary provided in " ^ content_type) in
  (* Extract the fields. *)
  extract_fields []
    (Str.split (Str.regexp_string ("--" ^ boundary))
       (get_post_body post_max r))

let downconvert_upload_data fields =
  List.map (fun (name, field) -> name, field.upload_value) fields

let get_params post_max r =
  let is_multipart =
    try
      let req_method = Request.method_number r
      and content_type = Table.get (Request.headers_in r) "Content-Type" in
      req_method = M_POST &&
      string_starts_with content_type "multipart/form-data"
    with
      Not_found -> false
  in
  if is_multipart then
    let uploads = parse_multipart_args post_max r in
    let params = downconvert_upload_data uploads in
    params, uploads, true
  else
    parse_args post_max r, [], false

let get_cookies r =
  try
    let header = Table.get (Request.headers_in r) "Cookie" in
    Cookie.parse header
  with
    Not_found -> []


class type template =
object
  method output : (string -> unit) -> unit
end

class cgi ?(post_max=Sys.max_string_length) r =
  let params, uploads, is_multipart = get_params post_max r in
  let cookies = get_cookies r in
object (self)

  method private send_cookies ?cookie ?cookies ?(cookie_cache=false)
    is_redirect =
    let set_header =
      Table.add (if is_redirect then Request.err_headers_out r
		 else Request.headers_out r) in
    may (fun cookie -> set_header "Set-Cookie" cookie#to_string)
      (cookie : Cookie.cookie option);
    may (List.iter (fun cookie -> set_header "Set-Cookie" cookie#to_string))
      (cookies : Cookie.cookie list option);
    if not cookie_cache then begin
      set_header "Cache-control" "no-cache=\"set-cookie\"";
      (* For HTTP/1.0 proxies along the way.  Cache-control directives
	 override this for HTTP/1.1.  *)
      set_header "Expires" "Thu, 1 Jan 1970 00:00:00 GMT";
      set_header "Pragma" "no-cache"
    end

  val mutable header_not_emitted = true

  method header ?(content_type="text/html") ?cookie ?cookies ?cookie_cache () =
    if header_not_emitted then begin
      Request.set_content_type r content_type;
      self#send_cookies ?cookie ?cookies ?cookie_cache false;
      Request.send_http_header r;
      header_not_emitted <- false
    end

  method template : 'a. ?content_type:string -> ?cookie:Cookie.cookie ->
    ?cookies:Cookie.cookie list -> ?cookie_cache:bool ->
    (#template as 'a) -> unit =
    fun ?content_type ?cookie ?cookies ?cookie_cache template ->
      self#header ?content_type ?cookie ?cookies ?cookie_cache ();
      template#output (fun s ->
			 let n = print_string r s in
			 if String.length s <> n then failwith "cgi#template")

  method exit : 'a. unit -> 'a = fun () -> raise Exit

  method redirect : 'a. ?cookie:Cookie.cookie ->
    ?cookies:Cookie.cookie list -> ?cookie_cache:bool -> string -> 'a
    = fun ?cookie ?cookies ?cookie_cache url ->
      self#send_cookies ?cookie ?cookies ?cookie_cache true;
      Table.set (Request.headers_out r) "Location" url;
      raise (HttpError cREDIRECT)

  (* Note that eventually we'll add some optional arguments to this method *)
  (* See apache/src/main/util_script.c and mod_cgi.c *)
  method url () =
    Request.uri r			(* Catch Not_found ? XXX *)

  method param name = List.assoc name params

  method param_all name =
    List.fold_right (fun (k,v) l ->
		       if k = name then v :: l else l) params []

  method param_exists name =
    try  ignore (self#param name); true
    with Not_found -> false

  method param_true name =
    try  let str = self#param name in str <> "" && str <> "0"
    with Not_found -> false

  method params =
    params

  method is_multipart =
    is_multipart

  method upload name =
    snd (List.find (fun (this_name, _) -> name = this_name) uploads)

  method upload_all name =
    List.map snd
      (List.filter (fun (this_name, _) -> name = this_name) uploads)

  method cookie name =
    List.find (fun cookie -> cookie#name = name) cookies

  method cookies =
    cookies

  method log msg = prerr_endline msg (* FIXME: Is this the recommended way? *)

  method request =
    r
end


module Sendmail =
struct
  open Unix

  exception Failure of string

  let sendmail = ref mailer_path
  let sendmail_args = ref mailer_args

  let send () =
    open_process_out(!sendmail ^ " " ^ !sendmail_args)

  let close chan =
    match close_process_out chan with
    | WEXITED 0 -> ()
    | WEXITED n ->
	raise(Failure(sprintf "%s: non-zero exit status: %i" !sendmail n))
    | WSIGNALED n ->
	raise(Failure(sprintf "%s: killed by signal %i" !sendmail n))
    | WSTOPPED n ->
	raise(Failure(sprintf "%s: stopped by signal %i" !sendmail n))


  let output_header chan k v =
    output_string chan k;
    output_string chan ": ";
    output_string chan (Pcre.replace ~pat:"\n" ~templ:"\n\t" v);
    output_char chan '\n'

  let output_header_opt1 chan k = function
    | None -> ()
    | Some v -> output_header chan k v

  let output_header_optN chan k = function
    | None -> ()
    | Some v -> output_header chan k (String.concat ", " v)

  (* NB. 'to' is a reserved word. *)
  let send_mail ?subject ?to_addr ?cc ?bcc ?from ?content_type ?headers body =
    let chan = send () in
    (match headers with
     | None -> ()
     | Some hs -> List.iter (fun (k,v) -> output_header chan k v) hs);
    output_header_opt1 chan "Subject" subject;
    output_header_optN chan "To" to_addr;
    output_header_optN chan "Cc" cc;
    output_header_optN chan "Bcc" bcc;
    output_header_opt1 chan "From" from;
    output_header_opt1 chan "Content-Type" content_type;
    output_char chan '\n'; (* blank line = end of headera. *)
    (* Send the body. *)
    output_string chan body;
    (* Close connection. *)
    close chan

end


(* Same as in cgi_escape.ml *)
let char_of_hex =
  let hex = [| '0'; '1'; '2'; '3'; '4'; '5'; '6'; '7'; '8'; '9';
	       'A'; 'B'; 'C'; 'D'; 'E'; 'F' |] in
  fun i -> Array.unsafe_get hex i


(* Generate a suitable random number (32 hex digits) for use in random
   cookies, session IDs, etc.  These numbers are supposed to be very
   hard to predict.  *)
let random_sessionid =
  if Sys.file_exists "/dev/urandom" then begin
    fun () ->
      let s = String.create 32 (* local => thread safe *) in
      let chan = open_in_bin "/dev/urandom" in
      for i = 0 to 15 do
	let b = input_byte chan in
	let i2 = 2 * i in
	s.[i2] <- char_of_hex(b lsr 4);
	s.[i2 + 1] <- char_of_hex(b land 0x0F)
      done;
      close_in chan;
      s
  end
  else begin
    Random.self_init();
    fun () ->
      let s = String.create 32 in
      for i = 0 to 7 do
	let b = Random.int 0x10000
	and i4 = 4 * i in
	s.[i4] <- char_of_hex(b land 0xF);
	let b = b lsr 4 in
	s.[i4 + 1] <- char_of_hex(b land 0xF);
	let b = b lsr 4 in
	s.[i4 + 2] <- char_of_hex(b land 0xF);
	s.[i4 + 3] <- char_of_hex(b lsr 4)
      done;
      s
  end

