Proper printing of proofs in the internal calculus (#6975)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 4 Aug 2021 22:28:55 +0000 (17:28 -0500)
committerGitHub <noreply@github.com>
Wed, 4 Aug 2021 22:28:55 +0000 (17:28 -0500)
src/proof/proof_node_to_sexpr.cpp
src/proof/proof_node_to_sexpr.h

index 85fc2395eb416e3d96eb37c03d1aa7c3d8000947..67a43fedc7017f72dc816665cca2b0cabec194f4 100644 (file)
@@ -19,7 +19,9 @@
 #include <sstream>
 
 #include "options/proof_options.h"
+#include "proof/proof_checker.h"
 #include "proof/proof_node.h"
+#include "theory/builtin/proof_checker.h"
 
 using namespace cvc5::kind;
 
@@ -71,7 +73,8 @@ Node ProofNodeToSExpr::convertToSExpr(const ProofNode* pn)
       traversing.pop_back();
       std::vector<Node> children;
       // add proof rule
-      children.push_back(getOrMkPfRuleVariable(cur->getRule()));
+      PfRule r = cur->getRule();
+      children.push_back(getOrMkPfRuleVariable(r));
       if (options::proofPrintConclusion())
       {
         children.push_back(d_conclusionMarker);
@@ -93,18 +96,14 @@ Node ProofNodeToSExpr::convertToSExpr(const ProofNode* pn)
         // needed to ensure builtin operators are not treated as operators
         // this can be the case for CONG where d_args may contain a builtin
         // operator
-        std::vector<Node> argsSafe;
-        for (const Node& a : args)
+        std::vector<Node> argsPrint;
+        for (size_t i = 0, nargs = args.size(); i < nargs; i++)
         {
-          Node av = a;
-          if (a.getNumChildren() == 0
-              && NodeManager::operatorToKind(a) != UNDEFINED_KIND)
-          {
-            av = getOrMkNodeVariable(a);
-          }
-          argsSafe.push_back(av);
+          ArgFormat f = getArgumentFormat(cur, i);
+          Node av = getArgument(args[i], f);
+          argsPrint.push_back(av);
         }
-        Node argsC = nm->mkNode(SEXPR, argsSafe);
+        Node argsC = nm->mkNode(SEXPR, argsPrint);
         children.push_back(argsC);
       }
       d_pnMap[cur] = nm->mkNode(SEXPR, children);
@@ -129,10 +128,96 @@ Node ProofNodeToSExpr::getOrMkPfRuleVariable(PfRule r)
   d_pfrMap[r] = var;
   return var;
 }
