Boolean terms rewriting for quantified variables of type Bool, when quantifier body...
authorMorgan Deters <mdeters@cs.nyu.edu>
Fri, 15 Mar 2013 17:30:05 +0000 (13:30 -0400)
committerMorgan Deters <mdeters@cs.nyu.edu>
Fri, 15 Mar 2013 17:30:05 +0000 (13:30 -0400)
src/smt/boolean_terms.cpp
src/smt/boolean_terms.h

index 262244f421c6800efde2ec589b39248f5e637014..35184e42e08dc8b1b5693543325aef817a30dfe0 100644 (file)
 #include "theory/theory_engine.h"
 #include "theory/model.h"
 #include "expr/kind.h"
+#include "util/hash.h"
+#include "util/bool.h"
 #include <string>
 #include <algorithm>
+#include <set>
+#include <map>
 
 using namespace std;
 using namespace CVC4::theory;
@@ -104,8 +108,34 @@ const Datatype& BooleanTermConverter::booleanTermsConvertDatatype(const Datatype
   return dt;
 }/* BooleanTermConverter::booleanTermsConvertDatatype() */
 
-Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw() {
-  Debug("boolean-terms") << "rewriteBooleanTerms: " << top << " - boolParent=" << boolParent << endl;
+// look for vars from "vars" that occur in a term-context in n; transfer them to output.
+static void collectVarsInTermContext(TNode n, std::set<TNode>& vars, std::set<TNode>& output, bool boolParent, std::hash_set< std::pair<TNode, bool>, PairHashFunction<TNode, bool, TNodeHashFunction, BoolHashFunction> >& alreadySeen) {
+  if(vars.empty()) {
+    return;
+  }
+  const pair<TNode, bool> cacheKey(n, boolParent);
+  if(alreadySeen.find(cacheKey) != alreadySeen.end()) {
+    return;
+  }
+  alreadySeen.insert(cacheKey);
+
+  if(n.isVar() && vars.find(n) != vars.end() && !boolParent) {
+    vars.erase(n);
+    output.insert(n);
+    if(vars.empty()) {
+      return;
+    }
+  }
+  for(size_t i = 0; i < n.getNumChildren(); ++i) {
+    collectVarsInTermContext(n[i], vars, output, isBoolean(n, i), alreadySeen);
+    if(vars.empty()) {
+      return;
+    }
+  }
+}
+
+Node BooleanTermConverter::rewriteBooleanTermsRec(TNode top, bool boolParent, std::map<TNode, Node>& quantBoolVars) throw() {
+  Debug("boolean-terms") << "rewriteBooleanTermsRec: " << top << " - boolParent=" << boolParent << endl;
 
   BooleanTermCache::iterator i = d_booleanTermCache.find(make_pair<Node, bool>(top, boolParent));
   if(i != d_booleanTermCache.end()) {
@@ -117,11 +147,21 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
 
   NodeManager* nm = NodeManager::currentNM();
 
+  Node one = nm->mkConst(BitVector(1u, 1u));
+  Node zero = nm->mkConst(BitVector(1u, 0u));
+
+  if(quantBoolVars.find(top) != quantBoolVars.end()) {
+    // this Bool variable is quantified over and we're changing it to a BitVector var
+    if(boolParent) {
+      return quantBoolVars[top].eqNode(one);
+    } else {
+      return quantBoolVars[top];
+    }
+  }
+
   if(!boolParent && top.getType().isBoolean()) {
-    Node one = nm->mkConst(BitVector(1u, 1u));
-    Node zero = nm->mkConst(BitVector(1u, 0u));
     // still need to rewrite e.g. function applications over boolean
-    Node topRewritten = rewriteBooleanTerms(top, true);
+    Node topRewritten = rewriteBooleanTermsRec(top, true, quantBoolVars);
     Node n = nm->mkNode(kind::ITE, topRewritten, one, zero);
     Debug("boolean-terms") << "constructed ITE: " << n << endl;
     return n;
@@ -143,7 +183,7 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
         }
         ArrayStoreAll asaRepl(nm->mkArrayType(indexType, nm->mkBitVectorType(1)).toType(), newConst.toExpr());
         Node n = nm->mkConst(asaRepl);
-        Debug("boolean-terms") << " returning new store_all: " << n << std::endl;
+        Debug("boolean-terms") << " returning new store_all: " << n << endl;
         return n;
       }
       if(indexType.isBoolean()) {
@@ -151,7 +191,7 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
         indexType = nm->mkBitVectorType(1);
         ArrayStoreAll asaRepl(nm->mkArrayType(indexType, TypeNode::fromType(constituentType)).toType(), asa.getExpr());
         Node n = nm->mkConst(asaRepl);
-        Debug("boolean-terms") << " returning new store_all: " << n << std::endl;
+        Debug("boolean-terms") << " returning new store_all: " << n << endl;
         return n;
       }
     }
@@ -176,8 +216,7 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
             Node var = nm->mkBoundVar(t[j]);
             boundVarsBuilder << var;
             if(t[j].isBoolean()) {
-              bodyBuilder << nm->mkNode(kind::ITE, var, nm->mkConst(BitVector(1u, 1u)),
-                                        nm->mkConst(BitVector(1u, 0u)));
+              bodyBuilder << nm->mkNode(kind::ITE, var, one, zero);
             } else {
               bodyBuilder << var;
             }
@@ -185,7 +224,7 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
           Node boundVars = boundVarsBuilder;
           Node body = bodyBuilder;
           Node lam = nm->mkNode(kind::LAMBDA, boundVars, body);
-          Debug("boolean-terms") << "substituting " << top << " ==> " << lam << std::endl;
+          Debug("boolean-terms") << "substituting " << top << " ==> " << lam << endl;
           d_smt.d_theoryEngine->getModel()->addSubstitution(top, lam);
           d_booleanTermCache[make_pair(top, boolParent)] = n;
           return n;
@@ -210,8 +249,6 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
                               NodeManager::SKOLEM_EXACT_NAME);
         top.setAttribute(BooleanTermAttr(), n);
         Debug("boolean-terms") << "constructed: " << n << " of type " << newType << endl;
