From 7583e034bbd991877b634d50249bbccfd9e3c763 Mon Sep 17 00:00:00 2001 From: Alex Ozdemir Date: Fri, 31 Jan 2020 23:44:24 -0800 Subject: [PATCH] Handle `expectedType` in TheoryProofEngine (#3691) `TheoryProofEngine` now uses the `expectedType` optional argument. * When printing terms, it sets this for theories that it dispatches too * It occasionally asks theories for help determining the `expectedType` using `equalityType`, which has a sensible default implementation. * It is mindful of `expectedType` when using the let map. I also moved to hpp function implementations into the cpp. --- src/proof/theory_proof.cpp | 102 +++++++++++++++++++++++++++++-------- src/proof/theory_proof.h | 32 ++++++++---- 2 files changed, 104 insertions(+), 30 deletions(-) diff --git a/src/proof/theory_proof.cpp b/src/proof/theory_proof.cpp index b74d4a4d2..b516c250f 100644 --- a/src/proof/theory_proof.cpp +++ b/src/proof/theory_proof.cpp @@ -176,6 +176,30 @@ void TheoryProofEngine::printConstantDisequalityProof(std::ostream& os, Expr c1, getTheoryProof(theory::Theory::theoryOf(c1))->printConstantDisequalityProof(os, c1, c2, globalLetMap); } +void TheoryProofEngine::printTheoryTerm(Expr term, + std::ostream& os, + const ProofLetMap& map, + TypeNode expectedType) +{ + this->printTheoryTermAsType(term, os, map, expectedType); +} + +TypeNode TheoryProofEngine::equalityType(const Expr& left, const Expr& right) +{ + // Ask the two theories what they think.. + TypeNode leftType = getTheoryProof(theory::Theory::theoryOf(left))->equalityType(left, right); + TypeNode rightType = getTheoryProof(theory::Theory::theoryOf(right))->equalityType(left, right); + + // Error if the disagree. + Assert(leftType.isNull() || rightType.isNull() || leftType == rightType) + << "TheoryProofEngine::equalityType(" << left << ", " << right << "):" << std::endl + << "theories disagree about the type of an equality:" << std::endl + << "\tleft: " << leftType << std::endl + << "\tright:" << rightType; + + return leftType.isNull() ? rightType : leftType; +} + void TheoryProofEngine::registerTerm(Expr term) { Debug("pf::tp::register") << "TheoryProofEngine::registerTerm: registering term: " << term << std::endl; @@ -295,11 +319,11 @@ void LFSCTheoryProofEngine::printTheoryTermAsType(Expr term, if (theory_id == theory::THEORY_BUILTIN || term.getKind() == kind::ITE || term.getKind() == kind::EQUAL) { - printCoreTerm(term, os, map); + printCoreTerm(term, os, map, expectedType); return; } // dispatch to proper theory - getTheoryProof(theory_id)->printOwnedTerm(term, os, map); + getTheoryProof(theory_id)->printOwnedTerm(term, os, map, expectedType); } void LFSCTheoryProofEngine::printSort(Type type, std::ostream& os) { @@ -866,18 +890,29 @@ void LFSCTheoryProofEngine::printBoundTermAsType(Expr term, { Debug("pf::tp") << "LFSCTheoryProofEngine::printBoundTerm( " << term << " ) " << std::endl; - ProofLetMap::const_iterator it = map.find(term); - if (it != map.end()) { - unsigned id = it->second.id; - unsigned count = it->second.count; + // Since let-abbreviated terms are abbreviated with their default type, only + // use the let map if there is no expectedType or the expectedType matches + // the default. + if (expectedType.isNull() + || TypeNode::fromType(term.getType()) == expectedType) + { + ProofLetMap::const_iterator it = map.find(term); + if (it != map.end()) + { + unsigned id = it->second.id; + unsigned count = it->second.count; - if (count > LET_COUNT) { - os << "let" << id; - return; + if (count > LET_COUNT) + { + os << "let" << id; + Debug("pf::tp::letmap") << "Using let map for " << term << std::endl; + return; + } } } + Debug("pf::tp::letmap") << "Skipping let map for " << term << std::endl; - printTheoryTerm(term, os, map); + printTheoryTerm(term, os, map, expectedType); } void LFSCTheoryProofEngine::printBoundFormula(Expr term, @@ -900,7 +935,7 @@ void LFSCTheoryProofEngine::printBoundFormula(Expr term, void LFSCTheoryProofEngine::printCoreTerm(Expr term, std::ostream& os, const ProofLetMap& map, - Type expectedType) + TypeNode expectedType) { if (term.isVariable()) { os << ProofManager::sanitize(term); @@ -911,6 +946,9 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term, switch(k) { case kind::ITE: { + TypeNode armType = expectedType.isNull() + ? TypeNode::fromType(term.getType()) + : expectedType; bool useFormulaType = term.getType().isBoolean(); Assert(term[1].getType().isSubtypeOf(term.getType())); Assert(term[2].getType().isSubtypeOf(term.getType())); @@ -924,7 +962,7 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term, } else { - printBoundTerm(term[1], os, map); + printBoundTerm(term[1], os, map, armType); } os << " "; if (useFormulaType) @@ -941,6 +979,7 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term, case kind::EQUAL: { bool booleanCase = term[0].getType().isBoolean(); + TypeNode armType = equalityType(term[0], term[1]); os << "("; if (booleanCase) { @@ -952,13 +991,13 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term, } if (booleanCase && printsAsBool(term[0])) os << "(p_app "; - printBoundTerm(term[0], os, map); + printBoundTerm(term[0], os, map, armType); if (booleanCase && printsAsBool(term[0])) os << ")"; os << " "; if (booleanCase && printsAsBool(term[1])) os << "(p_app "; - printBoundTerm(term[1], os, map); + printBoundTerm(term[1], os, map, armType); if (booleanCase && printsAsBool(term[1])) os << ") "; os << ")"; @@ -966,16 +1005,18 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term, } case kind::DISTINCT: + { // Distinct nodes can have any number of chidlren. Assert(term.getNumChildren() >= 2); + TypeNode armType = equalityType(term[0], term[1]); if (term.getNumChildren() == 2) { os << "(not (= "; printSort(term[0].getType(), os); os << " "; - printBoundTerm(term[0], os, map); + printBoundTerm(term[0], os, map, armType); os << " "; - printBoundTerm(term[1], os, map); + printBoundTerm(term[1], os, map, armType); os << "))"; } else { unsigned numOfPairs = term.getNumChildren() * (term.getNumChildren() - 1) / 2; @@ -985,28 +1026,29 @@ void LFSCTheoryProofEngine::printCoreTerm(Expr term, for (unsigned i = 0; i < term.getNumChildren(); ++i) { for (unsigned j = i + 1; j < term.getNumChildren(); ++j) { + TypeNode armType = equalityType(term[i], term[j]); if ((i != 0) || (j != 1)) { os << "(not (= "; printSort(term[0].getType(), os); os << " "; - printBoundTerm(term[i], os, map); + printBoundTerm(term[i], os, map, armType); os << " "; - printBoundTerm(term[j], os, map); + printBoundTerm(term[j], os, map, armType); os << ")))"; } else { os << "(not (= "; printSort(term[0].getType(), os); os << " "; - printBoundTerm(term[0], os, map); + printBoundTerm(term[0], os, map, armType); os << " "; - printBoundTerm(term[1], os, map); + printBoundTerm(term[1], os, map, armType); os << "))"; } } } } - return; + } case kind::CHAIN: { // LFSC doesn't allow declarations with variable numbers of @@ -1345,6 +1387,24 @@ void TheoryProof::printRewriteProof(std::ostream& os, const Node &n1, const Node os << "))"; } +void TheoryProof::printOwnedTerm(Expr term, + std::ostream& os, + const ProofLetMap& map, + TypeNode expectedType) +{ + this->printOwnedTermAsType(term, os, map, expectedType); +} + +TypeNode TheoryProof::equalityType(const Expr& left, const Expr& right) +{ + Assert(left.getType() == right.getType()) + << "TheoryProof::equalityType(" << left << ", " << right << "):" << std::endl + << "types disagree:" << std::endl + << "\tleft: " << left.getType() << std::endl + << "\tright:" << right.getType(); + return TypeNode::fromType(left.getType()); +} + bool TheoryProof::match(TNode n1, TNode n2) { theory::TheoryId theoryId = this->getTheoryId(); diff --git a/src/proof/theory_proof.h b/src/proof/theory_proof.h index e8569d636..85c8e5fee 100644 --- a/src/proof/theory_proof.h +++ b/src/proof/theory_proof.h @@ -198,14 +198,15 @@ public: void printTheoryTerm(Expr term, std::ostream& os, const ProofLetMap& map, - TypeNode expectedType = TypeNode()) - { - this->printTheoryTermAsType(term, os, map, expectedType); - } + TypeNode expectedType = TypeNode()); virtual void printTheoryTermAsType(Expr term, std::ostream& os, const ProofLetMap& map, TypeNode expectedType) = 0; + /** + * Calls `TheoryProof::equalityType` on the appropriate theory. + */ + TypeNode equalityType(const Expr& left, const Expr& right); bool printsAsBool(const Node &n); }; @@ -227,7 +228,7 @@ public: void printCoreTerm(Expr term, std::ostream& os, const ProofLetMap& map, - Type expectedType = Type()); + TypeNode expectedType = TypeNode()); void printLetTerm(Expr term, std::ostream& os) override; void printBoundTermAsType(Expr term, std::ostream& os, @@ -306,15 +307,28 @@ protected: void printOwnedTerm(Expr term, std::ostream& os, const ProofLetMap& map, - TypeNode expectedType = TypeNode()) - { - this->printOwnedTermAsType(term, os, map, expectedType); - } + TypeNode expectedType = TypeNode()); + virtual void printOwnedTermAsType(Expr term, std::ostream& os, const ProofLetMap& map, TypeNode expectedType) = 0; + /** + * Return the type (at the SMT level, the sort) of an equality or disequality + * between `left` and `right`. + * + * The default implementation asserts that the two have the same type, and + * returns it. + * + * A theory may want to do something else. + * + * For example, the theory of arithmetic allows equalities between Reals and + * Integers. In this case the integer is upcast to a real, and the equality + * is over reals. + */ + virtual TypeNode equalityType(const Expr& left, const Expr& right); + /** * Print the proof representation of the given type that belongs to some theory. * -- 2.30.2