Optimize the rewriter for DT_SYGUS_EVAL (#3529)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 6 Dec 2019 19:12:12 +0000 (13:12 -0600)
committerAndres Noetzli <andres.noetzli@gmail.com>
Fri, 6 Dec 2019 19:12:12 +0000 (11:12 -0800)
This makes it so that we don't construct intermediate unfoldings of applications of DT_SYGUS_EVAL, which wastes time in node construction. It makes the sygusToBuiltin utility in TermDbSygus use this implementation.

src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/theory_datatypes_utils.cpp
src/theory/datatypes/theory_datatypes_utils.h
src/theory/quantifiers/sygus/term_database_sygus.cpp

index be4226f694a40db58d24b0ba10903a37a13204c2..080306d39a2ed96e9d954e483b62abac05908f19 100644 (file)
@@ -120,34 +120,16 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
     if (ev.getKind() == APPLY_CONSTRUCTOR)
     {
       Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
-      const Datatype& dt = ev.getType().getDatatype();
-      unsigned i = utils::indexOf(ev.getOperator());
-      Node op = Node::fromExpr(dt[i].getSygusOp());
-      // if it is the "any constant" constructor, return its argument
-      if (op.getAttribute(SygusAnyConstAttribute()))
-      {
-        Assert(ev.getNumChildren() == 1);
-        Assert(ev[0].getType().isComparableTo(in.getType()));
-        return RewriteResponse(REWRITE_AGAIN_FULL, ev[0]);
-      }
+      Trace("dt-sygus-util") << "Type is " << in.getType() << std::endl;
       std::vector<Node> args;
       for (unsigned j = 1, nchild = in.getNumChildren(); j < nchild; j++)
       {
         args.push_back(in[j]);
       }
-      Assert(!dt.isParametric());
-      std::vector<Node> children;
-      for (const Node& evc : ev)
-      {
-        std::vector<Node> cc;
-        cc.push_back(evc);
-        cc.insert(cc.end(), args.begin(), args.end());
-        children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc));
-      }
-      Node ret = utils::mkSygusTerm(dt, i, children);
-      // apply the appropriate substitution
-      ret = utils::applySygusArgs(dt, op, ret, args);
+      Node ret = utils::sygusToBuiltinEval(ev, args);
       Trace("dt-sygus-util") << "...got " << ret << "\n";
+      Trace("dt-sygus-util") << "Type is " << ret.getType() << std::endl;
+      Assert(in.getType().isComparableTo(ret.getType()));
       return RewriteResponse(REWRITE_AGAIN_FULL, ret);
     }
   }
index 43d23b523cd80bb6b16c81571d0cc88159cd2201..d2833a85234c835c3991d9c15a656ade56ee0c1c 100644 (file)
@@ -18,6 +18,7 @@
 
 #include "expr/node_algorithm.h"
 #include "expr/sygus_datatype.h"
+#include "theory/evaluator.h"
 
 using namespace CVC4;
 using namespace CVC4::kind;
@@ -384,6 +385,200 @@ bool checkClash(Node n1, Node n2, std::vector<Node>& rew)
   return false;
 }
 
