(************************************************************************)
(* This file is part of SKS.  SKS is free software; you can
   redistribute it and/or modify it under the terms of the GNU General
   Public License as published by the Free Software Foundation; either
   version 2 of the License, or (at your option) any later version.

   This program is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
   USA *)
(***********************************************************************)

(** Executable: server process that handles database and 
  database queries. *)

module F(M:sig end) = 
struct
  open StdLabels
  open MoreLabels
  open Printf
  open Common
  open Packet
  module Unix = UnixLabels
  open Unix
  open DbMessages
  open Request
  open Pstyle

  let () = 
    set_logfile ".db";
    ignore (plerror 0 "sks_db, SKS version %s" version); 
    ignore (plerror 0 "Copyright Yaron Minsky 2002"); 
    ignore (plerror 0 "Licensed under GPL.  See COPYING file for details"); 
    ignore (plerror 3 "http port: %d" http_port)

  let settings = {
    Keydb.withtxn = !Settings.transactions;
    Keydb.cache_bytes = !Settings.cache_bytes;
    Keydb.pagesize = !Settings.pagesize;
    Keydb.dbdir = !Settings.dbdir;
    Keydb.dumpdir = !Settings.dumpdir;
  }
  module Keydb = Keydb.Safe

  (* Simple server code for handling DB requests.  This is the main control
     code for the DB. *)

  let withtxn = !Settings.transactions 
  let dbdir = !Settings.dbdir
  let () = 
    if not withtxn then 
      failwith "Running sks_db without transactions is no longer supported."


  let addr = inet_addr_of_string http_address
  let websock = Eventloop.create_sock (ADDR_INET (addr,http_port))
  let () = 
    if Sys.file_exists db_command_name 
    then Unix.unlink db_command_name
  let comsock = Eventloop.create_sock db_command_addr


  (*********************************************************************)
  (** Database checkpointing and syncing *)

  let sync () = 
    ignore (perror "Syncing database");
    Keydb.sync ();
    ignore (perror "Syncing complete")

  let sync_interval = !Settings.db_sync_interval

  let checkpoint () = 
    ignore (perror "Checkpointing database");
    Keydb.checkpoint ();
    ignore (perror "Checkpointing complete")
      
  let checkpoint_interval = !Settings.checkpoint_interval

  (***************************************************************)
  (*  Helper functions for http request handler   ****************)
  (***************************************************************)

  let ascending = compare
  let descending x y = compare y x

  (** sorts keys by time, dropping keys with no time *)
  let tsort_keys keys = 
    let kpairs = 
      List.fold_left ~init:[] keys
	~f:(fun list key -> 
	      try
		let ki = ParsePGP.parse_pubkey_info (List.hd key) in
		(ki.pk_ctime,key)::list
	      with
		| Sys.Break as e -> raise e
		| e -> list
	   )
    in
    let kpairs = List.sort ~cmp:descending kpairs in
    List.map ~f:snd kpairs

  (******************************************************************)

  let get_stats () = 
    let today = Stats.round_up_to_day (Unix.gettimeofday ()) in
    let log = 
      let maxsize = 20000 in
      let last_week = today -. (7. *. 24. *. 60. *. 60.) in
      Keydb.reverse_logquery ~maxsize last_week
    in
    let size = Keydb.get_num_keys () in
    (log,size)

  let last_stat_page = ref (Stats.generate_html_stats_page_nostats ())

  let calculate_stats_page () = 
    signore (plerror 3 "Calculating DB stats"); 
    let (log,size) = get_stats () in
    last_stat_page := Stats.generate_html_stats_page log size;
    signore (plerror 3 "Done calculating DB stats"); 
    []

  let get_keys_by_keyid keyid =
    let keyid_length = String.length keyid in
    let short_keyid = String.sub ~pos:(keyid_length - 4) ~len:4 keyid in
    let keys = Keydb.get_by_short_subkeyid short_keyid in
    match keyid_length with
      | 4 -> (* 32-bit keyid.  No further filtering required. *)
	  keys

      | 8 -> (* 64-bit keyid *) 
	  List.filter keys
	  ~f:(fun key -> (Fingerprint.from_key key).Fingerprint.keyid = keyid )

      | 20 -> (* 160-bit v. 4 fingerprint *)
	  List.filter keys
	  ~f:(fun key -> keyid = (Fingerprint.from_key key).Fingerprint.fp )

      | 16 -> (* 128-bit v3 fingerprint.  Not supported *)
	  failwith "128-bit v3 fingerprints not implemented"

      | _ -> failwith "unknown keyid type"
	  

  (** return uid given keyid *)
  let get_uids keyid = 
    let keys = get_keys_by_keyid keyid in
    match keys with
      | [] | _::_::_ -> []
      | key::tl ->
	  let pkey = KeyMerge.key_to_pkey key in
	  pkey.KeyMerge.uids

  (******************************************************************)
  (******************************************************************)

  let check_prefix string prefix = 
    String.length string >= String.length prefix &&
    (String.sub ~pos:0 ~len:(String.length prefix) string = prefix)

  let lookup_keys search_terms =
    let keys = 
      match search_terms with
	| [] -> []
	| first::rest ->
	    if check_prefix first "0x" then 
	      (* keyid search *)
	      let keyid_string_length = String.length first - 2 in
	      let keyid = 
		KeyHash.dehexify (String.sub ~pos:2 ~len:keyid_string_length first)
	      in
	      let keys = (try get_keys_by_keyid keyid 
			  with Failure s -> raise (Wserver.Misc_error s))
	      in
	      keys
	    else 
	      let keys = Keydb.get_by_words ~max:!Settings.max_matches 
			   search_terms 
	      in
	      tsort_keys keys
    in
    if keys = [] then raise (Wserver.Misc_error "No keys found")
    else keys


  (******************************************************************)

  let handle_get_request request =
    match request.kind with
      | Stats -> 
	  signore (plerror 4 "/pks/lookup: DB Stats request");
	  ("text/html", !last_stat_page)
      | Get -> 
	  signore (plerror 4 "/pks/lookup: Get request (%s)"
		     (String.concat " " request.search));
	  let keys = lookup_keys request.search in
	  let keystr = Key.to_string_multiple keys in
	  let aakeys = Armor.encode_pubkey_string keystr in
	  ("text/html",
	   HtmlTemplates.page  
	     ~title:(sprintf "Public Key Server -- Get ``%s ''" 
		       (String.concat ~sep:" " request.search))
	     ~body:(sprintf "\r\n<pre>\r\n%s\r\n</pre>\r\n" aakeys)
	  )
      | HGet -> 
	  let hash_str = List.hd request.search in
	  signore (plerror 4 "/pks/lookup: Hash search: %s" hash_str);
	  let hash = KeyHash.dehexify hash_str in
	  flush Pervasives.stdout;
	  let key = 
	    try Keydb.get_by_hash hash with Not_found -> 
	      raise (Wserver.Misc_error "Requested hash not found")
	  in
	  let keystr = Key.to_string key in
	  let aakeys = Armor.encode_pubkey_string keystr in
	  ("text/html",
	   HtmlTemplates.page  
	     ~title:(sprintf "Public Key Server -- Get ``%s ''" hash_str)
	     ~body:(sprintf "\r\n<pre>\r\n%s\r\n</pre>\r\n" aakeys)
	  )

      | Index | VIndex ->  
	  (* VIndex requests are treated indentically to index requests *)
	  signore (plerror 4 "/pks/lookup: Index request: (%s)" 
		     (String.concat " " request.search));
	  let keys = lookup_keys request.search in
	  if request.machine_readable then 
	    ("text/plain",
	     MRindex.keys_to_index keys)
	  else 
	    begin
	      try
		let output = 
		  if request.kind = VIndex then
		    List.map ~f:(Index.key_to_lines_verbose ~get_uids request) keys 
		  else
		    List.map ~f:(Index.key_to_lines_normal request) keys 
		in
		let output = List.flatten output in
		let pre = HtmlTemplates.preformat_list 
			    (Index.keyinfo_header request :: output)
		in
		("text/html",
		 HtmlTemplates.page ~body:pre
		   ~title:(sprintf "Search results for '%s'" 
			     (String.concat ~sep:" " request.search))
		)

	      with
		| Invalid_argument "Insufficiently specific words" ->
		    raise (Wserver.Misc_error 
			     ("Insufficiently specific words: provide " ^
			      "at least one more specific keyword"))

		| Invalid_argument "Too many responses" ->
		    raise (Wserver.Misc_error 
			     "Too many responses, unable to process query")
	    end

  let string_to_oplist s = 
    let s = Wserver.strip s in 
    try
      let (base,op_string) = chsplit '?' s in
      let oplist = Str.split amp op_string in
      let pairs = List.map ~f:(chsplit '=') oplist in
      let oplist = 
	List.map pairs
	  ~f:(fun (key,value) -> (key, Wserver.decode value))
      in
      (base,oplist)
    with
	Not_found -> (s,[])

  let get_extension s = 
    let pos = String.rindex s '.' in
    s </> (pos,0)

  let bool_to_string b = if b then "true" else "false"
  let print_request cout r = 
    fprintf cout "   kind: %s\n" (
      (function 
	   Index -> "index" | VIndex -> "vindex" | Stats -> "stats"
	 | Get -> "get" | HGet -> "hashget")
      r.kind);
    fprintf cout "   fingerprint: %s\n" (bool_to_string r.fingerprint);
    fprintf cout "   exact: %s\n" (bool_to_string r.exact);
    fprintf cout "   search: %s\n"
      (MList.to_string ~f:(fun x -> x) r.search)

  let get_keystrings_from_hashes hashes = 
    let rec loop hashes keystrings = match hashes with
	[] -> keystrings
      | hash::tl -> 
	  try 
	    let keystring = Keydb.get_keystring_by_hash hash in
	    loop tl (keystring::keystrings)
	  with
	      e ->
		ignore (eplerror 2 e "Error fetching key from hash %s" 
			  (KeyHash.hexify hash));
		loop tl keystrings
    in
    loop hashes []

  let read_file ?(binary=false) fname = 
    if not (Sys.file_exists fname) then raise (Wserver.Page_not_found fname);
    let f = (if binary then open_in_bin else open_in) fname in
    protect ~f:(fun () ->
		  let length = in_channel_length f in
		  let buf = String.create length in
		  really_input f buf 0 length;
		  buf
	       )
      ~finally:(fun () -> close_in f)


  let is_safe char = 
    (char >= 'A' && char <= 'Z') || (char >= 'a' && char <= 'z') || 
    (char >= '0' && char <= '9') || (char = '.')
    

  let verify_web_fname fname = 
    let bad = ref false in
    let pos = ref 0 in
    while not !bad && !pos < String.length fname do
      if not (is_safe fname.[!pos]) then bad := true;
      incr pos
    done;
    not !bad

  let convert_web_fname fname =
    if verify_web_fname fname then
      Filename.concat !Settings.basedir (Filename.concat "web" fname)
    else raise (Wserver.Misc_error "Malformed requst")

  (** Handler for HTTP requests *)
  let webhandler addr msg cout = 
    match msg with 
      | Wserver.GET (request,headers) ->
	  ignore (plerror 5 "Get request received");
	  let (base,oplist) = string_to_oplist request in
	  if base = "/pks/lookup" then (
	    let request = request_of_oplist oplist in
	    let (mimetype,body) = handle_get_request request in
	    cout#write_string body;
	    mimetype
	  ) else (
	    if (base = "/index.html" || base = "/index.htm" 
		|| base = "/" || base = "")
	    then
	      let fname = convert_web_fname "index.html" in 
	      let text = read_file fname in
	      cout#write_string text;
	      "text/html"
	    else 
	      (try 
		 let extension = get_extension base in
		 if extension = ".jpg" || extension = ".gif" || extension = ".png"
		 then 
		   let base = base </> (1,0) in
		   let image = read_file ~binary:true (convert_web_fname base) in
		   cout#write_string image;
		   (match extension with
		      | ".jpg" -> "image/jpeg" | ".gif" -> "image/gif"
		      | ".png" -> "image/png" 
		      | _ -> raise (Wserver.Misc_error 
				      ("internal error: no mimetype " ^
				       "for image extension"))
		   )
		 else raise Not_found
	       with
		   Not_found -> raise (Wserver.Page_not_found base)
	      )
	  )
      | Wserver.POST (request,headers,body) ->
	  let request = Wserver.strip request in
	  match request with
	      "/pks/add" ->
		let keytext = Scanf.sscanf body "keytext=%s" (fun s -> s) in
		let keytext = Wserver.decode keytext in
		let keys = Armor.decode_pubkey keytext in
		ignore (plerror 3 "Handling /pks/add for %d keys" 
			  (List.length keys)); 
		cout#write_string "<html><body>";
		let ctr = ref 0 in
		List.iter keys
		  ~f:(fun origkey -> 
			try
			  let key = Fixkey.canonicalize origkey in
			  signore (plerror 3 "/pks/add: key %s added to database"
				     (KeyHash.hexify (KeyHash.hash key))
				  );
			  Keydb.add_key_merge ~newkey:true key;
			  incr ctr;
			with
			  | Fixkey.Bad_key | KeyMerge.Unparseable_packet_sequence ->
			      cout#write_string
			      ("Add failed: Malformed Key --- unexpected packet " ^
			       "type and/or order of packets");
			      signore 
				(plerror 2 "key %s %s"
				   (KeyHash.hexify (KeyHash.hash origkey))
				   "could not be parsed by KeyMerge.canonicalize")
			  | Bdb.Key_exists as e ->
			      cout#write_string 
			      ("Add failed: identical key already " ^
			       "exists in database<br>");
			      signore (eperror e "Key add failed")
			  | e -> 
			      Eventloop.reraise e;
			      cout#write_string "Add failed<br>"; 
			      signore (eperror e "Key add failed")
		     );
		if !ctr > 0 then (
		  cout#write_string 
		    ("Key block added to key server database.\n  " ^
		     "New public keys added: <br>");
		  cout#write_string (sprintf "%d keys added succesfully.<br>" !ctr)
		);
		cout#write_string "</html></body>";
		"text/html"
	    | "/pks/hashquery" ->
		ignore (plerror 4 "Handling /pks/hashquery"); 
		let sin = new Channel.string_in_channel body 0 in
		let hashes = 
		  CMarshal.unmarshal_list ~f:CMarshal.unmarshal_string sin
		in
		let keystrings = get_keystrings_from_hashes hashes in
		signore (perror "%d keys found" (List.length keystrings));
		CMarshal.marshal_list ~f:CMarshal.marshal_string cout 
		  keystrings;
		"pgp/keys" (* This is a bogus content-type *)
	    | _ ->
		cout#write_string (HtmlTemplates.page 
				     ~title:"Unexpected POST request" 
				     ~body:"");
		"text/html"


  (** Prepare handler for use with eventloop by transforming system
    channels to Channel objects and by returning empty list instead 
    of unit *)
  let eventify_handler handle = 
    (fun addr cin cout ->
       let cin = (new Channel.sys_in_channel cin)
       and cout = (new Channel.sys_out_channel cout) in
       handle addr cin cout;
       []
    )

  let get_filters = 
    Utils.unit_memoize 
      (fun () -> 
	 try Str.split comma_rxp (Keydb.get_meta "filters")
	 with Not_found -> []
      )


  (** Handler for commands coming off of the db_command_addr *)
  let command_handler addr cin cout = 
    match (unmarshal cin).msg with
      | LogQuery (count,timestamp) -> 
	  let logresp = Keydb.logquery ~maxsize:count timestamp in
	  let length = List.length logresp in
	  if length > 0 then
	    ignore (plerror 3 "Sending LogResp size %d" length);
	  marshal cout (LogResp logresp)

      | WordQuery words -> 
	  ignore (plerror 3 "Handling WordQuery");
	  let keys = Keydb.get_by_words ~max:!Settings.max_matches words in
	  marshal cout (Keys keys)

      | Keys keys ->  
	  let keys = List.fold_left ~init:[] keys
		       ~f:(fun list key ->
			     try (Fixkey.canonicalize key)::list
			     with KeyMerge.Unparseable_packet_sequence | Fixkey.Bad_key -> list
			  )
	  in
	  marshal cout (Ack 0);
	  (try Keydb.add_keys_merge keys
	   with e -> ignore (eplerror 2 e "Key addition failed"))

      | DeleteKey hash ->
	  ignore (plerror 3 "Handling DeleteKey");
	  ( try
	      let hash = RMisc.truncate hash KeyHash.hash_bytes in
	      let key = Keydb.get_by_hash hash in
	      Keydb.delete_key ~hash key;
	      marshal cout (Ack 0);
	    with
		e -> 
		  marshal cout (Ack (-1));
		  raise e
	  )
      | HashRequest hashes ->
	  ignore (plerror 3 "Handling HashRequest");
	  let keys = 
	    List.fold_left hashes ~init:[]
	      ~f:(fun list hash ->
		    try (Keydb.get_by_hash hash)::list
		    with 
			Not_found -> 
			  ignore (plerror 2 "Requested key %s not found"
				    (Utils.hexstring hash));
			  list
		 )
	  in
	  ignore (plerror 3 "Returning set of %d keys" (List.length keys));
	  marshal cout (Keys keys)


      | Config (s,cvar) ->
	  ignore (plerror 4 "Received config message");
	  (match (s,cvar) with
	     | ("checkpoint", `none) ->
		 checkpoint ()
	     | ("filters", `none) ->
		 marshal cout (Filters (get_filters ()))
	     | (str,value) ->
		 ignore (perror "Unexpected config request <%s>" str)
	  )
	  

      | m -> 
	  marshal cout ProtocolError;
	  ignore (perror "Unexpected (%s) message" (msg_to_string m))


  (***********************************************************************)

  (** dequeues and transmits single key.  Returns true if there 
    might be more keys to be handled. *)
  let rec transmit_single_key () = 
    let txn = Keydb.txn_begin () in
    try
      match (try Some (Keydb.dequeue_key ~txn)
	     with Not_found -> None)
      with
	| Some (time,key) -> 
	    let body = Armor.encode_pubkey key in
	    let to_header = ("To", String.concat ~sep:", " 
			       (Membership.get_mailsync_partners ()))
	    in
	    let msg = { Sendmail.headers = 
			  [ to_header;
			    "From", Settings.get_from_addr ();
			    "Reply-To", Settings.get_from_addr ();
			    "Errors-To", Settings.get_from_addr ();
			    "Subject","incremental";
			    "Precedence","list";
			    "Content-type", "application/pgp-keys";
			    "X-KeyServer-Sent",!Settings.hostname;
			  ] ;
			Sendmail.body = body;
		      }
	    in
	    let string = Sendmail.msg_to_string msg in
	    signore (plerror 3 "Message transmitted for key %s"
		       (KeyHash.hexify (KeyHash.hash key)));
	    signore (plerror 6 "%s" string);
	    Sendmail.send msg;
	    Keydb.txn_commit txn;
	    signore (plerror 5 "transmission queue transaction committed");
	    true
	| None -> 
	    (* nothing was done, so commiting and aborting are same here *)
	    Keydb.txn_abort txn; 
	    false
      with
	  e -> 
	    Keydb.txn_abort txn;
	    raise e
	      

  (** Transmit all enqueued keys to other hosts *)
  let transmit_keys () = 
    while transmit_single_key () do () done;
    []

  (***********************************************************************)

  let run () = 
    Keydb.open_dbs settings;
    if !Settings.initial_stat then ignore (calculate_stats_page ());
    ignore (plerror 2 "Database opened");
    ignore (plerror 0 "Applied filters: %s" (String.concat ~sep:", " 
					       (get_filters ())));
    Eventloop.evloop

      (
	(if withtxn 
	 then (Ehandlers.repeat_forever_simple checkpoint_interval checkpoint)
	 else (Ehandlers.repeat_forever_simple sync_interval sync)) 
	@
	Ehandlers.repeat_forever_simple !Settings.membership_reload_time
	 Membership.reset_membership_time
	@
	(Ehandlers.repeat_forever 10. 
	   (Eventloop.make_tc ~cb:transmit_keys ~timeout:0
	      ~name:"mail transmit keys" )
	)
	@
	(Ehandlers.repeat_forever 10. 
	   (Eventloop.make_tc ~name:"mailsync" ~timeout:0
	      ~cb:(Mailsync.load_mailed_keys 
		     ~addkey:(Keydb.add_key_merge ~newkey:false)))
	)
	@
	(Ehandlers.repeat_at_hour !Settings.stat_calc_hour
	   calculate_stats_page)
      )

      (
	(websock, Eventloop.make_th ~name:"webserver" 
	  ~timeout:!Settings.wserver_timeout
	  ~cb:(Wserver.accept_connection webhandler ~recover_timeout:1))
	::
	 (comsock, Eventloop.make_th ~name:"command handler" 
	    ~timeout:!Settings.command_timeout
	    ~cb:(eventify_handler command_handler))
	::
	 (if !Settings.use_port_80 then
	    let sock = Eventloop.create_sock (ADDR_INET (addr,80)) in
	    (sock,Eventloop.make_th ~name:"webserver80" 
	       ~timeout:!Settings.wserver_timeout
	       ~cb:(Wserver.accept_connection webhandler ~recover_timeout:1)
	    )::[]
	  else
	    []
	 )
      )



  let run () = 
    protect ~f:run
      ~finally:(fun () -> 
		  set_catch_break false;
		  ignore (plerror 0 "Shutting down database"); 
		  Keydb.sync ();
		  ignore (plerror 0 "Database sync'd"); 
		  Keydb.unconditional_checkpoint ();
		  ignore (plerror 0 "Database checkpointed"); 
		  Keydb.close_dbs ();
		  ignore (plerror 0 "Database closed")
	       )
end
