Use the proper evaluator for optimized SyGuS datatype rewriting (#7266)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 1 Oct 2021 04:49:08 +0000 (23:49 -0500)
committerGitHub <noreply@github.com>
Fri, 1 Oct 2021 04:49:08 +0000 (04:49 +0000)
This updates the datatypes rewriter to use the evaluator from Env instead of creating local copies of Evaluator. This makes all uses of the Evaluator dependent on the proper options (e.g. which will be based later on the cardinality of the alphabet for strings).

This moves one utility method (sygusToBuiltinEval) to the datatypes rewriter, as it uses an Evaluator that will be dependent on options.

Notice that this is another instance where it is important for us to make the cache for the rewriter local. The same issue occurs for other places where rewriting is dependent on options. This issue will be revisited when the option for strings alphabet cardinality is added.

src/smt/env.cpp
src/smt/env.h
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/datatypes_rewriter.h
src/theory/datatypes/sygus_datatype_utils.cpp
src/theory/datatypes/sygus_datatype_utils.h
src/theory/datatypes/theory_datatypes.cpp

index 0ffe1c4b95fc87fc1fed1bd02f795be73d38b4d6..5c7836fb72de4c646a8826f761b09cd3d2c4f1db 100644 (file)
@@ -105,6 +105,11 @@ bool Env::isTheoryProofProducing() const
 
 theory::Rewriter* Env::getRewriter() { return d_rewriter.get(); }
 
+theory::Evaluator* Env::getEvaluator(bool useRewriter)
+{
+  return useRewriter ? d_evalRew.get() : d_eval.get();
+}
+
 theory::TrustSubstitutionMap& Env::getTopLevelSubstitutions()
 {
   return *d_topLevelSubs.get();
index e3a34cf4ab2585422f7cbbb6d3f0f7facf2bdec9..8d2b1636eae736d8465e4e1acf0542c98819121a 100644 (file)
@@ -105,6 +105,14 @@ class Env
   /** Get a pointer to the Rewriter owned by this Env. */
   theory::Rewriter* getRewriter();
 
+  /**
+   * Get a pointer to the Evaluator owned by this Env. There are two variants
+   * of the evaluator, one that invokes the rewriter when evaluation is not
+   * applicable, and one that does not. The former evaluator is returned when
+   * useRewriter is true.
+   */
+  theory::Evaluator* getEvaluator(bool useRewriter = false);
+
   /** Get a reference to the top-level substitution map */
   theory::TrustSubstitutionMap& getTopLevelSubstitutions();
 
index 33d143a36ed36202f1d84fb57f073b73fbdb2f06..c446504fd8ac0a3f34bf1d336df4eefca0c9642a 100644 (file)
@@ -35,6 +35,11 @@ namespace cvc5 {
 namespace theory {
 namespace datatypes {
 
+DatatypesRewriter::DatatypesRewriter(Evaluator* sygusEval)
+    : d_sygusEval(sygusEval)
+{
+}
+
 RewriteResponse DatatypesRewriter::postRewrite(TNode in)
 {
   Trace("datatypes-rewrite-debug") << "post-rewriting " << in << std::endl;
@@ -137,7 +142,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
       {
         args.push_back(in[j]);
       }
-      Node ret = utils::sygusToBuiltinEval(ev, args);
+      Node ret = sygusToBuiltinEval(ev, args);
       Trace("dt-sygus-util") << "...got " << ret << "\n";
       Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
       Assert(in.getType().isComparableTo(ret.getType()));
@@ -920,6 +925,126 @@ TrustNode DatatypesRewriter::expandDefinition(Node n)
   return TrustNode::null();
 }
 
+Node DatatypesRewriter::sygusToBuiltinEval(Node n,
+                                           const std::vector<Node>& args)
+{
+  Assert(d_sygusEval != nullptr);
+  NodeManager* nm = NodeManager::currentNM();
+  // constant arguments?
+  bool constArgs = true;
+  for (const Node& a : args)
+  {
+    if (!a.isConst())
+    {
+      constArgs = false;
+      break;
+    }
+  }
+  std::vector<Node> eargs;
+  bool svarsInit = false;
+  std::vector<Node> svars;
+  std::unordered_map<TNode, Node> visited;
+  std::unordered_map<TNode, Node>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  unsigned index;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+    if (it == visited.end())
+    {
+      TypeNode tn = cur.getType();
+      if (!tn.isDatatype() || !tn.getDType().isSygus())
+      {
+        visited[cur] = cur;
+      }
+      else if (cur.isConst())
+      {
+        // convert to builtin term
+        Node bt = utils::sygusToBuiltin(cur);
+        // run the evaluator if possible
+        if (!svarsInit)
+        {
+          svarsInit = true;
+          TypeNode type = cur.getType();
+          Node varList = type.getDType().getSygusVarList();
+          for (const Node& v : varList)
+          {
+            svars.push_back(v);
+          }
+        }
+        Assert(args.size() == svars.size());
+        // try evaluation if we have constant arguments
+        Node ret =
+            constArgs ? d_sygusEval->eval(bt, svars, args) : Node::null();
+        if (ret.isNull())
+        {
+          // if evaluation was not available, use a substitution
+          ret = bt.substitute(
+              svars.begin(), svars.end(), args.begin(), args.end());
+        }
+        visited[cur] = ret;
+      }
+      else
+      {
+        if (cur.getKind() == APPLY_CONSTRUCTOR)
+        {
+          visited[cur] = Node::null();
+          visit.push_back(cur);
+          for (const Node& cn : cur)
+          {
+            visit.push_back(cn);
+          }
+        }
+        else
+        {
+          // it is the evaluation of this term on the arguments
+          if (eargs.empty())
+          {
+            eargs.push_back(cur);
+            eargs.insert(eargs.end(), args.begin(), args.end());
+          }
+          else
+          {
+            eargs[0] = cur;
+          }
+          visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs);
+        }
+      }
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      Assert(cur.getKind() == APPLY_CONSTRUCTOR);
+      const DType& dt = cur.getType().getDType();
+      // non sygus-datatype terms are also themselves
+      if (dt.isSygus())
+      {
+        std::vector<Node> children;
+        for (const Node& cn : cur)
+        {
+          it = visited.find(cn);
+          Assert(it != visited.end());
+          Assert(!it->second.isNull());
+          children.push_back(it->second);
+        }
+        index = utils::indexOf(cur.getOperator());
+        // apply to children, which constructs the builtin term
+        ret = utils::mkSygusTerm(dt, index, children);
+        // now apply it to arguments in args
+        ret = utils::applySygusArgs(dt, dt[index].getSygusOp(), ret, args);
+      }
+      visited[cur] = ret;
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
+}
+
 }  // namespace datatypes
 }  // namespace theory
 }  // namespace cvc5
index 56dde76a0697ef47b50b3d39a874fa738f39b2e2..31e2a1befa70f51cfe8161a367f4329b56eb53da 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef CVC5__THEORY__DATATYPES__DATATYPES_REWRITER_H
 #define CVC5__THEORY__DATATYPES__DATATYPES_REWRITER_H
 
+#include "theory/evaluator.h"
 #include "theory/theory_rewriter.h"
 
 namespace cvc5 {
@@ -37,6 +38,7 @@ namespace datatypes {
 class DatatypesRewriter : public TheoryRewriter
 {
  public:
+  DatatypesRewriter(Evaluator* sygusEval);
   RewriteResponse postRewrite(TNode in) override;
   RewriteResponse preRewrite(TNode in) override;
 
@@ -164,7 +166,32 @@ class DatatypesRewriter : public TheoryRewriter
                               Node orig,
                               TypeNode orig_tn,
                               unsigned depth);
-}; /* class DatatypesRewriter */
+
+  /** Sygus to builtin eval
+   *
+   * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice
+   * that n does not necessarily need to be a constant.
+   *
+   * It does so by (1) converting constant subterms of n to builtin terms and
+   * evaluating them on the arguments args, (2) unfolding non-constant
+   * applications of sygus constructors in n with respect to args and (3)
+   * converting all other non-constant subterms of n to applications of
+   * DT_SYGUS_EVAL.
+   *
+   * For example, if
+   *   n = C_+( C_*( C_x(), C_y() ), n' ), and args = { 3, 4 }
+   * where n' is a variable, then this method returns:
+   *   12 + (DT_SYGUS_EVAL n' 3 4)
+   * Notice that the subterm C_*( C_x(), C_y() ) is converted to its builtin
+   * equivalent x*y and evaluated under the substition { x -> 3, y -> 4 } giving
+   * 12. The subterm n' is non-constant and thus we return its evaluation under
+   * 3,4, giving the term (DT_SYGUS_EVAL n' 3 4). Since the top-level
+   * constructor is C_+, these terms are added together to give the result.
+   */
+  Node sygusToBuiltinEval(Node n, const std::vector<Node>& args);
+  /** Pointer to the evaluator, used as an optimization for the above method */
+  Evaluator* d_sygusEval;
+};
 
 }  // namespace datatypes
 }  // namespace theory
