Adding constant propagation code - needs more testing - off by default
authorClark Barrett <barrett@cs.nyu.edu>
Tue, 12 Jun 2012 03:01:41 +0000 (03:01 +0000)
committerClark Barrett <barrett@cs.nyu.edu>
Tue, 12 Jun 2012 03:01:41 +0000 (03:01 +0000)
src/smt/smt_engine.cpp
src/smt/smt_engine.h
src/theory/substitutions.cpp
src/theory/substitutions.h

index e0d507475fb5bdf20e2ea00fa9507994921f156b..5874692b548e53a53892e5fc96a95d1130be842a 100644 (file)
@@ -140,6 +140,8 @@ class SmtEnginePrivate {
    * then nothing has been pushed out yet. */
   context::CDO<theory::SubstitutionMap::iterator> d_lastSubstitutionPos;
 
+  static const bool d_doConstantProp = false;
+
   /**
    * Runs the nonclausal solver and tries to solve all the assigned
    * theory literals.
@@ -256,6 +258,7 @@ SmtEngine::SmtEngine(ExprManager* em) throw(AssertionException) :
   d_private(new smt::SmtEnginePrivate(*this)),
   d_definitionExpansionTime("smt::SmtEngine::definitionExpansionTime"),
   d_nonclausalSimplificationTime("smt::SmtEngine::nonclausalSimplificationTime"),
+  d_numConstantProps("smt::SmtEngine::numConstantProps", 0),
   d_staticLearningTime("smt::SmtEngine::staticLearningTime"),
   d_simpITETime("smt::SmtEngine::simpITETime"),
   d_unconstrainedSimpTime("smt::SmtEngine::unconstrainedSimpTime"),
@@ -270,6 +273,7 @@ SmtEngine::SmtEngine(ExprManager* em) throw(AssertionException) :
 
   StatisticsRegistry::registerStat(&d_definitionExpansionTime);
   StatisticsRegistry::registerStat(&d_nonclausalSimplificationTime);
+  StatisticsRegistry::registerStat(&d_numConstantProps);
   StatisticsRegistry::registerStat(&d_staticLearningTime);
   StatisticsRegistry::registerStat(&d_simpITETime);
   StatisticsRegistry::registerStat(&d_unconstrainedSimpTime);
@@ -389,6 +393,7 @@ SmtEngine::~SmtEngine() throw() {
 
     StatisticsRegistry::unregisterStat(&d_definitionExpansionTime);
     StatisticsRegistry::unregisterStat(&d_nonclausalSimplificationTime);
+    StatisticsRegistry::unregisterStat(&d_numConstantProps);
     StatisticsRegistry::unregisterStat(&d_staticLearningTime);
     StatisticsRegistry::unregisterStat(&d_simpITETime);
     StatisticsRegistry::unregisterStat(&d_unconstrainedSimpTime);
@@ -847,35 +852,61 @@ void SmtEnginePrivate::nonClausalSimplify() {
     d_assertionsToPreprocess.clear();
     d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst<bool>(false));
     return;
-  } else {
-    // No, conflict, go through the literals and solve them
-    unsigned j = 0;
-    for(unsigned i = 0, i_end = d_nonClausalLearnedLiterals.size(); i < i_end; ++ i) {
-      // Simplify the literal we learned wrt previous substitutions
-      Node learnedLiteral =
-        theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(d_nonClausalLearnedLiterals[i]));
-      // It might just simplify to a constant
-      if (learnedLiteral.isConst()) {
-        if (learnedLiteral.getConst<bool>()) {
-          // If the learned literal simplifies to true, it's redundant
-          continue;
-        } else {
-          // If the learned literal simplifies to false, we're in conflict
-          Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
-                            << "conflict with "
-                            << d_nonClausalLearnedLiterals[i] << endl;
-          d_assertionsToPreprocess.clear();
-          d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst<bool>(false));
-          return;
-        }
+  }
+
+  // No, conflict, go through the literals and solve them
+  theory::SubstitutionMap constantPropagations(d_smt.d_context);
+  unsigned j = 0;
+  for(unsigned i = 0, i_end = d_nonClausalLearnedLiterals.size(); i < i_end; ++ i) {
+    // Simplify the literal we learned wrt previous substitutions
+    Node learnedLiteral = d_nonClausalLearnedLiterals[i];
+    Node learnedLiteralNew = d_topLevelSubstitutions.apply(learnedLiteral);
+    if (learnedLiteral != learnedLiteralNew) {
+      learnedLiteral = theory::Rewriter::rewrite(learnedLiteralNew);
+    }
+    for (;;) {
+      learnedLiteralNew = constantPropagations.apply(learnedLiteral);
+      if (learnedLiteralNew == learnedLiteral) {
+        break;
+      }
+      ++d_smt.d_numConstantProps;
+      learnedLiteral = theory::Rewriter::rewrite(learnedLiteralNew);
+    }
+    // It might just simplify to a constant
+    if (learnedLiteral.isConst()) {
+      if (learnedLiteral.getConst<bool>()) {
+        // If the learned literal simplifies to true, it's redundant
+        continue;
+      } else {
+        // If the learned literal simplifies to false, we're in conflict
+        Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
+                          << "conflict with "
+                          << d_nonClausalLearnedLiterals[i] << endl;
+        d_assertionsToPreprocess.clear();
+        d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst<bool>(false));
+        return;
       }
-      // Solve it with the corresponding theory
-      Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
-                        << "solving " << learnedLiteral << endl;
-      Theory::PPAssertStatus solveStatus =
-        d_smt.d_theoryEngine->solve(learnedLiteral, d_topLevelSubstitutions);
+    }
+    // Solve it with the corresponding theory
+    Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
+                      << "solving " << learnedLiteral << endl;
+    Theory::PPAssertStatus solveStatus =
+      d_smt.d_theoryEngine->solve(learnedLiteral, d_topLevelSubstitutions);
 
-      switch (solveStatus) {
+    switch (solveStatus) {
+      case Theory::PP_ASSERT_STATUS_SOLVED: {
+        // The literal should rewrite to true
+        Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
+                          << "solved " << learnedLiteral << endl;
+        Assert(theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(learnedLiteral)).isConst());
+        vector<pair<Node, Node> > equations;
+        constantPropagations.simplifyLHS(d_topLevelSubstitutions, equations, true);
+        if (equations.empty()) {
+          break;
+        }
+        Assert(equations[0].first.isConst() && equations[0].second.isConst() && equations[0].first != equations[0].second);
+        // else fall through
+      }
       case Theory::PP_ASSERT_STATUS_CONFLICT:
         // If in conflict, we return false
         Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
@@ -884,46 +915,109 @@ void SmtEnginePrivate::nonClausalSimplify() {
         d_assertionsToPreprocess.clear();
         d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst<bool>(false));
         return;
-      case Theory::PP_ASSERT_STATUS_SOLVED:
-        // The literal should rewrite to true
-        Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
-                          << "solved " << learnedLiteral << endl;
-        Assert(theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(learnedLiteral)).isConst());
-        break;
       default:
-        // Keep the literal
-        d_nonClausalLearnedLiterals[j++] = d_nonClausalLearnedLiterals[i];
+        if (d_doConstantProp && learnedLiteral.getKind() == kind::EQUAL && (learnedLiteral[0].isConst() || learnedLiteral[1].isConst())) {
+          // constant propagation
+          TNode t;
+          TNode c;
+          if (learnedLiteral[0].isConst()) {
+            t = learnedLiteral[1];
+            c = learnedLiteral[0];
+          }
+          else {
+            t = learnedLiteral[0];
+            c = learnedLiteral[1];
+          }
+          Assert(!t.isConst());
+          Assert(constantPropagations.apply(t) == t);
+          Assert(d_topLevelSubstitutions.apply(t) == t);
+          vector<pair<Node,Node> > equations;
+          constantPropagations.simplifyLHS(t, c, equations, true);
+          if (!equations.empty()) {
+            Assert(equations[0].first.isConst() && equations[0].second.isConst() && equations[0].first != equations[0].second);
+            d_assertionsToPreprocess.clear();
+            d_assertionsToCheck.push_back(NodeManager::currentNM()->mkConst<bool>(false));
+            return;
+          }
+          d_topLevelSubstitutions.simplifyRHS(constantPropagations);
+        }
+        else {
+          // Keep the literal
+          d_nonClausalLearnedLiterals[j++] = d_nonClausalLearnedLiterals[i];
+        }
         break;
+    }
+
+    if( Options::current()->incrementalSolving ||
+        Options::current()->simplificationMode == Options::SIMPLIFICATION_MODE_INCREMENTAL ) {
+      // Tell PropEngine about new substitutions
+      SubstitutionMap::iterator pos = d_lastSubstitutionPos;
+      if(pos == d_topLevelSubstitutions.end()) {
+        pos = d_topLevelSubstitutions.begin();
+      } else {
+        ++pos;
       }
 
-      if( Options::current()->incrementalSolving ||
-          Options::current()->simplificationMode == Options::SIMPLIFICATION_MODE_INCREMENTAL ) {
-        // Tell PropEngine about new substitutions
-        SubstitutionMap::iterator pos = d_lastSubstitutionPos;
-        if(pos == d_topLevelSubstitutions.end()) {
-           pos = d_topLevelSubstitutions.begin();
-        } else {
-          ++pos;
-        }
+      while(pos != d_topLevelSubstitutions.end()) {
+        // Push out this substitution
+        TNode lhs = (*pos).first, rhs = (*pos).second;
+        Node n = NodeManager::currentNM()->mkNode(lhs.getType().isBoolean() ? kind::IFF : kind::EQUAL, lhs, rhs);
+        d_assertionsToCheck.push_back(n);
+        Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): will notify SAT layer of substitution: " << n << endl;
+        d_lastSubstitutionPos = pos;
+        ++pos;
+      }
+    }
 
-        while(pos != d_topLevelSubstitutions.end()) {
-          // Push out this substitution
-          TNode lhs = (*pos).first, rhs = (*pos).second;
-          Node n = NodeManager::currentNM()->mkNode(lhs.getType().isBoolean() ? kind::IFF : kind::EQUAL, lhs, rhs);
-          d_assertionsToCheck.push_back(n);
-          Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): will notify SAT layer of substitution: " << n << endl;
-          d_lastSubstitutionPos = pos;
-          ++pos;
-        }
+#ifdef CVC4_ASSERTIONS
+    // Check data structure invariants:
+    // 1. for each lhs of d_topLevelSubstitutions, does not appear anywhere in rhs of d_topLevelSubstitutions or anywhere in constantPropagations
+    // 2. each lhs of constantPropagations rewrites to itself
+    // 3. if l -> r is a constant propagation and l is a subterm of l' with l' -> r' another constant propagation, then l'[l/r] -> r' should be a
+    //    constant propagation too
+    // 4. each lhs of constantPropagations is different from each rhs
+    SubstitutionMap::iterator pos = d_topLevelSubstitutions.begin();
+    for (; pos != d_topLevelSubstitutions.end(); ++pos) {
+      Assert((*pos).first.isVar());
+      Assert(d_topLevelSubstitutions.apply((*pos).second) == (*pos).second);
+    }
+    for (pos = constantPropagations.begin(); pos != constantPropagations.end(); ++pos) {
+      Assert((*pos).second.isConst());
+      Assert(Rewriter::rewrite((*pos).first) == (*pos).first);
+      Node newLeft = d_topLevelSubstitutions.apply((*pos).first);
+      if (newLeft != (*pos).first) {
+        newLeft = Rewriter::rewrite(newLeft);
+        Assert(newLeft == (*pos).second ||
+               (constantPropagations.hasSubstitution(newLeft) && constantPropagations.apply(newLeft) == (*pos).second));
       }
+      newLeft = constantPropagations.apply((*pos).first);
+      if (newLeft != (*pos).first) {
+        newLeft = Rewriter::rewrite(newLeft);
+        Assert(newLeft == (*pos).second ||
+               (constantPropagations.hasSubstitution(newLeft) && constantPropagations.apply(newLeft) == (*pos).second));
+      }
+      Assert(constantPropagations.apply((*pos).second) == (*pos).second);
     }
-    // Resize the learnt
-    d_nonClausalLearnedLiterals.resize(j);
+#endif
   }
+  // Resize the learnt
+  d_nonClausalLearnedLiterals.resize(j);
 
   hash_set<TNode, TNodeHashFunction> s;
   for (unsigned i = 0; i < d_assertionsToPreprocess.size(); ++ i) {
-    Node assertion = theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(d_assertionsToPreprocess[i]));
+    Node assertion = d_assertionsToPreprocess[i];
+    Node assertionNew = d_topLevelSubstitutions.apply(assertion);
+    if (assertion != assertionNew) {
+      assertion = theory::Rewriter::rewrite(assertionNew);
+    }
+    for (;;) {
+      assertionNew = constantPropagations.apply(assertion);
+      if (assertionNew == assertion) {
+        break;
+      }
+      ++d_smt.d_numConstantProps;
+      assertion = theory::Rewriter::rewrite(assertionNew);
+    }
     s.insert(assertion);
     d_assertionsToCheck.push_back(assertion);
     Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
@@ -933,7 +1027,19 @@ void SmtEnginePrivate::nonClausalSimplify() {
   d_assertionsToPreprocess.clear();
 
   for (unsigned i = 0; i < d_nonClausalLearnedLiterals.size(); ++ i) {
-    Node learned = theory::Rewriter::rewrite(d_topLevelSubstitutions.apply(d_nonClausalLearnedLiterals[i]));
+    Node learned = d_nonClausalLearnedLiterals[i];
+    Node learnedNew = d_topLevelSubstitutions.apply(learned);
+    if (learned != learnedNew) {
+      learned = theory::Rewriter::rewrite(learnedNew);
+    }
+    for (;;) {
+      learnedNew = constantPropagations.apply(learned);
+      if (learnedNew == learned) {
+        break;
+      }
+      ++d_smt.d_numConstantProps;
+      learned = theory::Rewriter::rewrite(learnedNew);
+    }
     if (s.find(learned) != s.end()) {
       continue;
     }
@@ -945,8 +1051,21 @@ void SmtEnginePrivate::nonClausalSimplify() {
   }
   d_nonClausalLearnedLiterals.clear();
 
+  SubstitutionMap::iterator pos = constantPropagations.begin();
+  for (; pos != constantPropagations.end(); ++pos) {
+    Node cProp = (*pos).first.eqNode((*pos).second);
+    if (s.find(cProp) != s.end()) {
+      continue;
+    }
+    s.insert(cProp);
+    d_assertionsToCheck.push_back(cProp);
+    Trace("simplify") << "SmtEnginePrivate::nonClausalSimplify(): "
+                      << "non-clausal constant propagation : "
+                      << d_assertionsToCheck.back() << endl;
+  }
 }
 
+
 void SmtEnginePrivate::simpITE()
 {
   TimerStat::CodeTimer simpITETimer(d_smt.d_simpITETime);
index 4c0fed74cfd65374434178ea4a34da144def802f..e0f544f9e02d394adb3fe4342645817a4761cefb 100644 (file)
@@ -239,6 +239,8 @@ class CVC4_PUBLIC SmtEngine {
   TimerStat d_definitionExpansionTime;
   /** time spent in non-clausal simplification */
   TimerStat d_nonclausalSimplificationTime;
+  /** Num of constant propagations found during nonclausal simp */
+  IntStat d_numConstantProps;
   /** time spent in static learning */
   TimerStat d_staticLearningTime;
   /** time spent in simplifying ITEs */
index df4c919d8e8a662520d29ade2199ec7c8e38d5f5..3f8a6d630b933a792b141286f7c6fb755f8232e0 100644 (file)
@@ -17,6 +17,7 @@
  **/
 
 #include "theory/substitutions.h"
+#include "theory/rewriter.h"
 
 using namespace std;
 
@@ -103,21 +104,127 @@ Node SubstitutionMap::internalSubstitute(TNode t, NodeCache& substitutionCache)
   return substitutionCache[t];
 }
 
