Handle `expectedType` in TheoryProofEngine (#3691)
authorAlex Ozdemir <aozdemir@hmc.edu>
Sat, 1 Feb 2020 07:44:24 +0000 (23:44 -0800)
committerGitHub <noreply@github.com>
Sat, 1 Feb 2020 07:44:24 +0000 (23:44 -0800)
`TheoryProofEngine` now uses the `expectedType` optional argument.
  * When printing terms, it sets this for theories that it dispatches too
  * It occasionally asks theories for help determining the `expectedType` using `equalityType`, which has a sensible default implementation.
  * It is mindful of `expectedType` when using the let map.

I also moved to hpp function implementations into the cpp.

src/proof/theory_proof.cpp
src/proof/theory_proof.h

index b74d4a4d22967ff1b011a2735c6a65b819750c22..b516c250f4bdcc3aae00b77210fd25e7d964fa73 100644 (file)
@@ -176,6 +176,30 @@ void TheoryProofEngine::printConstantDisequalityProof(std::ostream& os, Expr c1,
   getTheoryProof(theory::Theory::theoryOf(c1))->printConstantDisequalityProof(os, c1, c2, globalLetMap);
 }
 
+void TheoryProofEngine::printTheoryTerm(Expr term,
+                                        std::ostream& os,
+                                        const ProofLetMap& map,
+                                        TypeNode expectedType)
+{
+  this->printTheoryTermAsType(term, os, map, expectedType);
+}
+
+TypeNode TheoryProofEngine::equalityType(const Expr& left, const Expr& right)
+{
+  // Ask the two theories what they think..
+  TypeNode leftType = getTheoryProof(theory::Theory::theoryOf(left))->equalityType(left, right);
+  TypeNode rightType = getTheoryProof(theory::Theory::theoryOf(right))->equalityType(left, right);
+
+  // Error if the disagree.
+  Assert(leftType.isNull() || rightType.isNull() || leftType == rightType)
+    << "TheoryProofEngine::equalityType(" << left << ", " << right << "):" << std::endl
+    << "theories disagree about the type of an equality:" << std::endl
+    << "\tleft: " << leftType << std::endl
+    << "\tright:" << rightType;
+
+  return leftType.isNull() ? rightType : leftType;
+}
+
 void TheoryProofEngine::registerTerm(Expr term) {
   Debug("pf::tp::register") << "TheoryProofEngine::registerTerm: registering term: " << term << std::endl;
 
@@ -295,11 +319,11 @@ void LFSCTheoryProofEngine::printTheoryTermAsType(Expr term,
   if (theory_id == theory::THEORY_BUILTIN ||
       term.getKind() == kind::ITE ||
       term.getKind() == kind::EQUAL) {
-    printCoreTerm(term, os, map);
+    printCoreTerm(term, os, map, expectedType);
     return;
   }
   // dispatch to proper theory
-  getTheoryProof(theory_id)->printOwnedTerm(term, os, map);
+  getTheoryProof(theory_id)->printOwnedTerm(term, os, map, expectedType);
 }
 
 void LFSCTheoryProofEngine::printSort(Type type, std::ostream& os) {
@@ -866,18 +890,29 @@ void LFSCTheoryProofEngine::printBoundTermAsType(Expr term,
 {
   Debug("pf::tp") << "LFSCTheoryProofEngine::printBoundTerm( " << term << " ) " << std::endl;
 
-  ProofLetMap::const_iterator it = map.find(term);
-  if (it != map.end()) {
-    unsigned id = it->second.id;
-    unsigned count = it->second.count;
+  // Since let-abbreviated terms are abbreviated with their default type, only
+  // use the let map if there is no expectedType or the expectedType matches
+  // the default.
+  if (expectedType.isNull()
+      || TypeNode::fromType(term.getType()) == expectedType)
+  {
+    ProofLetMap::const_iterator it = map.find(term);
+    if (it != map.end())
+    {
+      unsigned id = it->second.id;
+      unsigned count = it->second.count;
 
-    if (count > LET_COUNT) {
-      os << "let" << id;
-      return;
+      if (count > LET_COUNT)
+      {
+        os << "let" << id;
+        Debug("pf::tp::letmap") << "Using let map for " << term << std::endl;
+        return;
+      }
     }
   }
+  Debug("pf::tp::letmap") << "Skipping let map for " << term << std::endl;
 
-  printTheoryTerm(term, os, map);
+  printTheoryTerm(term, os, map, expectedType);
 }
 
 void LFSCTheoryProofEngine::printBoundFormula(Expr term,
@@ -900,7 +935,7 @@ void LFSCTheoryProofEngine::printBoundFormula(Expr term,
 void LFSCTheoryProofEngine::printCoreTerm(Expr term,
                                           std::ostream& os,
                                           const ProofLetMap& map,
-                                          Type expectedType)
+                                          TypeNode expectedType)
 {
   if (term.isVariable()) {
     os << ProofManager::sanitize(term);
@@ -911,6 +946,9 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term,
 
   switch(k) {
   case kind::ITE: {
+    TypeNode armType = expectedType.isNull()
+                           ? TypeNode::fromType(term.getType())
+                           : expectedType;
     bool useFormulaType = term.getType().isBoolean();
     Assert(term[1].getType().isSubtypeOf(term.getType()));
     Assert(term[2].getType().isSubtypeOf(term.getType()));
@@ -924,7 +962,7 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term,
     }
     else
     {
-      printBoundTerm(term[1], os, map);
+      printBoundTerm(term[1], os, map, armType);
     }
     os << " ";
     if (useFormulaType)
@@ -941,6 +979,7 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term,
 
   case kind::EQUAL: {
     bool booleanCase = term[0].getType().isBoolean();
+    TypeNode armType = equalityType(term[0], term[1]);
 
     os << "(";
     if (booleanCase) {
@@ -952,13 +991,13 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term,
     }
 
     if (booleanCase && printsAsBool(term[0])) os << "(p_app ";
-    printBoundTerm(term[0], os, map);
+    printBoundTerm(term[0], os, map, armType);
     if (booleanCase && printsAsBool(term[0])) os << ")";
 
     os << " ";
 
     if (booleanCase && printsAsBool(term[1])) os << "(p_app ";
-    printBoundTerm(term[1], os, map);
+    printBoundTerm(term[1], os, map, armType);
     if (booleanCase && printsAsBool(term[1])) os << ") ";
     os << ")";
 
@@ -966,16 +1005,18 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term,
   }
 
   case kind::DISTINCT:
+  {
     // Distinct nodes can have any number of chidlren.
     Assert(term.getNumChildren() >= 2);
+    TypeNode armType = equalityType(term[0], term[1]);
 
     if (term.getNumChildren() == 2) {
       os << "(not (= ";
       printSort(term[0].getType(), os);
       os << " ";
-      printBoundTerm(term[0], os, map);
+      printBoundTerm(term[0], os, map, armType);
       os << " ";
-      printBoundTerm(term[1], os, map);
+      printBoundTerm(term[1], os, map, armType);
       os << "))";
     } else {
       unsigned numOfPairs = term.getNumChildren() * (term.getNumChildren() - 1) / 2;
@@ -985,28 +1026,29 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term,
 
       for (unsigned i = 0; i < term.getNumChildren(); ++i) {
         for (unsigned j = i + 1; j < term.getNumChildren(); ++j) {
+          TypeNode armType = equalityType(term[i], term[j]);
           if ((i != 0) || (j != 1)) {
             os << "(not (= ";
             printSort(term[0].getType(), os);
             os << " ";
-            printBoundTerm(term[i], os, map);
+            printBoundTerm(term[i], os, map, armType);
             os << " ";
-            printBoundTerm(term[j], os, map);
+            printBoundTerm(term[j], os, map, armType);
             os << ")))";
           } else {
             os << "(not (= ";
             printSort(term[0].getType(), os);
             os << " ";
-            printBoundTerm(term[0], os, map);
+            printBoundTerm(term[0], os, map, armType);
             os << " ";
-            printBoundTerm(term[1], os, map);
+            printBoundTerm(term[1], os, map, armType);
             os << "))";
           }
         }
       }
     }
-
     return;
+  }
 
   case kind::CHAIN: {
     // LFSC doesn't allow declarations with variable numbers of
@@ -1345,6 +1387,24 @@ void TheoryProof::printRewriteProof(std::ostream& os, const Node &n1, const Node
   os << "))";
 }
 
+void TheoryProof::printOwnedTerm(Expr term,
+                                 std::ostream& os,
+                                 const ProofLetMap& map,
+                                 TypeNode expectedType)
+{
+  this->printOwnedTermAsType(term, os, map, expectedType);
+}
+
+TypeNode TheoryProof::equalityType(const Expr& left, const Expr& right)
+{
+  Assert(left.getType() == right.getType())
+    << "TheoryProof::equalityType(" << left << ", " << right << "):" << std::endl
+    << "types disagree:" << std::endl
+    << "\tleft: " << left.getType() << std::endl
+    << "\tright:" << right.getType();
+  return TypeNode::fromType(left.getType());
+}
+
 bool TheoryProof::match(TNode n1, TNode n2)
 {
   theory::TheoryId theoryId = this->getTheoryId();
index e8569d636aadde070fde31f33781c6a107ef760d..85c8e5fee4ce0933d7dbedc301d0cbdb31249416 100644 (file)
@@ -198,14 +198,15 @@ public:
   void printTheoryTerm(Expr term,
                        std::ostream& os,
                        const ProofLetMap& map,
-                       TypeNode expectedType = TypeNode())
-  {
-    this->printTheoryTermAsType(term, os, map, expectedType);
-  }
+                       TypeNode expectedType = TypeNode());
   virtual void printTheoryTermAsType(Expr term,
                                      std::ostream& os,
                                      const ProofLetMap& map,
                                      TypeNode expectedType) = 0;
+  /**
+   * Calls `TheoryProof::equalityType` on the appropriate theory.
+   */
+  TypeNode equalityType(const Expr& left, const Expr& right);
 
   bool printsAsBool(const Node &n);
 };
@@ -227,7 +228,7 @@ public:
   void printCoreTerm(Expr term,
                      std::ostream& os,
                      const ProofLetMap& map,
-                     Type expectedType = Type());
+                     TypeNode expectedType = TypeNode());
   void printLetTerm(Expr term, std::ostream& os) override;
   void printBoundTermAsType(Expr term,
                             std::ostream& os,
@@ -306,15 +307,28 @@ protected:
   void printOwnedTerm(Expr term,
                       std::ostream& os,
                       const ProofLetMap& map,
-                      TypeNode expectedType = TypeNode())
-  {
-    this->printOwnedTermAsType(term, os, map, expectedType);
-  }
+                      TypeNode expectedType = TypeNode());
+
   virtual void printOwnedTermAsType(Expr term,
                                     std::ostream& os,
                                     const ProofLetMap& map,
                                     TypeNode expectedType) = 0;
 
+  /**
+   * Return the type (at the SMT level, the sort) of an equality or disequality
+   * between `left` and `right`.
+   *
+   * The default implementation asserts that the two have the same type, and
+   * returns it.
+   *
+   * A theory may want to do something else.
+   *
+   * For example, the theory of arithmetic allows equalities between Reals and
+   * Integers. In this case the integer is upcast to a real, and the equality
+   * is over reals.
+   */
+  virtual TypeNode equalityType(const Expr& left, const Expr& right);
+
   /**
    * Print the proof representation of the given type that belongs to some theory.
    *