+Node ProofNodeToSExpr::getOrMkKindVariable(TNode n)
+{
+  Kind k;
+  if (!ProofRuleChecker::getKind(n, k))
+  {
+    // just use self if we failed to get the node, throw a debug failure
+    Assert(false) << "Expected kind node, got " << n;
+    return n;
+  }
+  std::map<Kind, Node>::iterator it = d_kindMap.find(k);
+  if (it != d_kindMap.end())
+  {
+    return it->second;
+  }
+  std::stringstream ss;
+  ss << k;
+  NodeManager* nm = NodeManager::currentNM();
+  Node var = nm->mkBoundVar(ss.str(), nm->sExprType());
+  d_kindMap[k] = var;
+  return var;
+}
+
+Node ProofNodeToSExpr::getOrMkTheoryIdVariable(TNode n)
+{
+  theory::TheoryId tid;
+  if (!theory::builtin::BuiltinProofRuleChecker::getTheoryId(n, tid))
+  {
+    // just use self if we failed to get the node, throw a debug failure
+    Assert(false) << "Expected theory id node, got " << n;
+    return n;
+  }
+  std::map<theory::TheoryId, Node>::iterator it = d_tidMap.find(tid);
+  if (it != d_tidMap.end())
+  {
+    return it->second;
+  }
+  std::stringstream ss;
+  ss << tid;
+  NodeManager* nm = NodeManager::currentNM();
+  Node var = nm->mkBoundVar(ss.str(), nm->sExprType());
+  d_tidMap[tid] = var;
+  return var;
+}
+
+Node ProofNodeToSExpr::getOrMkMethodIdVariable(TNode n)
+{
+  MethodId mid;
+  if (!getMethodId(n, mid))
+  {
+    // just use self if we failed to get the node, throw a debug failure
+    Assert(false) << "Expected method id node, got " << n;
+    return n;
+  }
+  std::map<MethodId, Node>::iterator it = d_midMap.find(mid);
+  if (it != d_midMap.end())
+  {
+    return it->second;
+  }
+  std::stringstream ss;
+  ss << mid;
+  NodeManager* nm = NodeManager::currentNM();
+  Node var = nm->mkBoundVar(ss.str(), nm->sExprType());
+  d_midMap[mid] = var;
+  return var;
+}
+Node ProofNodeToSExpr::getOrMkInferenceIdVariable(TNode n)
+{
+  theory::InferenceId iid;
+  if (!theory::getInferenceId(n, iid))
+  {
+    // just use self if we failed to get the node, throw a debug failure
+    Assert(false) << "Expected inference id node, got " << n;
+    return n;
+  }
+  std::map<theory::InferenceId, Node>::iterator it = d_iidMap.find(iid);
+  if (it != d_iidMap.end())
+  {
+    return it->second;
+  }
+  std::stringstream ss;
+  ss << iid;
+  NodeManager* nm = NodeManager::currentNM();
+  Node var = nm->mkBoundVar(ss.str(), nm->sExprType());
+  d_iidMap[iid] = var;
+  return var;
+}
 
-Node ProofNodeToSExpr::getOrMkNodeVariable(Node n)
+Node ProofNodeToSExpr::getOrMkNodeVariable(TNode n)
 {
-  std::map<Node, Node>::iterator it = d_nodeMap.find(n);
+  std::map<TNode, Node>::iterator it = d_nodeMap.find(n);
   if (it != d_nodeMap.end())
   {
     return it->second;
@@ -145,4 +230,76 @@ Node ProofNodeToSExpr::getOrMkNodeVariable(Node n)
   return var;
 }
 
+Node ProofNodeToSExpr::getArgument(Node arg, ArgFormat f)
+{
+  switch (f)
+  {
+    case ArgFormat::KIND: return getOrMkKindVariable(arg);
+    case ArgFormat::THEORY_ID: return getOrMkTheoryIdVariable(arg);
+    case ArgFormat::METHOD_ID: return getOrMkMethodIdVariable(arg);
+    case ArgFormat::INFERENCE_ID: return getOrMkInferenceIdVariable(arg);
+    case ArgFormat::NODE_VAR: return getOrMkNodeVariable(arg);
+    default: return arg;
+  }
+}
+
+ProofNodeToSExpr::ArgFormat ProofNodeToSExpr::getArgumentFormat(
+    const ProofNode* pn, size_t i)
+{
+  PfRule r = pn->getRule();
+  switch (r)
+  {
+    case PfRule::CONG:
+    {
+      if (i == 0)
+      {
+        return ArgFormat::KIND;
+      }
+      const std::vector<Node>& args = pn->getArguments();
+      Assert(i < args.size());
+      if (args[i].getNumChildren() == 0
+          && NodeManager::operatorToKind(args[i]) != UNDEFINED_KIND)
+      {
+        return ArgFormat::NODE_VAR;
+      }
+    }
+    break;
+    case PfRule::SUBS:
+    case PfRule::REWRITE:
+    case PfRule::MACRO_SR_EQ_INTRO:
+    case PfRule::MACRO_SR_PRED_INTRO:
+    case PfRule::MACRO_SR_PRED_TRANSFORM:
+      if (i > 0)
+      {
+        return ArgFormat::METHOD_ID;
+      }
+      break;
+    case PfRule::MACRO_SR_PRED_ELIM: return ArgFormat::METHOD_ID; break;
+    case PfRule::THEORY_LEMMA:
+    case PfRule::THEORY_REWRITE:
+      if (i == 1)
+      {
+        return ArgFormat::THEORY_ID;
+      }
+      else if (r == PfRule::THEORY_REWRITE && i == 2)
+      {
+        return ArgFormat::METHOD_ID;
+      }
+      break;
+    case PfRule::INSTANTIATE:
+    {
+      Assert(!pn->getChildren().empty());
+      Node q = pn->getChildren()[0]->getResult();
+      Assert(q.getKind() == kind::FORALL);
+      if (i == q[0].getNumChildren())
+      {
+        return ArgFormat::INFERENCE_ID;
+      }
+    }
+    break;
+    default: break;
+  }
+  return ArgFormat::DEFAULT;
+}
+
 }  // namespace cvc5