-        Node one = nm->mkConst(BitVector(1u, 1u));
-        Node zero = nm->mkConst(BitVector(1u, 0u));
         Node n_zero = nm->mkNode(kind::SELECT, n, zero);
         Node n_one = nm->mkNode(kind::SELECT, n, one);
         Node base = nm->mkConst(ArrayStoreAll(ArrayType(top.getType().toType()), nm->mkConst(false).toExpr()));
@@ -308,9 +345,60 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
     // not yet supported
     return top;
 
+  case kind::FORALL:
+  case kind::EXISTS: {
+    Debug("bt") << "looking at quantifier -> " << top << endl;
+    set<TNode> ourVars;
+    for(TNode::iterator i = top[0].begin(); i != top[0].end(); ++i) {
+      if((*i).getType().isBoolean()) {
+        ourVars.insert(*i);
+      }
+    }
+    if(ourVars.empty()) {
+      // Simple case, quantifier doesn't quantify over Boolean vars,
+      // no special handling needed for quantifier.  Fall through.
+      Debug("bt") << "- quantifier simple case (1), no Boolean vars bound" << endl;
+    } else {
+      set<TNode> output;
+      hash_set< pair<TNode, bool>, PairHashFunction<TNode, bool, TNodeHashFunction, BoolHashFunction> > alreadySeen;
+      collectVarsInTermContext(top[1], ourVars, output, true, alreadySeen);
+      if(output.empty()) {
+        // Simple case, quantifier quantifies over Boolean vars, but they
+        // don't occur in term context.  Fall through.
+        Debug("bt") << "- quantifier simple case (2), Boolean vars bound but not used in term context" << endl;
+      } else {
+        Debug("bt") << "- quantifier case (3), Boolean vars bound and used in term context" << endl;
+        // We have Boolean vars appearing in term context.  Convert their
+        // types in the quantifier.
+        for(set<TNode>::const_iterator i = output.begin(); i != output.end(); ++i) {
+          Node newVar = nm->mkBoundVar((*i).toString(), nm->mkBitVectorType(1));
+          Assert(quantBoolVars.find(*i) == quantBoolVars.end(), "bad quantifier: shares a bound var with another quantifier (don't do that!)");
+          quantBoolVars[*i] = newVar;
+        }
+        vector<TNode> boundVars;
+        for(TNode::iterator i = top[0].begin(); i != top[0].end(); ++i) {
+          map<TNode, Node>::const_iterator j = quantBoolVars.find(*i);
+          if(j == quantBoolVars.end()) {
+            boundVars.push_back(*i);
+          } else {
+            boundVars.push_back((*j).second);
+          }
+        }
+        Node boundVarList = nm->mkNode(kind::BOUND_VAR_LIST, boundVars);
+        Node body = rewriteBooleanTermsRec(top[1], true, quantBoolVars);
+        Node quant = nm->mkNode(top.getKind(), boundVarList, body);
+        Debug("bt") << "rewrote quantifier to -> " << quant << endl;
+        d_booleanTermCache[make_pair(top, true)] = quant;
+        d_booleanTermCache[make_pair(top, false)] = quant.iteNode(one, zero);
+        return quant;
+      }
+    }
+    /* intentional fall-through for some cases above */
+  }
+
   default:
     NodeBuilder<> b(k);