+struct SygusToBuiltinTermAttributeId
+{
+};
+typedef expr::Attribute<SygusToBuiltinTermAttributeId, Node>
+    SygusToBuiltinTermAttribute;
+
+Node sygusToBuiltin(Node n)
+{
+  Assert(n.isConst());
+  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  unsigned index;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+    if (it == visited.end())
+    {
+      if (cur.getKind() == APPLY_CONSTRUCTOR)
+      {
+        if (cur.hasAttribute(SygusToBuiltinTermAttribute()))
+        {
+          visited[cur] = cur.getAttribute(SygusToBuiltinTermAttribute());
+        }
+        else
+        {
+          visited[cur] = Node::null();
+          visit.push_back(cur);
+          for (const Node& cn : cur)
+          {
+            visit.push_back(cn);
+          }
+        }
+      }
+      else
+      {
+        // non-datatypes are themselves
+        visited[cur] = cur;
+      }
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      Assert(cur.getKind() == APPLY_CONSTRUCTOR);
+      const Datatype& dt = cur.getType().getDatatype();
+      // Non sygus-datatype terms are also themselves. Notice we treat the
+      // case of non-sygus datatypes this way since it avoids computing
+      // the type / datatype of the node in the pre-traversal above. The
+      // case of non-sygus datatypes is very rare, so the extra addition to
+      // visited is justified performance-wise.
+      if (dt.isSygus())
+      {
+        std::vector<Node> children;
+        for (const Node& cn : cur)
+        {
+          it = visited.find(cn);
+          Assert(it != visited.end());
+          Assert(!it->second.isNull());
+          children.push_back(it->second);
+        }
+        index = indexOf(cur.getOperator());
+        ret = mkSygusTerm(dt, index, children);
+      }
+      visited[cur] = ret;
+      // cache
+      SygusToBuiltinTermAttribute stbt;
+      cur.setAttribute(stbt, ret);
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
+}
+
+Node sygusToBuiltinEval(Node n, const std::vector<Node>& args)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Evaluator eval;
+  // constant arguments?
+  bool constArgs = true;
+  for (const Node& a : args)
+  {
+    if (!a.isConst())
+    {
+      constArgs = false;
+      break;
+    }
+  }
+  std::vector<Node> eargs;
+  bool svarsInit = false;
+  std::vector<Node> svars;
+  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  unsigned index;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+    if (it == visited.end())
+    {
+      TypeNode tn = cur.getType();
+      if (!tn.isDatatype() || !tn.getDatatype().isSygus())
+      {
+        visited[cur] = cur;
+      }
+      else if (cur.isConst())
+      {
+        // convert to builtin term
+        Node bt = sygusToBuiltin(cur);
+        // run the evaluator if possible
+        if (!svarsInit)
+        {
+          svarsInit = true;
+          TypeNode tn = cur.getType();
+          Node varList = Node::fromExpr(tn.getDatatype().getSygusVarList());
+          for (const Node& v : varList)
+          {
+            svars.push_back(v);
+          }
+        }
+        Assert(args.size() == svars.size());
+        // try evaluation if we have constant arguments
+        Node ret = constArgs ? eval.eval(bt, svars, args) : Node::null();
+        if (ret.isNull())
+        {
+          // if evaluation was not available, use a substitution
+          ret = bt.substitute(
+              svars.begin(), svars.end(), args.begin(), args.end());
+        }
+        visited[cur] = ret;
+      }
+      else
+      {
+        if (cur.getKind() == APPLY_CONSTRUCTOR)
+        {
+          visited[cur] = Node::null();
+          visit.push_back(cur);
+          for (const Node& cn : cur)
+          {
+            visit.push_back(cn);
+          }
+        }
+        else
+        {
+          // it is the evaluation of this term on the arguments
+          if (eargs.empty())
+          {
+            eargs.push_back(cur);
+            eargs.insert(eargs.end(), args.begin(), args.end());
+          }
+          else
+          {
+            eargs[0] = cur;
+          }
+          visited[cur] = nm->mkNode(DT_SYGUS_EVAL, eargs);
+        }
+      }
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      Assert(cur.getKind() == APPLY_CONSTRUCTOR);
+      const Datatype& dt = cur.getType().getDatatype();
+      // non sygus-datatype terms are also themselves
+      if (dt.isSygus())
+      {
+        std::vector<Node> children;
+        for (const Node& cn : cur)
+        {
+          it = visited.find(cn);
+          Assert(it != visited.end());
+          Assert(!it->second.isNull());
+          children.push_back(it->second);
+        }
+        index = indexOf(cur.getOperator());
+        // apply to arguments
+        ret = mkSygusTerm(dt, index, children);
+      }
+      visited[cur] = ret;
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
+}
+
 }  // namespace utils
 }  // namespace datatypes
 }  // namespace theory
index 5f74a4bee693f6141fcfb57bed04318cd3d6f62a..46a6d56be93edc2ff55cf10da3e2abd3c1223e11 100644 (file)
@@ -185,12 +185,36 @@ Node applySygusArgs(const Datatype& dt,
                     Node op,
                     Node n,
                     const std::vector<Node>& args);
-/**
- * Get the builtin sygus operator for constructor term n of sygus datatype
- * type. For example, if n is the term C_+( d1, d2 ) where C_+ is a sygus
- * constructor whose sygus op is the builtin operator +, this method returns +.
+/** Sygus to builtin
+ *
+ * This method converts a constant term of SyGuS datatype type to its builtin
+ * equivalent. For example, given input C_*( C_x(), C_y() ), this method returns
+ * x*y, assuming C_+, C_x, and C_y have sygus operators *, x, and y
+ * respectively.
+ */
+Node sygusToBuiltin(Node c);
+/** Sygus to builtin eval
+ *
+ * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice that
+ * n does not necessarily need to be a constant.
+ *
+ * It does so by (1) converting constant subterms of n to builtin terms and
+ * evaluating them on the arguments args, (2) unfolding non-constant
+ * applications of sygus constructors in n with respect to args and (3)
+ * converting all other non-constant subterms of n to applications of
+ * DT_SYGUS_EVAL.
+ *
+ * For example, if
+ *   n = C_+( C_*( C_x(), C_y() ), n' ), and args = { 3, 4 }
+ * where n' is a variable, then this method returns:
+ *   12 + (DT_SYGUS_EVAL n' 3 4)
+ * Notice that the subterm C_*( C_x(), C_y() ) is converted to its builtin
+ * equivalent x*y and evaluated under the substition { x -> 3, x -> 4 } giving
+ * 12. The subterm n' is non-constant and thus we return its evaluation under
+ * 3,4, giving the term (DT_SYGUS_EVAL n' 3 4). Since the top-level constructor
+ * is C_+, these terms are added together to give the result.
  */
-Node getSygusOpForCTerm(Node n);
+Node sygusToBuiltinEval(Node n, const std::vector<Node>& args);
 
 // ------------------------ end sygus utils
 
index d664a462ddd70ae7eb5879e255e45e12b2f9456a..c5ea0f9f3af36612560e5ac5170f8f7f5864cba1 100644 (file)
@@ -277,6 +277,11 @@ typedef expr::Attribute<SygusToBuiltinAttributeId, Node>
 
 Node TermDbSygus::sygusToBuiltin(Node n, TypeNode tn)
 {
+  if (n.isConst())
+  {
+    // if its a constant, we use the datatype utility version
+    return datatypes::utils::sygusToBuiltin(n);
+  }
   Assert(n.getType().isComparableTo(tn));
   if (!tn.isDatatype())
   {