open Hammer_errors

(******************************************************************************)

let is_alpha = function 'A'..'Z'|'a'..'z'|'_' -> true | _ -> false

let is_good_dep s = is_alpha (String.get s 0) && not (Hhlib.string_begins_with s "_HAMMER_")

let get_deps lst = List.filter is_good_dep lst

let get_defs lst =
  List.filter is_good_dep
    (List.map (fun s -> String.sub s 6 (String.length s - 6))
       (List.filter (fun s -> Hhlib.string_begins_with s "$_def_") lst))

(******************************************************************************)

let call_eprover infile outfile =
  let tmt = string_of_int !Opt.atp_timelimit in
  let cmd =
    "eprover -s --cpu-limit=" ^ tmt ^ " --auto-schedule -R --print-statistics -p --tstp-format \"" ^ infile ^ "\" 2>/dev/null | grep \"file[(]'\\|# SZS\" > \"" ^ outfile ^ "\""
  in
  if !Opt.debug_mode then
    Msg.info cmd;
  if Sys.command cmd = 0 then
    Sys.command ("grep -q -s \"SZS status Theorem\" " ^ outfile) = 0
  else
    false

let extract_eprover_data outfile =
  try
    let ic = open_in outfile
    in
    let rec pom acc =
      try
        let ln = input_line ic in
        if String.get ln 0 = '#' then
          pom acc
        else if String.sub ln ((String.index ln ',') + 2) 5 = "axiom" then
          let i = String.rindex ln ',' + 2 in
          let j = String.rindex ln '\'' in
          let name = Scanf.unescaped (String.sub ln (i + 1) (j - i - 1)) in
          pom (name :: acc)
        else
          pom acc
      with
      | End_of_file ->
        acc
      | Not_found | Invalid_argument(_) ->
        pom acc
    in
    let names = pom []
    in
    close_in ic;
    (get_deps names, get_defs names)
  with _ ->
    raise (HammerError "Failed to extract EProver data")

let call_z3 infile outfile =
  let tmt = string_of_int !Opt.atp_timelimit in
  let cmd =
    "timeout " ^ tmt ^ " z3 -tptp DISPLAY_UNSAT_CORE=true ELIM_QUANTIFIERS=true PULL_NESTED_QUANTIFIERS=true -T:" ^ tmt ^ " " ^ infile ^ " 2>/dev/null > " ^ outfile
  in
  if !Opt.debug_mode then
    Msg.info cmd;
  if Sys.command cmd = 0 then
    Sys.command ("grep -q -s \"SZS status Theorem\" " ^ outfile) = 0
  else
    false

let extract_z3_data outfile =
  try
    let ic = open_in outfile
    in
    ignore (input_line ic);
    let ln = input_line ic in
    let i = String.index ln '[' in
    let j = String.rindex ln ']' in
    let s = String.sub ln (i + 2) (j - i - 3) in
    let names = List.map Scanf.unescaped (Str.split (Str.regexp "', '") s) in
    close_in ic;
    (get_deps names, get_defs names)
  with _ ->
    raise (HammerError "Failed to extract Z3 data")

let call_vampire infile outfile =
  let tmt = string_of_int !Opt.atp_timelimit in
  let cmd =
    "timeout " ^ tmt ^ " vampire --mode casc -t " ^ tmt ^ " --proof tptp --output_axiom_names on " ^ infile ^ " | grep \"file[(]'\|% SZS\" > " ^ outfile
  in
  if !Opt.debug_mode then
    Msg.info cmd;
  if Sys.command cmd = 0 then
    Sys.command ("grep -q -s \"SZS status Theorem\" " ^ outfile) = 0
  else
    false

let extract_vampire_data outfile =
  try
    let ic = open_in outfile
    in
    let rec pom acc =
      try
        let ln = input_line ic in
        if String.get ln 0 = '%' then
          pom acc
        else
          let i = String.rindex ln ',' + 1 in
          let j = String.rindex ln '\'' in
          let name = Scanf.unescaped (String.sub ln (i + 1) (j - i - 1)) in
          if name <> "HAMMER_GOAL" then
            pom (name :: acc)
          else
            pom acc
      with
      | End_of_file ->
        acc
      | Not_found | Invalid_argument(_) ->
        pom acc
    in
    let names = pom []
    in
    close_in ic;
    (get_deps names, get_defs names)
  with _ ->
    raise (HammerError "Failed to extract Vampire data")

(******************************************************************************)

let provers = [(Opt.vampire_enabled, "Vampire", call_vampire, extract_vampire_data);
               (Opt.z3_enabled, "Z3", call_z3, extract_z3_data);
               (Opt.eprover_enabled, "EProver", call_eprover, extract_eprover_data)]

let call_provers fname ofname =
  let rec pom lst =
    match lst with
    | [] -> raise (HammerFailure "ATPs failed to find a proof")
    | (enabled, pname, call, extract) :: t when !enabled ->
      Msg.info ("Running " ^ pname ^ "...");
      if call fname ofname then
        begin
          let (deps, defs) = extract ofname in
          let n = List.length deps in
          if n <= !Opt.max_atp_predictions then
            (pname, (deps, defs))
          else
            begin
              Msg.info (pname ^ " returned too many predictions (" ^ string_of_int n ^ ")");
              pom t
            end
        end
      else
        begin
          Msg.info (pname ^ " failed");
          pom t
        end
    | _ :: t ->
      pom t
  in
  pom provers

(******************************************************************************)
(* Main functions *)

let write_atp_file fname deps1 hyps deps goal =
  let name = Hh_term.get_hhdef_name goal in
  let depnames = List.map Hh_term.get_hhdef_name (hyps @ deps1) in
  Coq.opt_compute_dependencies := false;
  Coq.remove_def name;
  List.iter (fun d -> Coq.remove_def (Hh_term.get_hhdef_name d)) hyps;
  Coq.reinit (goal :: hyps @ deps);
  Msg.info ("Translating the problem to FOL...");
  Coq.retranslate (name :: depnames);
  if !Opt.debug_mode then
    Msg.info ("Writing translated problem to file '" ^ fname ^ "'...");
  Coq.write_problem fname name depnames

let predict deps1 hyps deps goal =
  let prn_lst lst =
    match lst with
    | [] -> ""
    | h :: t ->
      List.fold_right (fun x a -> (Hhlib.drop_prefix x "Top.") ^ ", " ^ a) t
        (Hhlib.drop_prefix h "Top.")
  in
  let fname = Filename.temp_file "coqhammer" ".p" in
  write_atp_file fname deps1 hyps deps goal;
  let ofname = fname ^ ".out" in
  let clean () =
    if not !Opt.debug_mode then
      begin
        Sys.remove fname;
        Sys.remove ofname
      end
  in
  try
    let (pname, (deps, defs)) = call_provers fname ofname in
    Msg.info (pname ^ " succeeded\n - dependencies: " ^ prn_lst deps ^
                "\n - definitions: " ^ prn_lst defs);
    clean ();
    (deps, defs)
  with e ->
    clean ();
    raise e