-    Debug("bt") << "looking at: " << top << std::endl;
+    Debug("bt") << "looking at: " << top << endl;
     if(mk == kind::metakind::PARAMETERIZED) {
       if(kindToTheoryId(k) != THEORY_BV &&
          k != kind::APPLY_TYPE_ASCRIPTION &&
@@ -319,29 +407,29 @@ Node BooleanTermConverter::rewriteBooleanTerms(TNode top, bool boolParent) throw
          k != kind::RECORD_SELECT &&
          k != kind::RECORD_UPDATE &&
          k != kind::RECORD) {
-        Debug("bt") << "rewriting: " << top.getOperator() << std::endl;
-        b << rewriteBooleanTerms(top.getOperator(), false);
-        Debug("bt") << "got: " << b.getOperator() << std::endl;
+        Debug("bt") << "rewriting: " << top.getOperator() << endl;
+        b << rewriteBooleanTermsRec(top.getOperator(), false, quantBoolVars);
+        Debug("bt") << "got: " << b.getOperator() << endl;
       } else {
         b << top.getOperator();
       }
     }
     for(unsigned i = 0; i < top.getNumChildren(); ++i) {
-      Debug("bt") << "rewriting: " << top[i] << std::endl;
-      b << rewriteBooleanTerms(top[i], isBoolean(top, i));
-      Debug("bt") << "got: " << b[b.getNumChildren() - 1] << std::endl;
+      Debug("bt") << "rewriting: " << top[i] << endl;
+      b << rewriteBooleanTermsRec(top[i], isBoolean(top, i), quantBoolVars);
+      Debug("bt") << "got: " << b[b.getNumChildren() - 1] << endl;
     }
     Node n = b;
     Debug("boolean-terms") << "constructed: " << n << endl;
     if(boolParent &&
        n.getType().isBitVector() &&
        n.getType().getBitVectorSize() == 1) {
-      n = nm->mkNode(kind::EQUAL, n, nm->mkConst(BitVector(1u, 1u)));
+      n = nm->mkNode(kind::EQUAL, n, one);
     }
     d_booleanTermCache[make_pair(top, boolParent)] = n;
     return n;
   }
-}/* BooleanTermConverter::rewriteBooleanTerms() */
+}/* BooleanTermConverter::rewriteBooleanTermsRec() */
 
 }/* CVC4::smt namespace */
 }/* CVC4 namespace */
index e51a7bbb05578b2b3cf22dcf9cfb397bae5364c3..c53eadfa0a31e50f077bf0ae72470eac46ea4740 100644 (file)
@@ -24,6 +24,7 @@
 #include "expr/attribute.h"
 #include "expr/node.h"
 #include "util/hash.h"
+#include <map>
 #include <utility>
 
 namespace CVC4 {
@@ -52,6 +53,8 @@ class BooleanTermConverter {
    */
   const Datatype& booleanTermsConvertDatatype(const Datatype& dt) throw();
 
+  Node rewriteBooleanTermsRec(TNode n, bool boolParent, std::map<TNode, Node>& quantBoolVars) throw();
+
 public:
 
   BooleanTermConverter(SmtEngine& smt) :
@@ -61,7 +64,10 @@ public:
   /**
    * We rewrite Boolean terms in assertions as bitvectors of length 1.
    */
-  Node rewriteBooleanTerms(TNode n, bool boolParent = true) throw();
+  Node rewriteBooleanTerms(TNode n, bool boolParent = true) throw() {
+    std::map<TNode, Node> quantBoolVars;
+    return rewriteBooleanTermsRec(n, boolParent, quantBoolVars);
+  }
 
 };/* class BooleanTermConverter */