Ensure exported sygus solutions match grammar (#4270)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 11 Apr 2020 04:37:43 +0000 (23:37 -0500)
committerGitHub <noreply@github.com>
Sat, 11 Apr 2020 04:37:43 +0000 (23:37 -0500)
Previously we were doing rewriting/expand definitions during grammar normalization, which overwrote the original sygus operators. The connection to the original grammar was maintained via the SygusPrintCallback utility, which ensured that a sygus term printed in a way that matched the grammar.

We now have several use cases where solutions from SyGuS will be directly exported to the user, including the current use of get-abduct. This means that the terms must match the grammar, and we cannot simply rely on the print callback.

This moves the code to normalize sygus operators to datatypes utils, where the conversion between sygus and builtin terms takes place. This allows a version of this function where isExternal = true, which constructs terms matching the original grammar.

This PR enables the SyGuS API to have an accurate getSynthSolution method. It also will eliminate the need for SygusPrintCallback altogether, once the v1 parser is deleted.

src/expr/dtype_cons.h
src/theory/datatypes/theory_datatypes_utils.cpp
src/theory/datatypes/theory_datatypes_utils.h
src/theory/quantifiers/sygus/sygus_grammar_norm.cpp
src/theory/quantifiers/sygus/sygus_grammar_norm.h
src/theory/quantifiers/sygus/synth_conjecture.cpp
test/regress/CMakeLists.txt
test/regress/regress1/sygus/yoni-true-sol.smt2 [new file with mode: 0644]

index d5d0013de10d3e157d8297d852dbf23e44f8834b..ca48063167ae37d72f367c93f3ef82f8431ca286 100644 (file)
@@ -87,12 +87,9 @@ class DTypeConstructor
   void setSygus(Node op);
   /** get sygus op
    *
-   * This method returns the operator or
-   * term that this constructor represents
-   * in the sygus encoding. This may be a
-   * builtin operator, defined function, variable,
-   * or constant that this constructor encodes in this
-   * deep embedding.
+   * This method returns the operator or term that this constructor represents
+   * in the sygus encoding. This may be a builtin operator, defined function,
+   * variable, or constant that this constructor encodes in this deep embedding.
    */
   Node getSygusOp() const;
   /** is this a sygus identity function?
index 13cc8fc19a5f8834480007170ad8e8f6f555dae9..ee0fd814ef38ba73f2ba11b7c5cf6b688c051218 100644 (file)
 #include "expr/dtype.h"
 #include "expr/node_algorithm.h"
 #include "expr/sygus_datatype.h"
+#include "smt/smt_engine.h"
+#include "smt/smt_engine_scope.h"
 #include "theory/evaluator.h"
+#include "theory/rewriter.h"
 
 using namespace CVC4;
 using namespace CVC4::kind;
@@ -117,10 +120,99 @@ Kind getOperatorKindForSygusBuiltin(Node op)
   return UNDEFINED_KIND;
 }
 
+struct SygusOpRewrittenAttributeId
+{
+};
+typedef expr::Attribute<SygusOpRewrittenAttributeId, Node>
+    SygusOpRewrittenAttribute;
+
+Kind getEliminateKind(Kind ok)
+{
+  Kind nk = ok;
+  // We also must ensure that builtin operators which are eliminated
+  // during expand definitions are replaced by the proper operator.
+  if (ok == BITVECTOR_UDIV)
+  {
+    nk = BITVECTOR_UDIV_TOTAL;
+  }
+  else if (ok == BITVECTOR_UREM)
+  {
+    nk = BITVECTOR_UREM_TOTAL;
+  }
+  else if (ok == DIVISION)
+  {
+    nk = DIVISION_TOTAL;
+  }
+  else if (ok == INTS_DIVISION)
+  {
+    nk = INTS_DIVISION_TOTAL;
+  }
+  else if (ok == INTS_MODULUS)
+  {
+    nk = INTS_MODULUS_TOTAL;
+  }
+  return nk;
+}
+
+Node eliminatePartialOperators(Node n)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
+  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
+  std::vector<TNode> visit;
+  TNode cur;
+  visit.push_back(n);
+  do
+  {
+    cur = visit.back();
+    visit.pop_back();
+    it = visited.find(cur);
+
+    if (it == visited.end())
+    {
+      visited[cur] = Node::null();
+      visit.push_back(cur);
+      for (const Node& cn : cur)
+      {
+        visit.push_back(cn);
+      }
+    }
+    else if (it->second.isNull())
+    {
+      Node ret = cur;
+      bool childChanged = false;
+      std::vector<Node> children;
+      if (cur.getMetaKind() == metakind::PARAMETERIZED)
+      {
+        children.push_back(cur.getOperator());
+      }
+      for (const Node& cn : cur)
+      {
+        it = visited.find(cn);
+        Assert(it != visited.end());
+        Assert(!it->second.isNull());
+        childChanged = childChanged || cn != it->second;
+        children.push_back(it->second);
+      }
+      Kind ok = cur.getKind();
+      Kind nk = getEliminateKind(ok);
+      if (nk != ok || childChanged)
+      {
+        ret = nm->mkNode(nk, children);
+      }
+      visited[cur] = ret;
+    }
+  } while (!visit.empty());
+  Assert(visited.find(n) != visited.end());
+  Assert(!visited.find(n)->second.isNull());
+  return visited[n];
+}
+
 Node mkSygusTerm(const DType& dt,
                  unsigned i,
                  const std::vector<Node>& children,
-                 bool doBetaReduction)
+                 bool doBetaReduction,
+                 bool isExternal)
 {
   Trace("dt-sygus-util") << "Make sygus term " << dt.getName() << "[" << i
                          << "] with children: " << children << std::endl;
@@ -128,7 +220,49 @@ Node mkSygusTerm(const DType& dt,
   Assert(dt.isSygus());
   Assert(!dt[i].getSygusOp().isNull());
   Node op = dt[i].getSygusOp();
-  return mkSygusTerm(op, children, doBetaReduction);
+  Node opn = op;
+  if (!isExternal)
+  {
+    // Get the normalized version of the sygus operator. We do this by
+    // expanding definitions, rewriting it, and eliminating partial operators.
+    if (!op.hasAttribute(SygusOpRewrittenAttribute()))
+    {
+      if (op.isConst())
+      {
+        // If it is a builtin operator, convert to total version if necessary.
+        // First, get the kind for the operator.
+        Kind ok = NodeManager::operatorToKind(op);
+        Trace("sygus-grammar-normalize-debug")
+            << "...builtin kind is " << ok << std::endl;
+        Kind nk = getEliminateKind(ok);
+        if (nk != ok)
+        {
+          Trace("sygus-grammar-normalize-debug")
+              << "...replace by builtin operator " << nk << std::endl;
+          opn = NodeManager::currentNM()->operatorOf(nk);
+        }
+      }
+      else
+      {
+        // Only expand definitions if the operator is not constant, since
+        // calling expandDefinitions on them should be a no-op. This check
+        // ensures we don't try to expand e.g. bitvector extract operators,
+        // whose type is undefined, and thus should not be passed to
+        // expandDefinitions.
+        opn = Node::fromExpr(
+            smt::currentSmtEngine()->expandDefinitions(op.toExpr()));
+        opn = Rewriter::rewrite(opn);
+        opn = eliminatePartialOperators(opn);
+        SygusOpRewrittenAttribute sora;
+        op.setAttribute(sora, opn);
+      }
+    }
+    else
+    {
+      opn = op.getAttribute(SygusOpRewrittenAttribute());
+    }
+  }
+  return mkSygusTerm(opn, children, doBetaReduction);
 }
 
 Node mkSygusTerm(Node op,
@@ -386,7 +520,7 @@ struct SygusToBuiltinTermAttributeId
 typedef expr::Attribute<SygusToBuiltinTermAttributeId, Node>
     SygusToBuiltinTermAttribute;
 
-Node sygusToBuiltin(Node n)
+Node sygusToBuiltin(Node n, bool isExternal)
 {
   Assert(n.isConst());
   std::unordered_map<TNode, Node, TNodeHashFunction> visited;
@@ -404,7 +538,7 @@ Node sygusToBuiltin(Node n)
     {
       if (cur.getKind() == APPLY_CONSTRUCTOR)
       {
-        if (cur.hasAttribute(SygusToBuiltinTermAttribute()))
+        if (!isExternal && cur.hasAttribute(SygusToBuiltinTermAttribute()))
         {
           visited[cur] = cur.getAttribute(SygusToBuiltinTermAttribute());
         }
@@ -445,12 +579,15 @@ Node sygusToBuiltin(Node n)
           children.push_back(it->second);
         }
         index = indexOf(cur.getOperator());
-        ret = mkSygusTerm(dt, index, children);
+        ret = mkSygusTerm(dt, index, children, true, isExternal);
       }
       visited[cur] = ret;
       // cache
-      SygusToBuiltinTermAttribute stbt;
-      cur.setAttribute(stbt, ret);
+      if (!isExternal)
+      {
+        SygusToBuiltinTermAttribute stbt;
+        cur.setAttribute(stbt, ret);
+      }
     }
   } while (!visit.empty());
   Assert(visited.find(n) != visited.end());
index b2330227669f726f574901122ff5de1267acc663..58f719910ab90fdaf6b5c6353acbb1062f50f801 100644 (file)
@@ -146,17 +146,31 @@ bool checkClash(Node n1, Node n2, std::vector<Node>& rew);
  * function mkSygusTerm.
  */
 Kind getOperatorKindForSygusBuiltin(Node op);
+/**
+ * Returns the total version of Kind k if it is a partial operator, or
+ * otherwise k itself.
+ */
+Kind getEliminateKind(Kind k);
+/**
+ * Returns a version of n where all partial functions such as bvudiv
+ * have been replaced by their total versions like bvudiv_total.
+ */
+Node eliminatePartialOperators(Node n);
 /** make sygus term
  *
  * This function returns a builtin term f( children[0], ..., children[n] )
  * where f is the builtin op that the i^th constructor of sygus datatype dt
  * encodes. If doBetaReduction is true, then lambdas are eagerly eliminated
  * via beta reduction.
+ *
+ * If isExternal is true, then the returned term respects the original grammar
+ * that was provided. This includes the use of defined functions.
  */
 Node mkSygusTerm(const DType& dt,
                  unsigned i,
                  const std::vector<Node>& children,
-                 bool doBetaReduction = true);
+                 bool doBetaReduction = true,
+                 bool isExternal = false);
 /**
  * Same as above, but we already have the sygus operator op. The above method
  * is syntax sugar for calling this method on dt[i].getSygusOp().
@@ -201,8 +215,13 @@ Node applySygusArgs(const DType& dt,
  * equivalent. For example, given input C_*( C_x(), C_y() ), this method returns
  * x*y, assuming C_+, C_x, and C_y have sygus operators *, x, and y
  * respectively.
+ *
+ * If isExternal is true, then the returned term respects the original grammar
+ * that was provided. This includes the use of defined functions. This argument
+ * should typically be false, unless we are e.g. exporting the value of the
+ * term as a final solution.
  */