index 12c255f5770ee0499a8ee250cb80f8491a791231..c68b87d85b4e2b59c80bd74df7ea81a4d56acdb7 100644 (file)
@@ -388,124 +388,6 @@ Node sygusToBuiltin(Node n, bool isExternal)
   return visited[n];
 }
 
-Node sygusToBuiltinEval(Node n, const std::vector<Node>& args)
-{
-  NodeManager* nm = NodeManager::currentNM();
-  Evaluator eval(nullptr);
-  // constant arguments?
-  bool constArgs = true;
-  for (const Node& a : args)
-  {
-    if (!a.isConst())
-    {
-      constArgs = false;
-      break;
-    }
-  }
-  std::vector<Node> eargs;
-  bool svarsInit = false;
-  std::vector<Node> svars;
-  std::unordered_map<TNode, Node> visited;
-  std::unordered_map<TNode, Node>::iterator it;
-  std::vector<TNode> visit;
-  TNode cur;
-  unsigned index;
-  visit.push_back(n);
-  do
-  {
-    cur = visit.back();
-    visit.pop_back();
-    it = visited.find(cur);
-    if (it == visited.end())
-    {
-      TypeNode tn = cur.getType();
-      if (!tn.isDatatype() || !tn.getDType().isSygus())
-      {
-        visited[cur] = cur;
-      }
-      else if (cur.isConst())
-      {
-        // convert to builtin term
-        Node bt = sygusToBuiltin(cur);
-        // run the evaluator if possible
-        if (!svarsInit)
-        {
-          svarsInit = true;
-          TypeNode type = cur.getType();
-          Node varList = type.getDType().getSygusVarList();
-          for (const Node& v : varList)
-          {
-            svars.push_back(v);
-          }
-        }
-        Assert(args.size() == svars.size());
-        // try evaluation if we have constant arguments
-        Node ret = constArgs ? eval.eval(bt, svars, args) : Node::null();
-        if (ret.isNull())
-        {
-          // if evaluation was not available, use a substitution
-          ret = bt.substitute(
-              svars.begin(), svars.end(), args.begin(), args.end());
-        }
-        visited[cur] = ret;
-      }
-      else
-      {
-        if (cur.getKind() == APPLY_CONSTRUCTOR)
-        {
-          visited[cur] = Node::null();
-          visit.push_back(cur);
-          for (const Node& cn : cur)
-          {
-            visit.push_back(cn);
-          }
-        }
-        else
-        {
-          // it is the evaluation of this term on the arguments
-          if (eargs.empty())
-          {
-            eargs.push_back(cur);
-            eargs.insert(eargs.end(), args.begin(), args.end());
-          }
-          else
-          {
-            eargs[0] = cur;
-          }
-          visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs);
-        }
-      }
-    }
-    else if (it->second.isNull())
-    {
-      Node ret = cur;
-      Assert(cur.getKind() == APPLY_CONSTRUCTOR);
-      const DType& dt = cur.getType().getDType();
-      // non sygus-datatype terms are also themselves
-      if (dt.isSygus())
-      {
-        std::vector<Node> children;
-        for (const Node& cn : cur)
-        {
-          it = visited.find(cn);
-          Assert(it != visited.end());
-          Assert(!it->second.isNull());
-          children.push_back(it->second);
-        }
-        index = indexOf(cur.getOperator());
-        // apply to children, which constructs the builtin term
-        ret = mkSygusTerm(dt, index, children);
-        // now apply it to arguments in args
-        ret = applySygusArgs(dt, dt[index].getSygusOp(), ret, args);
-      }
-      visited[cur] = ret;
-    }
-  } while (!visit.empty());
-  Assert(visited.find(n) != visited.end());
-  Assert(!visited.find(n)->second.isNull());
-  return visited[n];
-}
-
 Node builtinVarToSygus(Node v)
 {
   BuiltinVarToSygusAttribute bvtsa;
index 5784fe34af9b4cae7828826ff15081f0927d2d09..3ea6b62e9635dd15d2fb545d158d640fb93437d3 100644 (file)
@@ -165,29 +165,6 @@ Node sygusToBuiltin(Node n, bool isExternal = false);
  */
 Node builtinVarToSygus(Node v);
 
-/** Sygus to builtin eval
- *
- * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice that
- * n does not necessarily need to be a constant.
- *
- * It does so by (1) converting constant subterms of n to builtin terms and
- * evaluating them on the arguments args, (2) unfolding non-constant
- * applications of sygus constructors in n with respect to args and (3)
- * converting all other non-constant subterms of n to applications of
- * DT_SYGUS_EVAL.
- *
- * For example, if
- *   n = C_+( C_*( C_x(), C_y() ), n' ), and args = { 3, 4 }
- * where n' is a variable, then this method returns:
- *   12 + (DT_SYGUS_EVAL n' 3 4)
- * Notice that the subterm C_*( C_x(), C_y() ) is converted to its builtin
- * equivalent x*y and evaluated under the substition { x -> 3, y -> 4 } giving
- * 12. The subterm n' is non-constant and thus we return its evaluation under
- * 3,4, giving the term (DT_SYGUS_EVAL n' 3 4). Since the top-level constructor
- * is C_+, these terms are added together to give the result.
- */
-Node sygusToBuiltinEval(Node n, const std::vector<Node>& args);
-
 /** Get free symbols in a sygus datatype type
  *
  * Add the free symbols (expr::getSymbols) in terms that can be generated by
index feb19b182a690c379e0dec2e24d28f6282dcfab3..a1c6942a5e09ac427b0ab72fc17e474103c4bf63 100644 (file)
@@ -61,6 +61,7 @@ TheoryDatatypes::TheoryDatatypes(Env& env,
       d_functionTerms(context()),
       d_singleton_eq(userContext()),
       d_sygusExtension(nullptr),
+      d_rewriter(env.getEvaluator()),
       d_state(env, valuation),
       d_im(env, *this, d_state, d_pnm),
       d_notify(d_im, *this)