-void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache, bool backSub, bool forwardSub) {
-  Debug("substitution") << "SubstitutionMap::addSubstitution(" << x << ", " << t << ")" << std::endl;
-  Assert(d_substitutions.find(x) == d_substitutions.end());
 
-  if (backSub) {
-    // Temporary substitution cache
+void SubstitutionMap::simplifyRHS(const SubstitutionMap& subMap)
+{
+  // Put the new substitutions into the old ones
+  NodeMap::iterator it = d_substitutions.begin();
+  NodeMap::iterator it_end = d_substitutions.end();
+  for(; it != it_end; ++ it) {
+    d_substitutions[(*it).first] = subMap.apply((*it).second);
+  }  
+}
+
+
+void SubstitutionMap::simplifyRHS(TNode x, TNode t) {
+  // Temporary substitution cache
+  NodeCache tempCache;
+  tempCache[x] = t;
+
+  // Put the new substitution into the old ones
+  NodeMap::iterator it = d_substitutions.begin();
+  NodeMap::iterator it_end = d_substitutions.end();
+  for(; it != it_end; ++ it) {
+    d_substitutions[(*it).first] = internalSubstitute((*it).second, tempCache);
+  }  
+}
+
+
+/* We use subMap to simplify the left-hand sides of the current substitution map.  If rewrite is true,
+ * we also apply the rewriter to the result.
+ * We want to maintain the invariant that all lhs are distinct from each other and from all rhs.
+ * If for some l -> r, l reduces to l', we try to add a new rule l' -> r.  There are two cases
+ * where this fails
+ *   (i) if l' is equal to some ll (in a rule ll -> rr), then if r != rr we add (r,rr) to the equation list
+ *   (i) if l' is equalto some rr (in a rule ll -> rr), then if r != rr we add (r,rr) to the equation list
+ */
+void SubstitutionMap::simplifyLHS(const SubstitutionMap& subMap, vector<pair<Node, Node> >& equalities, bool rewrite)
+{
+  Assert(d_worklist.empty());
+  // First, apply subMap to every LHS in d_substitutions
+  NodeMap::iterator it = d_substitutions.begin();
+  NodeMap::iterator it_end = d_substitutions.end();
+  Node newLeft;
+  for(; it != it_end; ++ it) {
+    newLeft = subMap.apply((*it).first);
+    if (newLeft != (*it).first) {
+      if (rewrite) {
+        newLeft = Rewriter::rewrite(newLeft);
+      }
+      d_worklist.push_back(pair<Node,Node>(newLeft, (*it).second));
+    }
+  }
+  processWorklist(equalities, rewrite);
+  Assert(d_worklist.empty());
+}
+
+
+void SubstitutionMap::simplifyLHS(TNode lhs, TNode rhs, vector<pair<Node,Node> >& equalities, bool rewrite)
+{
+  Assert(d_worklist.empty());
+  d_worklist.push_back(pair<Node,Node>(lhs,rhs));
+  processWorklist(equalities, rewrite);                       
+  Assert(d_worklist.empty());
+}
+
+
+void SubstitutionMap::processWorklist(vector<pair<Node, Node> >& equalities, bool rewrite)
+{
+  // Add each new rewrite rule, taking care not to invalidate invariants and looking
+  // for any new rewrite rules we can learn
+  Node newLeft, newRight;
+  while (!d_worklist.empty()) {
+    newLeft = d_worklist.back().first;
+    newRight = d_worklist.back().second;
+    d_worklist.pop_back();
+
     NodeCache tempCache;
-    tempCache[x] = t;
+    tempCache[newLeft] = newRight;
 
-    // Put in the new substitutions into the old ones
+    Node newLeft2;
+    unsigned size = d_worklist.size();
+    bool addThisRewrite = true;
     NodeMap::iterator it = d_substitutions.begin();
     NodeMap::iterator it_end = d_substitutions.end();
+
     for(; it != it_end; ++ it) {
-      d_substitutions[(*it).first] = internalSubstitute((*it).second, tempCache);
+
+      // Check for invariant violation.  If new rewrite is redundant, do nothing
+      // Otherwise, add an equality to the output equalities
+      // In either case undo any work done by this rewrite
+      if (newLeft == (*it).first || newLeft == (*it).second) {
+        if ((*it).second != newRight) {
+          equalities.push_back(pair<Node,Node>((*it).second, newRight));
+        }
+        while (d_worklist.size() > size) {
+          d_worklist.pop_back();
+        }
+        addThisRewrite = false;
+        break;
+      }
+
+      newLeft2 = internalSubstitute((*it).first, tempCache);
+      if (newLeft2 != (*it).first) {
+        if (rewrite) {
+          newLeft2 = Rewriter::rewrite(newLeft2);
+        }
+        d_worklist.push_back(pair<Node,Node>(newLeft2, (*it).second));
+      }
     }
+    if (addThisRewrite) {
+      d_substitutions[newLeft] = newRight;
+      d_cacheInvalidated = true;
+    }
+  }
+}
+
+
+void SubstitutionMap::addSubstitution(TNode x, TNode t, bool invalidateCache, bool backSub, bool forwardSub) {
+  Debug("substitution") << "SubstitutionMap::addSubstitution(" << x << ", " << t << ")" << std::endl;
+  Assert(d_substitutions.find(x) == d_substitutions.end());
+
+  if (backSub) {
+    simplifyRHS(x, t);
   }
 
   // Put the new substitution in
@@ -181,6 +288,10 @@ void SubstitutionMap::print(ostream& out) const {
   }
 }
 
+void SubstitutionMap::debugPrint() const {
+  print(std::cout);
+}
+
 }/* CVC4::theory namespace */
 
 std::ostream& operator<<(std::ostream& out, const theory::SubstitutionMap::iterator& i) {
index ee2a15f6f00031740b815316434626e40704038a..32ed35074befa089ca743edd77241eb7d85de13e 100644 (file)
@@ -91,6 +91,10 @@ private:
    */
   CacheInvalidator d_cacheInvalidator;
 
+  // Helper list and method for simplifyLHS methods
+  std::vector<std::pair<Node, Node> > d_worklist;
+  void processWorklist(std::vector<std::pair<Node, Node> >& equalities, bool rewrite);
+
 public:
 
   SubstitutionMap(context::Context* context) :
@@ -158,10 +162,27 @@ public:
   // should best interact with cache invalidation on context
   // pops.
 
+  // Simplify right-hand sides of current map using the given substitutions
+  void simplifyRHS(const SubstitutionMap& subMap);
+
+  // Simplify right-hand sides of current map with lhs -> rhs
+  void simplifyRHS(TNode lhs, TNode rhs);
+
+  // Simplify left-hand sides of current map using the given substitutions
+  void simplifyLHS(const SubstitutionMap& subMap,
+                   std::vector<std::pair<Node,Node> >& equalities,
+                   bool rewrite = true);
+
+  // Simplify left-hand sides of current map with lhs -> rhs and then add lhs -> rhs to the substitutions set
+  void simplifyLHS(TNode lhs, TNode rhs,
+                   std::vector<std::pair<Node,Node> >& equalities,
+                   bool rewrite = true);
+
   /**
    * Print to the output stream
    */
   void print(std::ostream& out) const;
+  void debugPrint() const;
 
 };/* class SubstitutionMap */