Added cache to getModelValue
authorClark Barrett <barrett@cs.nyu.edu>
Sun, 2 Nov 2014 03:05:02 +0000 (20:05 -0700)
committerClark Barrett <barrett@cs.nyu.edu>
Sun, 2 Nov 2014 03:05:02 +0000 (20:05 -0700)
src/theory/theory_model.cpp
src/theory/theory_model.h

index 70ae2c03bc74aa54385905386dbbc7be08e6fb2f..54a647d896e3bacaccf27db59817f128d2f5f8ca 100644 (file)
@@ -53,6 +53,7 @@ TheoryModel::~TheoryModel() {
 }
 
 void TheoryModel::reset(){
+  d_modelCache.clear();
   d_reps.clear();
   d_rep_set.clear();
   d_uf_terms.clear();
@@ -98,6 +99,11 @@ Cardinality TheoryModel::getCardinality( Type t ) const{
 
 Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const
 {
+  std::hash_map<Node, Node, NodeHashFunction>::iterator it = d_modelCache.find(n);
+  if (it != d_modelCache.end()) {
+    return (*it).second;
+  }
+  Node ret = n;
   if(n.getKind() == kind::EXISTS || n.getKind() == kind::FORALL) {
     // We should have terms, thanks to TheoryQuantifiers::collectModelInfo().
     // However, if the Decision Engine stops us early, there might be a
@@ -114,19 +120,23 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const
     // checkModel(), and the quantifier actually matters, we'll get an
     // assert-fail since the quantifier isn't a constant.
     if(!d_equalityEngine->hasTerm(Rewriter::rewrite(n))) {
-      return n;
+      d_modelCache[n] = ret;
+      return ret;
     } else {
-      n = Rewriter::rewrite(n);
+      ret = Rewriter::rewrite(n);
     }
   } else {
     if(n.getKind() == kind::LAMBDA) {
       NodeManager* nm = NodeManager::currentNM();
       Node body = getModelValue(n[1], true);
       body = Rewriter::rewrite(body);
-      return nm->mkNode(kind::LAMBDA, n[0], body);
+      ret = nm->mkNode(kind::LAMBDA, n[0], body);
+      d_modelCache[n] = ret;
+      return ret;
     }
     if(n.isConst() || (hasBoundVars && n.getKind() == kind::BOUND_VARIABLE)) {
-      return n;
+      d_modelCache[n] = ret;
+      return ret;
     }
 
     TypeNode t = n.getType();
@@ -135,7 +145,9 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const
         std::map< Node, Node >::const_iterator it = d_uf_models.find(n);
         if (it != d_uf_models.end()) {
           // Existing function
-          return it->second;
+          ret = it->second;
+          d_modelCache[n] = ret;
+          return ret;
         }
         // Unknown function symbol: return LAMBDA x. c, where c is the first constant in the enumeration of the range type
         vector<TypeNode> argTypes = t.getArgTypes();
@@ -146,7 +158,9 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const
         }
         Node boundVarList = nm->mkNode(kind::BOUND_VAR_LIST, args);
         TypeEnumerator te(t.getRangeType());
-        return nm->mkNode(kind::LAMBDA, boundVarList, *te);
+        ret = nm->mkNode(kind::LAMBDA, boundVarList, *te);
+        d_modelCache[n] = ret;
+        return ret;
       }
       // TODO: if func models not enabled, throw an error?
       Unreachable();
@@ -165,34 +179,39 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const
       }
       //evaluate the children
       for (unsigned i = 0; i < n.getNumChildren(); ++i) {
-        Node val = getModelValue(n[i], hasBoundVars);
-        children.push_back(val);
+        ret = getModelValue(n[i], hasBoundVars);
+        children.push_back(ret);
       }
-      Node val = Rewriter::rewrite(NodeManager::currentNM()->mkNode(n.getKind(), children));
-      if(val.getKind() == kind::CARDINALITY_CONSTRAINT) {
-        val = NodeManager::currentNM()->mkConst(getCardinality(val[0].getType().toType()).getFiniteCardinality() <= val[1].getConst<Rational>().getNumerator());
+      ret = Rewriter::rewrite(NodeManager::currentNM()->mkNode(n.getKind(), children));
+      if(ret.getKind() == kind::CARDINALITY_CONSTRAINT) {
+        ret = NodeManager::currentNM()->mkConst(getCardinality(ret[0].getType().toType()).getFiniteCardinality() <= ret[1].getConst<Rational>().getNumerator());
       }
-      if(val.getKind() == kind::COMBINED_CARDINALITY_CONSTRAINT ){
+      if(ret.getKind() == kind::COMBINED_CARDINALITY_CONSTRAINT ){
         //FIXME
-        val = NodeManager::currentNM()->mkConst(false);
+        ret = NodeManager::currentNM()->mkConst(false);
       }
-      return val;
+      d_modelCache[n] = ret;
+      return ret;
     }
 
     if (!d_equalityEngine->hasTerm(n)) {
       // Unknown term - return first enumerated value for this type
       TypeEnumerator te(n.getType());
-      return *te;
+      ret = *te;
+      d_modelCache[n] = ret;
+      return ret;
     }
   }
-  Node val = d_equalityEngine->getRepresentative(n);
-  Assert(d_reps.find(val) != d_reps.end());
-  std::map< Node, Node >::const_iterator it = d_reps.find( val );
-  if( it!=d_reps.end() ){
-    return it->second;
-  }else{
-    return Node::null();
+  ret = d_equalityEngine->getRepresentative(ret);
+  Assert(d_reps.find(ret) != d_reps.end());
+  std::map< Node, Node >::const_iterator it2 = d_reps.find( ret );
+  if (it2 != d_reps.end()) {
+    ret = it2->second;
+  } else {
+    ret = Node::null();
   }
+  d_modelCache[n] = ret;
+  return ret;
 }
 
 Node TheoryModel::getDomainValue( TypeNode tn, std::vector< Node >& exclude ){
index 2eb33b0fbb05e6668c337b32bea483a5066dbb9f..eeaf3c8da65922be37a3a2b9a0716f61f4ca7108 100644 (file)
@@ -52,6 +52,7 @@ public:
   Node d_true;
   Node d_false;
   context::CDO<bool> d_modelBuilt;
+  mutable std::hash_map<Node, Node, NodeHashFunction> d_modelCache;
 
 protected:
   /** reset the model */