-Node sygusToBuiltin(Node c);
+Node sygusToBuiltin(Node c, bool isExternal = false);
 /** Sygus to builtin eval
  *
  * This method returns the rewritten form of (DT_SYGUS_EVAL n args). Notice that
index f00fd0092fa2014de60a20eadb378a0b3a7d30d7..3b2c569743e97c0174f2ee47bb654fd52323fbdd 100644 (file)
@@ -81,87 +81,6 @@ SygusGrammarNorm::TypeObject::TypeObject(TypeNode src_tn, TypeNode unres_tn)
       d_sdt(unres_tn.getAttribute(expr::VarNameAttr()))
 {
 }
-Kind SygusGrammarNorm::TypeObject::getEliminateKind(Kind ok)
-{
-  Kind nk = ok;
-  // We also must ensure that builtin operators which are eliminated
-  // during expand definitions are replaced by the proper operator.
-  if (ok == BITVECTOR_UDIV)
-  {
-    nk = BITVECTOR_UDIV_TOTAL;
-  }
-  else if (ok == BITVECTOR_UREM)
-  {
-    nk = BITVECTOR_UREM_TOTAL;
-  }
-  else if (ok == DIVISION)
-  {
-    nk = DIVISION_TOTAL;
-  }
-  else if (ok == INTS_DIVISION)
-  {
-    nk = INTS_DIVISION_TOTAL;
-  }
-  else if (ok == INTS_MODULUS)
-  {
-    nk = INTS_MODULUS_TOTAL;
-  }
-  return nk;
-}
-
-Node SygusGrammarNorm::TypeObject::eliminatePartialOperators(Node n)
-{
-  NodeManager* nm = NodeManager::currentNM();
-  std::unordered_map<TNode, Node, TNodeHashFunction> visited;
-  std::unordered_map<TNode, Node, TNodeHashFunction>::iterator it;
-  std::vector<TNode> visit;
-  TNode cur;
-  visit.push_back(n);
-  do
-  {
-    cur = visit.back();
-    visit.pop_back();
-    it = visited.find(cur);
-
-    if (it == visited.end())
-    {
-      visited[cur] = Node::null();
-      visit.push_back(cur);
-      for (const Node& cn : cur)
-      {
-        visit.push_back(cn);
-      }
-    }
-    else if (it->second.isNull())
-    {
-      Node ret = cur;
-      bool childChanged = false;
-      std::vector<Node> children;
-      if (cur.getMetaKind() == metakind::PARAMETERIZED)
-      {
-        children.push_back(cur.getOperator());
-      }
-      for (const Node& cn : cur)
-      {
-        it = visited.find(cn);
-        Assert(it != visited.end());
-        Assert(!it->second.isNull());
-        childChanged = childChanged || cn != it->second;
-        children.push_back(it->second);
-      }
-      Kind ok = cur.getKind();
-      Kind nk = getEliminateKind(ok);
-      if (nk != ok || childChanged)
-      {
-        ret = nm->mkNode(nk, children);
-      }
-      visited[cur] = ret;
-    }
-  } while (!visit.empty());
-  Assert(visited.find(n) != visited.end());
-  Assert(!visited.find(n)->second.isNull());
-  return visited[n];
-}
 
 void SygusGrammarNorm::TypeObject::addConsInfo(
     SygusGrammarNorm* sygus_norm,
@@ -174,41 +93,6 @@ void SygusGrammarNorm::TypeObject::addConsInfo(
   Node sygus_op = cons.getSygusOp();
   Trace("sygus-grammar-normalize-debug")
       << ".....operator is " << sygus_op << std::endl;
-  Node exp_sop_n = sygus_op;
-  if (exp_sop_n.isConst())
-  {
-    // If it is a builtin operator, convert to total version if necessary.
-    // First, get the kind for the operator.
-    Kind ok = NodeManager::operatorToKind(exp_sop_n);
-    Trace("sygus-grammar-normalize-debug")
-        << "...builtin kind is " << ok << std::endl;
-    Kind nk = getEliminateKind(ok);
-    if (nk != ok)
-    {
-      Trace("sygus-grammar-normalize-debug")
-          << "...replace by builtin operator " << nk << std::endl;
-      exp_sop_n = NodeManager::currentNM()->operatorOf(nk);
-    }
-  }
-  else
-  {
-    // Only expand definitions if the operator is not constant, since calling
-    // expandDefinitions on them should be a no-op. This check ensures we don't
-    // try to expand e.g. bitvector extract operators, whose type is undefined,
-    // and thus should not be passed to expandDefinitions.
-    exp_sop_n = Node::fromExpr(
-        smt::currentSmtEngine()->expandDefinitions(sygus_op.toExpr()));
-    exp_sop_n = Rewriter::rewrite(exp_sop_n);
-    Trace("sygus-grammar-normalize-debug")
-        << ".....operator (post-rewrite) is " << exp_sop_n << std::endl;
-    // eliminate all partial operators from it
-    exp_sop_n = eliminatePartialOperators(exp_sop_n);
-    Trace("sygus-grammar-normalize-debug")
-        << ".....operator (eliminate partial operators) is " << exp_sop_n
-        << std::endl;
-    // rewrite again
-    exp_sop_n = Rewriter::rewrite(exp_sop_n);
-  }
 
   std::vector<TypeNode> consTypes;
   const std::vector<std::shared_ptr<DTypeSelector> >& args = cons.getArgs();
@@ -222,10 +106,8 @@ void SygusGrammarNorm::TypeObject::addConsInfo(
     consTypes.push_back(atype);
   }
 
-  Trace("sygus-type-cons-defs") << "\tOriginal op: " << cons.getSygusOp()
-                                << "\n\tExpanded one: " << exp_sop_n << "\n\n";
   d_sdt.addConstructor(
-      exp_sop_n, cons.getName(), consTypes, spc, cons.getWeight());
+      sygus_op, cons.getName(), consTypes, spc, cons.getWeight());
 }
 
 void SygusGrammarNorm::TypeObject::initializeDatatype(
index 360762b38f463a619b7b25bfea8ba73751e69b9f..956228f3813be7439aacc697dc2b542fb62fc9ed 100644 (file)
@@ -200,16 +200,6 @@ class SygusGrammarNorm
     void addConsInfo(SygusGrammarNorm* sygus_norm,
                      const DTypeConstructor& cons,
                      std::shared_ptr<SygusPrintCallback> spc);
-    /**
-     * Returns the total version of Kind k if it is a partial operator, or
-     * otherwise k itself.
-     */
-    static Kind getEliminateKind(Kind k);
-    /**
-     * Returns a version of n where all partial functions such as bvudiv
-     * have been replaced by their total versions like bvudiv_total.
-     */
-    static Node eliminatePartialOperators(Node n);
 
     /** initializes a datatype with the information in the type object
      *
index 1596c30f05716bcde4596d24826efe2f9ed776ea..e69d746fe17a96ecec372221446e37a089f1f0e6 100644 (file)
@@ -1178,8 +1178,10 @@ bool SynthConjecture::getSynthSolutions(
   NodeManager* nm = NodeManager::currentNM();
   std::vector<Node> sols;
   std::vector<int> statuses;
+  Trace("cegqi-debug") << "getSynthSolutions..." << std::endl;
   if (!getSynthSolutionsInternal(sols, statuses))
   {
+    Trace("cegqi-debug") << "...failed internal" << std::endl;
     return false;
   }
   // we add it to the solution map, indexed by this conjecture
@@ -1188,12 +1190,16 @@ bool SynthConjecture::getSynthSolutions(
   {
     Node sol = sols[i];
     int status = statuses[i];
+    Trace("cegqi-debug") << "...got " << i << ": " << sol
+                         << ", status=" << status << std::endl;
     // get the builtin solution
     Node bsol = sol;
     if (status != 0)
     {
-      // convert sygus to builtin here
-      bsol = d_tds->sygusToBuiltin(sol, sol.getType());
+      // Convert sygus to builtin here.
+      // We must use the external representation to ensure bsol matches the
+      // grammar.
+      bsol = datatypes::utils::sygusToBuiltin(sol, true);
     }
     // convert to lambda
     TypeNode tn = d_embed_quant[0][i].getType();
@@ -1214,6 +1220,7 @@ bool SynthConjecture::getSynthSolutions(
     }
     // store in map
     smc[fvar] = bsol;
+    Trace("cegqi-debug") << "...return " << bsol << std::endl;
   }
   return true;
 }
index 8aae1890d85c1057ba168433e7e81b734401279b..06dc2d87c7e780e5066c13f7708ca51ffe67e81d 100644 (file)
@@ -1951,6 +1951,7 @@ set(regress_1_tests
   regress1/sygus/unbdd_inv_gen_ex7.sy
   regress1/sygus/unbdd_inv_gen_winf1.sy
   regress1/sygus/univ_2-long-repeat.sy
+  regress1/sygus/yoni-true-sol.smt2
   regress1/sym/q-constant.smt2
   regress1/sym/q-function.smt2
   regress1/sym/qf-function.smt2
diff --git a/test/regress/regress1/sygus/yoni-true-sol.smt2 b/test/regress/regress1/sygus/yoni-true-sol.smt2
new file mode 100644 (file)
index 0000000..464f7c7
--- /dev/null
@@ -0,0 +1,20 @@
+; COMMAND-LINE: --produce-abducts
+; EXPECT: (define-fun A () Bool (>= j i))
+(set-logic QF_LIA)
+(set-option :produce-abducts true)
+(declare-fun n () Int)
+(declare-fun m () Int)
+(declare-fun i () Int)
+(declare-fun j () Int)
+(assert (and (>= n 0) (>= m 0)))
+(assert (< n i))
+(assert (< (+ i j) m))
+(get-abduct A
+  (<= n m)
+  ((GA Bool) (GJ Int) (GI Int))
+  (
+  (GA Bool ((>= GJ GI)))
+  (GJ Int ( j))
+  (GI Int ( i))
+  )
+)