External cache argument for evaluator (#3672)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 30 Jan 2020 15:41:12 +0000 (09:41 -0600)
committerGitHub <noreply@github.com>
Thu, 30 Jan 2020 15:41:12 +0000 (09:41 -0600)
src/theory/evaluator.cpp
src/theory/evaluator.h

index f95160df700fe9531f7e87df73bf0a19e86ad466..a3ea768d415c0b078c5975de133f48b85593e10b 100644 (file)
@@ -117,11 +117,34 @@ Node EvalResult::toNode() const
 Node Evaluator::eval(TNode n,
                      const std::vector<Node>& args,
                      const std::vector<Node>& vals) const
+{
+  std::unordered_map<Node, Node, NodeHashFunction> visited;
+  return eval(n, args, vals, visited);
+}
+Node Evaluator::eval(
+    TNode n,
+    const std::vector<Node>& args,
+    const std::vector<Node>& vals,
+    const std::unordered_map<Node, Node, NodeHashFunction>& visited) const
 {
   Trace("evaluator") << "Evaluating " << n << " under substitution " << args
-                     << " " << vals << std::endl;
+                     << " " << vals << " with visited size = " << visited.size()
+                     << std::endl;
   std::unordered_map<TNode, Node, NodeHashFunction> evalAsNode;
-  Node ret = evalInternal(n, args, vals, evalAsNode).toNode();
+  std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
+  // add visited to results
+  for (const std::pair<const Node, Node>& p : visited)
+  {
+    Trace("evaluator") << "Add " << p.first << " == " << p.second << std::endl;
+    results[p.first] = evalInternal(p.second, args, vals, evalAsNode, results);
+    if (results[p.first].d_tag == EvalResult::INVALID)
+    {
+      // could not evaluate, use the evalAsNode map
+      evalAsNode[p.first] = evalAsNode[p.second];
+    }
+  }
+  Trace("evaluator") << "Run eval internal..." << std::endl;
+  Node ret = evalInternal(n, args, vals, evalAsNode, results).toNode();
   // if we failed to evaluate
   if (ret.isNull())
   {
@@ -142,9 +165,9 @@ EvalResult Evaluator::evalInternal(
     TNode n,
     const std::vector<Node>& args,
     const std::vector<Node>& vals,
-    std::unordered_map<TNode, Node, NodeHashFunction>& evalAsNode) const
+    std::unordered_map<TNode, Node, NodeHashFunction>& evalAsNode,
+    std::unordered_map<TNode, EvalResult, TNodeHashFunction>& results) const
 {
-  std::unordered_map<TNode, EvalResult, TNodeHashFunction> results;
   std::vector<TNode> queue;
   queue.emplace_back(n);
   std::unordered_map<TNode, EvalResult, TNodeHashFunction>::iterator itr;
@@ -164,15 +187,20 @@ EvalResult Evaluator::evalInternal(
     if (currNode.getMetaKind() == kind::metakind::PARAMETERIZED)
     {
       TNode op = currNode.getOperator();
-      itr = results.find(op);
-      if (itr == results.end())
+      // Certain nodes are parameterized with constant operators, including
+      // bitvector extract. These operators do not need to be evaluated.
+      if (!op.isConst())
       {
-        queue.emplace_back(op);
-        doProcess = false;
-      }
-      else if (itr->second.d_tag == EvalResult::INVALID)
-      {
-        doEval = false;
+        itr = results.find(op);
+        if (itr == results.end())
+        {
+          queue.emplace_back(op);
+          doProcess = false;
+        }
+        else if (itr->second.d_tag == EvalResult::INVALID)
+        {
+          doEval = false;
+        }
       }
     }
     for (const auto& currNodeChild : currNode)
@@ -287,8 +315,9 @@ EvalResult Evaluator::evalInternal(
         // be cached. We could alternatively copy evalAsNode to evalAsNodeC but
         // favor avoiding this copy for performance reasons.
         std::unordered_map<TNode, Node, NodeHashFunction> evalAsNodeC;
+        std::unordered_map<TNode, EvalResult, TNodeHashFunction> resultsC;
         results[currNode] =
-            evalInternal(op[1], lambdaArgs, lambdaVals, evalAsNodeC);
+            evalInternal(op[1], lambdaArgs, lambdaVals, evalAsNodeC, resultsC);
         Trace("evaluator") << "Evaluated via arguments to "
                            << results[currNode].d_tag << std::endl;
         if (results[currNode].d_tag == EvalResult::INVALID)
@@ -807,19 +836,26 @@ Node Evaluator::reconstruct(
   if (n.getMetaKind() == kind::metakind::PARAMETERIZED)
   {
     TNode op = n.getOperator();
-    itr = eresults.find(op);
-    Assert(itr != eresults.end());
-    if (itr->second.d_tag == EvalResult::INVALID)
+    if (op.isConst())
     {
-      // could not evaluate the operator, look in the node cache
-      itn = evalAsNode.find(op);
-      Assert(itn != evalAsNode.end());
-      echildren.push_back(itn->second);
+      echildren.push_back(op);
     }
     else
     {
-      // otherwise, use the evaluation of the operator
-      echildren.push_back(itr->second.toNode());
+      itr = eresults.find(op);
+      Assert(itr != eresults.end());
+      if (itr->second.d_tag == EvalResult::INVALID)
+      {
+        // could not evaluate the operator, look in the node cache
+        itn = evalAsNode.find(op);
+        Assert(itn != evalAsNode.end());
+        echildren.push_back(itn->second);
+      }
+      else
+      {
+        // otherwise, use the evaluation of the operator
+        echildren.push_back(itr->second.toNode());
+      }
     }
   }
   for (const auto& currNodeChild : n)
index 533a0365716723ff986ee88c71c82423db883437..58e179fbe596ba2a421ac482cd4f6f508c2045f5 100644 (file)
@@ -95,6 +95,14 @@ class Evaluator
   Node eval(TNode n,
             const std::vector<Node>& args,
             const std::vector<Node>& vals) const;
+  /**
+   * Same as above, but with a precomputed visited map.
+   */
+  Node eval(
+      TNode n,
+      const std::vector<Node>& args,
+      const std::vector<Node>& vals,
+      const std::unordered_map<Node, Node, NodeHashFunction>& visited) const;
 
  private:
   /**
@@ -117,7 +125,8 @@ class Evaluator
       TNode n,
       const std::vector<Node>& args,
       const std::vector<Node>& vals,
-      std::unordered_map<TNode, Node, NodeHashFunction>& evalAsNode) const;
+      std::unordered_map<TNode, Node, NodeHashFunction>& evalAsNode,
+      std::unordered_map<TNode, EvalResult, TNodeHashFunction>& results) const;
   /** reconstruct
    *
    * This function reconstructs the result of evaluating n using a combination