index c358f3445ba9b60c51f230a1323996df9d34b9f8..83d719aaf3d84e76acd175240d6c74f08b8b857b 100644 (file)
 
 #include <map>
 
+#include "expr/kind.h"
 #include "expr/node.h"
+#include "proof/method_id.h"
 #include "proof/proof_rule.h"
+#include "theory/inference_id.h"
+#include "theory/theory_id.h"
 
 namespace cvc5 {
 
@@ -46,8 +50,32 @@ class ProofNodeToSExpr
   Node convertToSExpr(const ProofNode* pn);
 
  private:
+  /** argument format, determines how to print an argument */
+  enum class ArgFormat
+  {
+    // just use the argument itself
+    DEFAULT,
+    // print the argument as a kind
+    KIND,
+    // print the argument as a theory id
+    THEORY_ID,
+    // print the argument as a method id
+    METHOD_ID,
+    // print the argument as an inference id
+    INFERENCE_ID,
+    // print a variable whose name is the term (see getOrMkNodeVariable)
+    NODE_VAR
+  };
   /** map proof rules to a variable */
   std::map<PfRule, Node> d_pfrMap;
+  /** map kind to a variable displaying the kind they represent */
+  std::map<Kind, Node> d_kindMap;
+  /** map theory ids to a variable displaying the theory id they represent */
+  std::map<theory::TheoryId, Node> d_tidMap;
+  /** map method ids to a variable displaying the method id they represent */
+  std::map<MethodId, Node> d_midMap;
+  /** map infer ids to a variable displaying the method id they represent */
+  std::map<theory::InferenceId, Node> d_iidMap;
   /** Dummy ":args" marker */
   Node d_argsMarker;
   /** Dummy ":conclusion" marker */
@@ -58,11 +86,27 @@ class ProofNodeToSExpr
    * map nodes to a bound variable, used for nodes that have special AST status
    * like builtin operators
    */
-  std::map<Node, Node> d_nodeMap;
+  std::map<TNode, Node> d_nodeMap;
   /** get or make pf rule variable */
   Node getOrMkPfRuleVariable(PfRule r);
-  /** get or make node variable */
-  Node getOrMkNodeVariable(Node n);
+  /** get or make kind variable from the kind embedded in n */
+  Node getOrMkKindVariable(TNode n);
+  /** get or make theory id variable */
+  Node getOrMkTheoryIdVariable(TNode n);
+  /** get or make method id variable */
+  Node getOrMkMethodIdVariable(TNode n);
+  /** get or make inference id variable */
+  Node getOrMkInferenceIdVariable(TNode n);
+  /**
+   * Get or make node variable that prints the same as n but has SEXPR type.
+   * This is used to ensure the type checker does not complain when trying to
+   * print e.g. builtin operators as first-class terms in the SEXPR.
+   */
+  Node getOrMkNodeVariable(TNode n);
+  /** get argument based on the provided format */
+  Node getArgument(Node arg, ArgFormat f);
+  /** get argument format for proof node */
+  ArgFormat getArgumentFormat(const ProofNode* pn, size_t i);
 };
 
 }  // namespace cvc5