structure hfCheck :> hfCheck =
struct

open HolKernel boolLib aiLib hfLib hfProofTerm hfProofRule hfDep hfTheory

val ERR = mk_HOL_ERR "hfCheck"

(* in betaeta normal form *)
type pfd = (term list * term, pf) Redblackmap.dict
val empty_pfd = dempty (cpl_compare (list_compare Term.compare) Term.compare)


fun dfind_pfd th pfd = dfind (ben_thm th) pfd
fun dmem_pfd th pfd = can (dfind_pfd th) pfd
fun dadd_pfd th pf pfd = dadd (ben_thm th) pf pfd

(* ------------------------------------------------------------------------
   Checking that the expected results of the rule corresponds to 
   what was recorded.
   ------------------------------------------------------------------------ *)

(*
val thmd = HOLset.fromList goal_compare
  (map (dest_thm o snd) 
     (List.concat (map DB.thms (ancestry (current_theory ())))))
*)

val lemmad = ref (dempty Int.compare)

fun assign_pf th1 (pf,asl,w) = 
  let val asl' = HOLset.listItems asl in
    if goal_compare (dest_thm th1, (asl',w)) <> EQUAL 
    then 
      (print_endline (thm_to_string th1);
       print_endline (thm_to_string (mk_thm (asl',w)));
       raise ERR "assign_pf" "")
    else pf
      (*
      let val pfn = 
        ((dfind th1 freqd - 1) handle NotFound => 0) * pf_size pf 
      in
        if pfn > 64 then
          (
          print_endline ("lemma:" ^ thm_to_string th1);
          if null (hyp th1) andalso 
             null (free_vars_lr (concl th1)) andalso
             null (type_vars_in_term (concl th1))
          then 
            let val name = dlength (!lemmad) in
              print_endline ("yes: " ^ its pfn);
              lemmad := dadd name pf (!lemmad); 
              KnO (its name)
            end
          else (print_endline ("no: " ^ its pfn); pf)
          )
        else pf
      end
      *)
  end

(* ------------------------------------------------------------------------
   Simulating HOL4 rules using Proofgold rules
   ------------------------------------------------------------------------ *)

val axiomd = 
  let val l = map (fn (a,b) => (dest_thm b,a)) (DB.thms "hf") in
    dnew goal_compare l
  end

val notest2 = ref false

fun simul_rule pfd rule = 
  let 
    val fake_seq = (KnO "D", HOLset.empty Term.compare, T)
    fun assume_pf msg = (print_endline msg; notest2 := true; fake_seq)
    fun find x = 
      let 
        fun cont () = (print_endline ("missing: " ^ thm_to_string x); KnO "D")
        val pf = dfind_pfd x pfd handle NotFound => cont ()
      in
        (pf, hypset x, concl x)
      end
  in
    case rule of
      pABS (term,thm) => pfABS term (find thm)
    | pALPHA (t1,t2) => pfALPHA t1 t2
    | pAP_TERM (term,thm) => pfAP_TERM term (find thm)
    | pAP_THM (thm,term) => pfAP_THM (find thm) term
    | pASSUME term => pfASSUME term
    | pBETA_CONV term => pfBETA_CONV term
    | pBeta thm => pfBeta (find thm)
    | pCCONTR (term,thm) => pfCCONTR term (find thm)
    | pCHOOSE ((term,th1),th2) => pfCHOOSE (term, find th1) (find th2)
    | pCONJ (th1,th2) => pfCONJ (find th1) (find th2)
    | pCONJUNCT1 thm => pfCONJUNCT1 (find thm)
    | pCONJUNCT2 thm => pfCONJUNCT2 (find thm)
    | pDISCH (term,thm) => pfDISCH term (find thm)
    | pDISJ1 (thm,term) => pfDISJ1 (find thm) term
    | pDISJ2 (term,thm) => pfDISJ2 term (find thm)
    | pDISJ_CASES (th1,th2,th3) => 
      pfDISJ_CASES (find th1) (find th2) (find th3)
    | pEQ_IMP_RULE_LEFT thm => pfEQ_IMP_RULE_LEFT (find thm)
    | pEQ_IMP_RULE_RIGHT thm => pfEQ_IMP_RULE_RIGHT (find thm)
    | pEQ_MP (th1,th2) => pfEQ_MP (find th1) (find th2)
    | pEXISTS ((v,term),thm) => pfEXISTS (v,term) (find thm)
    | pGEN (term,thm) => pfGEN term (find thm)
    | pGENL (terml,thm) => pfGENL terml (find thm)
    | pGEN_ABS (termo,vl,thm) => pfGEN_ABS termo vl (find thm)
    | pINST (tmsubst,thm) => pfINST tmsubst (find thm)
    | pINST_TYPE (tysubst,thm) => pfINST_TYPE tysubst (find thm)
    | pMK_COMB (th1,th2) => pfMK_COMB (find th1) (find th2)
    | pMP (th1,th2) => pfMP (find th1) (find th2)
    | pMk_abs thm => find thm
    | pMk_comb thm => find thm
    | pNOT_ELIM thm => pfNOT_ELIM (find thm)
    | pNOT_INTRO thm => pfNOT_INTRO (find thm)
    | pSPEC (term,thm) => pfSPEC term (find thm)
    | pSUBST (oldsubst,template,thm) => 
      let
        val fvs = Term.FVL [template] Term.empty_varset
        val vsubst = filter (fn {redex,residue} => HOLset.member (fvs,redex)) 
          oldsubst
        fun f {redex,residue} = {redex = redex, residue = find residue} 
      in
        pfSUBST (map f vsubst) template (find thm)
      end
    | pSYM thm => pfSYM (find thm)
    | pSpecialize (term,thm) => pfSpecialize term (find thm)
    | pTRANS (th1,th2) => pfTRANS (find th1) (find th2)
    | pmk_axiom_thm term => 
      (
      if term_eq term (concl ETA_AX) 
      then pfETA_AX 
      else assume_pf ("mk_axiom: " ^ term_to_string term)
      )
    | pmk_defn_thm term => 
      assume_pf ("mk_defn: " ^ term_to_string term)
    | pmk_oracle_thm (terml,term) => 
      assume_pf ("mk_oracle: " ^ thm_to_string (mk_thm (terml,term)))
    | prefl_nocheck (_,term) => pfREFL term
  end

(* ------------------------------------------------------------------------
   Verify a rule (relies on shortcut theorems)
   ------------------------------------------------------------------------ *)

(*
       fun try_match_thm thm tm =
           let 
             val subst = match_term (snd (strip_forall (concl thm))) tm 
             val _ = print_endline (
               String.concatWith " " 
                 (map (term_to_string o #residue) (fst subst)))
           in
             #1 (pfSPECL (map #residue (fst subst)) (find thm))
           end
        fun try_match_thml thml tm = case thml of
          [] => raise ERR "try_match_thml" ""
        | thm :: m => 
          (try_match_thm thm tm handle HOL_ERR _ => try_match_thml m tm)

fun fast_simul_rule pfd rule = 
  let 
    fun find x = 
      let 
        val goal = dest_thm x
        val pf = dfind goal pfd
          handle NotFound => 
          (
          if HOLset.member (thmd,goal)
          then (print_endline ("database: " ^ thm_to_string x); 
                KnO "database") 
          else (print_endline ("missing: " ^ thm_to_string x); KnO "missing") 
          ) 
     in
       (pf, hypset x, concl x)
      end
  in
    case rule of
      pABS (term,thm) => pfABS term (find thm)
    | pALPHA (t1,t2) => shALPHA t1 t2
    | pAP_TERM (term,thm) => shAP_TERM term (find thm)
    | pAP_THM (thm,term) => shAP_THM (find thm) term
    | pASSUME term => pfASSUME term
    | pBETA_CONV term => shBETA_CONV term
    | pBeta thm => pfBeta (find thm)
    | pCCONTR (term,thm) => pfCCONTR term (find thm)
    | pCHOOSE ((term,th1),th2) => pfCHOOSE (term, find th1) (find th2)
    | pCONJ (th1,th2) => pfCONJ (find th1) (find th2)
    | pCONJUNCT1 thm => pfCONJUNCT1 (find thm)
    | pCONJUNCT2 thm => pfCONJUNCT2 (find thm)
    | pDISCH (term,thm) => pfDISCH term (find thm)
    | pDISJ1 (thm,term) => pfDISJ1 (find thm) term
    | pDISJ2 (term,thm) => pfDISJ2 term (find thm)
    | pDISJ_CASES (th1,th2,th3) => 
      pfDISJ_CASES (find th1) (find th2) (find th3)
    | pEQ_IMP_RULE_LEFT thm => shEQ_IMP_RULE_LEFT (find thm)
    | pEQ_IMP_RULE_RIGHT thm => shEQ_IMP_RULE_RIGHT (find thm)
    | pEQ_MP (th1,th2) => shEQ_MP (find th1) (find th2)
    | pEXISTS ((v,term),thm) => pfEXISTS (v,term) (find thm)
    | pGEN (term,thm) => pfGEN term (find thm)
    | pGENL (terml,thm) => pfGENL terml (find thm)
    | pGEN_ABS (termo,vl,thm) => pfGEN_ABS termo vl (find thm)
    | pINST (tmsubst,thm) => pfINST tmsubst (find thm)
    | pINST_TYPE (tysubst,thm) => pfINST_TYPE tysubst (find thm)
    | pMK_COMB (th1,th2) => shMK_COMB (find th1) (find th2)
    | pMP (th1,th2) => pfMP (find th1) (find th2)
    | pMk_abs thm => find thm
    | pMk_comb thm => find thm
    | pNOT_ELIM thm => pfNOT_ELIM (find thm)
    | pNOT_INTRO thm => pfNOT_INTRO (find thm)
    | pSPEC (term,thm) => pfSPEC term (find thm)
    | pSUBST (oldsubst,template,thm) => 
      let
        val fvs = Term.FVL [template] Term.empty_varset
        val vsubst = filter (fn {redex,residue} => HOLset.member (fvs,redex)) 
          oldsubst
        fun f {redex,residue} = {redex = redex, residue = find residue} 
      in
        shSUBST (map f vsubst) template (find thm)
      end
    | pSYM thm => shSYM (find thm)
    | pSpecialize (term,thm) => pfSpecialize term (find thm)
    | pTRANS (th1,th2) => shTRANS (find th1) (find th2)
    | pmk_axiom_thm term => 
      (
      if term_eq term (concl ETA_AX) then pfETA_AX else
        (print_endline (term_to_string term);
         raise ERR "pmk_axiom_thm" "not supported")
      )
    | pmk_defn_thm term => 
      (
      print_endline (term_to_string term); 
      raise ERR "pmk_defn_thm" "not supported"
      )
    | pmk_oracle_thm (terml,term) => 
      ((KnO (dfind (terml,term) axiomd), 
        HOLset.fromList Term.compare terml, term)
      handle NotFound => 
      (print_endline (term_to_string term); 
       raise ERR "pmk_oracle_thm" ""))
    | prefl_nocheck (_,term) => shREFL term
  end
*)

(* ------------------------------------------------------------------------
   Verify a rule
   ------------------------------------------------------------------------ *)

val notest = ref false

fun simul_rule_h4 rule = case rule of
    pABS (term,thm) => ABS term thm
  | pALPHA (t1,t2) => ALPHA t1 t2
  | pAP_TERM (term,thm) => AP_TERM term thm
  | pAP_THM (thm,term) => AP_THM thm term
  | pASSUME term => ASSUME term
  | pBETA_CONV term => BETA_CONV term
  | pBeta thm => Beta thm
  | pCCONTR (term,thm) => CCONTR term thm
  | pCHOOSE ((term,th1),th2) => CHOOSE (term, th1) th2
  | pCONJ (th1,th2) => CONJ th1 th2
  | pCONJUNCT1 thm => CONJUNCT1 thm
  | pCONJUNCT2 thm => CONJUNCT2 thm
  | pDISCH (term,thm) => DISCH term thm
  | pDISJ1 (thm,term) => DISJ1 thm term
  | pDISJ2 (term,thm) => DISJ2 term thm
  | pDISJ_CASES (th1,th2,th3) => DISJ_CASES th1 th2 th3
  | pEQ_IMP_RULE_LEFT thm => fst (EQ_IMP_RULE thm)
  | pEQ_IMP_RULE_RIGHT thm => snd (EQ_IMP_RULE thm)
  | pEQ_MP (th1,th2) => EQ_MP th1 th2
  | pEXISTS ((v,term),thm) => EXISTS (v,term) thm
  | pGEN (term,thm) => GEN term thm
  | pGENL (terml,thm) => GENL terml thm
  | pGEN_ABS (termo,vl,thm) => GEN_ABS termo vl thm
  | pINST (tmsubst,thm) => INST tmsubst thm
  | pINST_TYPE (tysubst,thm) => INST_TYPE tysubst thm
  | pMK_COMB (th1,th2) => MK_COMB (th1,th2)
  | pMP (th1,th2) => MP th1 th2
  | pMk_abs thm => thm
  | pMk_comb thm => thm
  | pNOT_ELIM thm => NOT_ELIM thm
  | pNOT_INTRO thm => NOT_INTRO thm
  | pSPEC (term,thm) => SPEC term thm
  | pSUBST (vsubst,term,thm) => SUBST vsubst term thm
  | pSYM thm => SYM thm
  | pSpecialize (term,thm) => Specialize term thm
  | pTRANS (th1,th2) => TRANS th1 th2
  | pmk_axiom_thm term => (notest := true; TRUTH)
  | pmk_defn_thm term => (notest := true; TRUTH)
  | pmk_oracle_thm (terml,term) => (notest := true; TRUTH)
  | prefl_nocheck (_,term) => REFL term



(* ------------------------------------------------------------------------
   Checking a log
   ------------------------------------------------------------------------ *)

val last_rule = ref NONE

fun check_log_h4_aux i log = case log of
    [] => print_endline "self-simulation completed"
  | (thm,rule) :: m => 
    let 
      val _ = last_rule := SOME (thm,rule)
      val _ = notest := false
      val newthm = simul_rule_h4 rule
      val _ = if not (!notest) andalso thm_compare (thm,newthm) <> EQUAL then
         (print_endline (thm_to_string thm);
          print_endline (thm_to_string newthm);
          raise ERR "" "")
          else ()
    in
      if i mod 100 = 0 then print_endline (its i) else ();
      check_log_h4_aux (i+1) m
    end

fun check_log_h4 log = check_log_h4_aux 0 log

fun check_log_aux pfd log = case log of
    [] => pfd
  | (thm, pmk_oracle_thm (_,v)) :: m =>
    let
      val pf = KnO (fst (dest_var v))
      val newpfd = if dmem_pfd thm pfd then pfd else dadd_pfd thm pf pfd
      val n = dlength newpfd
    in
      if n mod 100 = 0 then print_endline (its n) else ();
      check_log_aux newpfd m
    end
  | (thm,rule) :: m => 
    let 
      val _ = last_rule := SOME (thm,rule)
      val _ = last_rules := []
      val _ = notest2 := false
      val seq = simul_rule pfd rule
      val pf = if !notest2 
               then #1 seq 
               else assign_pf thm seq
      val newpfd = 
        if dmem_pfd thm pfd then pfd else dadd_pfd thm pf pfd
      val n = dlength newpfd
    in
      (* if n mod 100 = 0 then print_endline (its n) else (); *)
      check_log_aux newpfd m
    end

fun check_log pfd log = check_log_aux pfd log

(* ------------------------------------------------------------------------
   Removing free variables from the produced proof term
   ------------------------------------------------------------------------ *)

fun choose_elem ty = 
  let 
    val (atyl,imty) = strip_type ty 
    val xl = map (fn ty => mk_var ("x",ty)) atyl
    val {Thy,Tyop,Args} = dest_thy_type imty 
    val imtm = 
      if Thy = "min" andalso Tyop = "bool" then F
      else if Thy = "hf" andalso Tyop = "set" then ``hf$Empty``
      else raise ERR "not supported type" ""
  in
    list_mk_abs (xl,imtm)
  end
  
fun rm_fv_in_pf (thm,pf) = 
  let 
    val (asl,w) = dest_thm thm 
    val fvl = free_vars_pf pf
    val vsubst = 
      map (fn x => {redex = x, residue = choose_elem (type_of x)}) fvl;
    val (newpf,newasl,neww) = pfINST_NOHYP vsubst (pf,hypset thm,w)
  in
    if goal_compare ((asl,w),(HOLset.listItems newasl,neww)) = EQUAL 
    then newpf
    else raise ERR "rm_fv_in_pf" ""
  end

end
