From a84864da338f74958c6754696d98cd6355e798a8 Mon Sep 17 00:00:00 2001 From: Clark Barrett Date: Sat, 1 Nov 2014 20:05:02 -0700 Subject: [PATCH] Added cache to getModelValue --- src/theory/theory_model.cpp | 63 ++++++++++++++++++++++++------------- src/theory/theory_model.h | 1 + 2 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/theory/theory_model.cpp b/src/theory/theory_model.cpp index 70ae2c03b..54a647d89 100644 --- a/src/theory/theory_model.cpp +++ b/src/theory/theory_model.cpp @@ -53,6 +53,7 @@ TheoryModel::~TheoryModel() { } void TheoryModel::reset(){ + d_modelCache.clear(); d_reps.clear(); d_rep_set.clear(); d_uf_terms.clear(); @@ -98,6 +99,11 @@ Cardinality TheoryModel::getCardinality( Type t ) const{ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const { + std::hash_map::iterator it = d_modelCache.find(n); + if (it != d_modelCache.end()) { + return (*it).second; + } + Node ret = n; if(n.getKind() == kind::EXISTS || n.getKind() == kind::FORALL) { // We should have terms, thanks to TheoryQuantifiers::collectModelInfo(). // However, if the Decision Engine stops us early, there might be a @@ -114,19 +120,23 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const // checkModel(), and the quantifier actually matters, we'll get an // assert-fail since the quantifier isn't a constant. if(!d_equalityEngine->hasTerm(Rewriter::rewrite(n))) { - return n; + d_modelCache[n] = ret; + return ret; } else { - n = Rewriter::rewrite(n); + ret = Rewriter::rewrite(n); } } else { if(n.getKind() == kind::LAMBDA) { NodeManager* nm = NodeManager::currentNM(); Node body = getModelValue(n[1], true); body = Rewriter::rewrite(body); - return nm->mkNode(kind::LAMBDA, n[0], body); + ret = nm->mkNode(kind::LAMBDA, n[0], body); + d_modelCache[n] = ret; + return ret; } if(n.isConst() || (hasBoundVars && n.getKind() == kind::BOUND_VARIABLE)) { - return n; + d_modelCache[n] = ret; + return ret; } TypeNode t = n.getType(); @@ -135,7 +145,9 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const std::map< Node, Node >::const_iterator it = d_uf_models.find(n); if (it != d_uf_models.end()) { // Existing function - return it->second; + ret = it->second; + d_modelCache[n] = ret; + return ret; } // Unknown function symbol: return LAMBDA x. c, where c is the first constant in the enumeration of the range type vector argTypes = t.getArgTypes(); @@ -146,7 +158,9 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const } Node boundVarList = nm->mkNode(kind::BOUND_VAR_LIST, args); TypeEnumerator te(t.getRangeType()); - return nm->mkNode(kind::LAMBDA, boundVarList, *te); + ret = nm->mkNode(kind::LAMBDA, boundVarList, *te); + d_modelCache[n] = ret; + return ret; } // TODO: if func models not enabled, throw an error? Unreachable(); @@ -165,34 +179,39 @@ Node TheoryModel::getModelValue(TNode n, bool hasBoundVars) const } //evaluate the children for (unsigned i = 0; i < n.getNumChildren(); ++i) { - Node val = getModelValue(n[i], hasBoundVars); - children.push_back(val); + ret = getModelValue(n[i], hasBoundVars); + children.push_back(ret); } - Node val = Rewriter::rewrite(NodeManager::currentNM()->mkNode(n.getKind(), children)); - if(val.getKind() == kind::CARDINALITY_CONSTRAINT) { - val = NodeManager::currentNM()->mkConst(getCardinality(val[0].getType().toType()).getFiniteCardinality() <= val[1].getConst().getNumerator()); + ret = Rewriter::rewrite(NodeManager::currentNM()->mkNode(n.getKind(), children)); + if(ret.getKind() == kind::CARDINALITY_CONSTRAINT) { + ret = NodeManager::currentNM()->mkConst(getCardinality(ret[0].getType().toType()).getFiniteCardinality() <= ret[1].getConst().getNumerator()); } - if(val.getKind() == kind::COMBINED_CARDINALITY_CONSTRAINT ){ + if(ret.getKind() == kind::COMBINED_CARDINALITY_CONSTRAINT ){ //FIXME - val = NodeManager::currentNM()->mkConst(false); + ret = NodeManager::currentNM()->mkConst(false); } - return val; + d_modelCache[n] = ret; + return ret; } if (!d_equalityEngine->hasTerm(n)) { // Unknown term - return first enumerated value for this type TypeEnumerator te(n.getType()); - return *te; + ret = *te; + d_modelCache[n] = ret; + return ret; } } - Node val = d_equalityEngine->getRepresentative(n); - Assert(d_reps.find(val) != d_reps.end()); - std::map< Node, Node >::const_iterator it = d_reps.find( val ); - if( it!=d_reps.end() ){ - return it->second; - }else{ - return Node::null(); + ret = d_equalityEngine->getRepresentative(ret); + Assert(d_reps.find(ret) != d_reps.end()); + std::map< Node, Node >::const_iterator it2 = d_reps.find( ret ); + if (it2 != d_reps.end()) { + ret = it2->second; + } else { + ret = Node::null(); } + d_modelCache[n] = ret; + return ret; } Node TheoryModel::getDomainValue( TypeNode tn, std::vector< Node >& exclude ){ diff --git a/src/theory/theory_model.h b/src/theory/theory_model.h index 2eb33b0fb..eeaf3c8da 100644 --- a/src/theory/theory_model.h +++ b/src/theory/theory_model.h @@ -52,6 +52,7 @@ public: Node d_true; Node d_false; context::CDO d_modelBuilt; + mutable std::hash_map d_modelCache; protected: /** reset the model */ -- 2.30.2