Do not require sygus constructors to be flattened (#3049)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 11 Jun 2019 21:47:13 +0000 (16:47 -0500)
committerGitHub <noreply@github.com>
Tue, 11 Jun 2019 21:47:13 +0000 (16:47 -0500)
src/parser/smt2/Smt2.g
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/datatypes_rewriter.h
src/theory/quantifiers/sygus/sygus_eval_unfold.cpp
src/theory/quantifiers/sygus/sygus_grammar_cons.cpp
src/theory/quantifiers/sygus/sygus_grammar_red.cpp
src/theory/quantifiers/sygus/term_database_sygus.cpp
src/theory/quantifiers/sygus/term_database_sygus.h
test/regress/regress1/sygus/sygus-dt.sy

index 9ba7f4b2eb560297d0990efdca0cc33508df14ca..b224032e84811e0d5766a367108f9b7bffac994f 100644 (file)
@@ -1004,7 +1004,7 @@ sygusGTerm[CVC4::SygusGTerm& sgt, std::string& fun]
       }else if( PARSER_STATE->isDeclared(name,SYM_VARIABLE) ){
         Debug("parser-sygus") << "Sygus grammar " << fun << " : symbol "
                               << name << std::endl;
-        sgt.d_expr = PARSER_STATE->getVariable(name);
+        sgt.d_expr = PARSER_STATE->getExpressionForName(name);
         sgt.d_name = name;
         sgt.d_gterm_type = SygusGTerm::gterm_op;
       }else{
index be87b7e8d858a212c54de53d97a914eb83356322..ac3bff21bffa74267d74ec33a207718a45751301 100644 (file)
@@ -16,6 +16,8 @@
 
 #include "theory/datatypes/datatypes_rewriter.h"
 
+#include "expr/node_algorithm.h"
+
 using namespace CVC4;
 using namespace CVC4::kind;
 
@@ -115,8 +117,7 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
     if (ev.getKind() == APPLY_CONSTRUCTOR)
     {
       Trace("dt-sygus-util") << "Rewrite " << in << " by unfolding...\n";
-      const Datatype& dt =
-          static_cast<DatatypeType>(ev.getType().toType()).getDatatype();
+      const Datatype& dt = ev.getType().getDatatype();
       unsigned i = indexOf(ev.getOperator());
       Node op = Node::fromExpr(dt[i].getSygusOp());
       // if it is the "any constant" constructor, return its argument
@@ -141,14 +142,8 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
         children.push_back(nm->mkNode(DT_SYGUS_EVAL, cc));
       }
       Node ret = mkSygusTerm(dt, i, children);
-      // if it is a variable, apply the substitution
-      if (ret.getKind() == BOUND_VARIABLE)
-      {
-        Assert(ret.hasAttribute(SygusVarNumAttribute()));
-        int vn = ret.getAttribute(SygusVarNumAttribute());
-        Assert(Node::fromExpr(dt.getSygusVarList())[vn] == ret);
-        ret = args[vn];
-      }
+      // apply the appropriate substitution
+      ret = applySygusArgs(dt, op, ret, args);
       Trace("dt-sygus-util") << "...got " << ret << "\n";
       return RewriteResponse(REWRITE_AGAIN_FULL, ret);
     }
@@ -186,6 +181,67 @@ RewriteResponse DatatypesRewriter::postRewrite(TNode in)
   return RewriteResponse(REWRITE_DONE, in);
 }
 
+Node DatatypesRewriter::applySygusArgs(const Datatype& dt,
+                                       Node op,
+                                       Node n,
+                                       const std::vector<Node>& args)
+{
+  if (n.getKind() == BOUND_VARIABLE)
+  {
+    Assert(n.hasAttribute(SygusVarNumAttribute()));
+    int vn = n.getAttribute(SygusVarNumAttribute());
+    Assert(Node::fromExpr(dt.getSygusVarList())[vn] == n);
+    return args[vn];
+  }
+  // n is an application of operator op.
+  // We must compute the free variables in op to determine if there are
+  // any substitutions we need to make to n.
+  TNode val;
+  if (!op.hasAttribute(SygusVarFreeAttribute()))
+  {
+    std::unordered_set<Node, NodeHashFunction> fvs;
+    if (expr::getFreeVariables(op, fvs))
+    {
+      if (fvs.size() == 1)
+      {
+        for (const Node& v : fvs)
+        {
+          val = v;
+        }
+      }
+      else
+      {
+        val = op;
+      }
+    }
+    Trace("dt-sygus-fv") << "Free var in " << op << " : " << val << std::endl;
+    op.setAttribute(SygusVarFreeAttribute(), val);
+  }
+  else
+  {
+    val = op.getAttribute(SygusVarFreeAttribute());
+  }
+  if (val.isNull())
+  {
+    return n;
+  }
+  if (val.getKind() == BOUND_VARIABLE)
+  {
+    // single substitution case
+    int vn = val.getAttribute(SygusVarNumAttribute());
+    TNode sub = args[vn];
+    return n.substitute(val, sub);
+  }
+  // do the full substitution
+  std::vector<Node> vars;
+  Node bvl = Node::fromExpr(dt.getSygusVarList());
+  for (unsigned i = 0, nvars = bvl.getNumChildren(); i < nvars; i++)
+  {
+    vars.push_back(bvl[i]);
+  }
+  return n.substitute(vars.begin(), vars.end(), args.begin(), args.end());
+}
+
 Kind DatatypesRewriter::getOperatorKindForSygusBuiltin(Node op)
 {
   Assert(op.getKind() != BUILTIN);
@@ -224,6 +280,13 @@ Node DatatypesRewriter::mkSygusTerm(const Datatype& dt,
   Assert(!dt[i].getSygusOp().isNull());
   std::vector<Node> schildren;
   Node op = Node::fromExpr(dt[i].getSygusOp());
+  Trace("dt-sygus-util") << "Operator is " << op << std::endl;
+  if (children.empty())
+  {
+    // no children, return immediately
+    Trace("dt-sygus-util") << "...return direct op" << std::endl;
+    return op;
+  }
   // if it is the any constant, we simply return the child
   if (op.getAttribute(SygusAnyConstAttribute()))
   {
@@ -243,18 +306,13 @@ Node DatatypesRewriter::mkSygusTerm(const Datatype& dt,
     return ret;
   }
   Kind ok = NodeManager::operatorToKind(op);
+  Trace("dt-sygus-util") << "operator kind is " << ok << std::endl;
   if (ok != UNDEFINED_KIND)
   {
-    if (ok == APPLY_UF && schildren.size() == 1)
-    {
-      // This case is triggered for defined constant symbols. In this case,
-      // we return the operator itself instead of an APPLY_UF node.
-      ret = schildren[0];
-    }
-    else
-    {
-      ret = NodeManager::currentNM()->mkNode(ok, schildren);
-    }
+    // If it is an APPLY_UF operator, we should have at least an operator and
+    // a child.
+    Assert(ok != APPLY_UF || schildren.size() != 1);
+    ret = NodeManager::currentNM()->mkNode(ok, schildren);
     Trace("dt-sygus-util") << "...return (op) " << ret << std::endl;
     return ret;
   }
index 6c1d64e5b77ef69c856e1f18a523d688a2c2ed1a..1a1735402cf5706594a126d7c086475499ebdf45 100644 (file)
@@ -51,6 +51,28 @@ struct SygusSymBreakOkAttributeId
 typedef expr::Attribute<SygusSymBreakOkAttributeId, bool>
     SygusSymBreakOkAttribute;
 
+/** sygus var free
+ *
+ * This attribute is used to mark whether sygus operators have free occurrences
+ * of variables from the formal argument list of the function-to-synthesize.
+ *
+ * We store three possible cases for sygus operators op:
+ * (1) op.getAttribute(SygusVarFreeAttribute())==Node::null()
+ * In this case, op has no free variables from the formal argument list of the
+ * function-to-synthesize.
+ * (2) op.getAttribute(SygusVarFreeAttribute())==v, where v is a bound variable.
+ * In this case, op has exactly one free variable, v.
+ * (3) op.getAttribute(SygusVarFreeAttribute())==op
+ * In this case, op has an arbitrary set (cardinality >1) of free variables from
+ * the formal argument list of the function to synthesize.
+ *
+ * This attribute is used to compute applySygusArgs below.
+ */
+struct SygusVarFreeAttributeId
+{
+};
+typedef expr::Attribute<SygusVarFreeAttributeId, Node> SygusVarFreeAttribute;
+
 namespace datatypes {
 
 class DatatypesRewriter {
@@ -149,6 +171,37 @@ public:
  static Node mkSygusTerm(const Datatype& dt,
                          unsigned i,
                          const std::vector<Node>& children);
+ /**
+  * n is a builtin term that is an application of operator op.
+  *
+  * This returns an n' such that (eval n args) is n', where n' is a instance of
+  * n for the appropriate substitution.
+  *
+  * For example, given a function-to-synthesize with formal argument list (x,y),
+  * say we have grammar:
+  *   A -> A+A | A+x | A+(x+y) | y
+  * These lead to constructors with sygus ops:
+  *   C1 / (lambda w1 w2. w1+w2)
+  *   C2 / (lambda w1. w1+x)
+  *   C3 / (lambda w1. w1+(x+y))
+  *   C4 / y
+  * Examples of calling this function:
+  *   applySygusArgs( dt, C1, (APPLY_UF (lambda w1 w2. w1+w2) t1 t2), { 3, 5 } )
+  *     ... returns (APPLY_UF (lambda w1 w2. w1+w2) t1 t2).
+  *   applySygusArgs( dt, C2, (APPLY_UF (lambda w1. w1+x) t1), { 3, 5 } )
+  *     ... returns (APPLY_UF (lambda w1. w1+3) t1).
+  *   applySygusArgs( dt, C3, (APPLY_UF (lambda w1. w1+(x+y)) t1), { 3, 5 } )
+  *     ... returns (APPLY_UF (lambda w1. w1+(3+5)) t1).
+  *   applySygusArgs( dt, C4, y, { 3, 5 } )
+  *     ... returns 5.
+  * Notice the attribute SygusVarFreeAttribute is applied to C1, C2, C3, C4,
+  * to cache the results of whether the evaluation of this constructor needs
+  * a substitution over the formal argument list of the function-to-synthesize.
+  */
+ static Node applySygusArgs(const Datatype& dt,
+                            Node op,
+                            Node n,
+                            const std::vector<Node>& args);
  /**
   * Get the builtin sygus operator for constructor term n of sygus datatype
   * type. For example, if n is the term C_+( d1, d2 ) where C_+ is a sygus
index e44b604d0cdbf66090e5874f2b261cce33f6e4db..7324add50da6cc5e2256ab5bf9137cfbc53e5a0f 100644 (file)
@@ -133,7 +133,18 @@ void SygusEvalUnfold::registerModelValue(Node a,
         bool do_unfold = false;
         if (options::sygusEvalUnfoldBool())
         {
-          if (bTerm.getKind() == ITE || bTerm.getType().isBoolean())
+          Node bTermUse = bTerm;
+          if (bTerm.getKind() == APPLY_UF)
+          {
+            // if the builtin term is non-beta-reduced application of lambda,
+            // we look at the body of the lambda.
+            Node bTermOp = bTerm.getOperator();
+            if (bTermOp.getKind() == LAMBDA)
+            {
+              bTermUse = bTermOp[0];
+            }
+          }
+          if (bTermUse.getKind() == ITE || bTermUse.getType().isBoolean())
           {
             do_unfold = true;
           }
index 48da8e8e814e5e49d6e0c181a707cbe02c93a9f5..263c88d158dd9c678e31b7aa340e4806662d3d3d 100644 (file)
@@ -712,7 +712,14 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
       for (unsigned k = 0, size_k = dt.getNumConstructors(); k < size_k; ++k)
       {
         Trace("sygus-grammar-def") << "...for " << dt[k].getName() << std::endl;
-        ops[i].push_back( dt[k].getConstructor() );
+        Expr cop = dt[k].getConstructor();
+        if (dt[k].getNumArgs() == 0)
+        {
+          // Nullary constructors are interpreted as terms, not operators.
+          // Thus, we apply them to no arguments here.
+          cop = nm->mkNode(APPLY_CONSTRUCTOR, Node::fromExpr(cop)).toExpr();
+        }
+        ops[i].push_back(cop);
         cnames[i].push_back(dt[k].getName());
         cargs[i].push_back(std::vector<Type>());
         Trace("sygus-grammar-def") << "...add for selectors" << std::endl;
index 6ad590f2806d4f5061d665167fbaafab9c4f7154..2b2c87f3834abd7656bef9a5119494435dae5251 100644 (file)
@@ -43,7 +43,6 @@ void SygusRedundantCons::initialize(QuantifiersEngine* qe, TypeNode tn)
     std::map<int, Node> pre;
     Node g = tds->mkGeneric(dt, i, pre);
     Trace("sygus-red-debug") << "  ...pre-rewrite : " << g << std::endl;
-    Assert(g.getNumChildren() == dt[i].getNumArgs());
     d_gen_terms[i] = g;
     for (unsigned j = 0, nargs = dt[i].getNumArgs(); j < nargs; j++)
     {
index af820b0fcc54105f9386e96fb0459ee889a36127..01d08dad84a3d0ffc7f766b77ccdc09874016fd3 100644 (file)
@@ -411,6 +411,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ) {
             Trace("sygus-db") << ", kind = " << sk;
             d_kinds[tn][sk] = i;
             d_arg_kind[tn][i] = sk;
+            if (sk == ITE)
+            {
+              // mark that this type has an ITE
+              d_hasIte[tn] = true;
+            }
           }
           else if (sop.isConst() && dt[i].getNumArgs() == 0)
           {
@@ -432,6 +437,11 @@ void TermDbSygus::registerSygusType( TypeNode tn ) {
                   << ", argument to a lambda constructor is not " << lat
                   << std::endl;
             }
+            if (sop[0].getKind() == ITE)
+            {
+              // mark that this type has an ITE
+              d_hasIte[tn] = true;
+            }
           }
           // symbolic constructors
           if (n.getAttribute(SygusAnyConstAttribute()))
@@ -602,7 +612,7 @@ void TermDbSygus::registerEnumerator(Node e,
         // solution" clauses.
         const Datatype& dt = et.getDatatype();
         if (options::sygusStream()
-            || (!hasKind(et, ITE) && !dt.getSygusType().isBoolean()))
+            || (!hasIte(et) && !dt.getSygusType().isBoolean()))
         {
           isActiveGen = true;
         }
@@ -1003,6 +1013,10 @@ int TermDbSygus::getOpConsNum( TypeNode tn, Node n ) {
 bool TermDbSygus::hasKind( TypeNode tn, Kind k ) {
   return getKindConsNum( tn, k )!=-1;
 }
+bool TermDbSygus::hasIte(TypeNode tn) const
+{
+  return d_hasIte.find(tn) != d_hasIte.end();
+}
 bool TermDbSygus::hasConst( TypeNode tn, Node n ) {
   return getConstConsNum( tn, n )!=-1;
 }
@@ -1502,14 +1516,9 @@ Node TermDbSygus::unfold( Node en, std::map< Node, Node >& vtm, std::vector< Nod
     pre[j] = nm->mkNode(DT_SYGUS_EVAL, cc);
   }
   Node ret = mkGeneric(dt, i, pre);
-  // if it is a variable, apply the substitution
-  if (ret.getKind() == kind::BOUND_VARIABLE)
-  {
-    Assert(ret.hasAttribute(SygusVarNumAttribute()));
-    int i = ret.getAttribute(SygusVarNumAttribute());
-    Assert(Node::fromExpr(dt.getSygusVarList())[i] == ret);
-    return args[i];
-  }
+  // apply the appropriate substitution to ret
+  ret = datatypes::DatatypesRewriter::applySygusArgs(dt, sop, ret, args);
+  // rewrite
   ret = Rewriter::rewrite(ret);
   return ret;
 }
index 0f3d650d38a6cfb822f5e86ca94cc2d9af5fa14b..2854ecab6e853029556b75ba980a01f4b5ef969a 100644 (file)
@@ -393,6 +393,11 @@ class TermDbSygus {
   std::map<TypeNode, std::vector<Node> > d_var_list;
   std::map<TypeNode, std::map<int, Kind> > d_arg_kind;
   std::map<TypeNode, std::map<Kind, int> > d_kinds;
+  /**
+   * Whether this sygus type has a constructors whose sygus operator is ITE,
+   * or is a lambda whose body is ITE.
+   */
+  std::map<TypeNode, bool> d_hasIte;
   std::map<TypeNode, std::map<int, Node> > d_arg_const;
   std::map<TypeNode, std::map<Node, int> > d_consts;
   std::map<TypeNode, std::map<Node, int> > d_ops;
@@ -462,6 +467,11 @@ class TermDbSygus {
   int getConstConsNum( TypeNode tn, Node n );
   int getOpConsNum( TypeNode tn, Node n );
   bool hasKind( TypeNode tn, Kind k );
+  /**
+   * Returns true if this sygus type has a constructors whose sygus operator is
+   * ITE, or is a lambda whose body is ITE.
+   */
+  bool hasIte(TypeNode tn) const;
   bool hasConst( TypeNode tn, Node n );
   bool hasOp( TypeNode tn, Node n );
   Node getConsNumConst( TypeNode tn, int i );
index 2f3f4dbb9fb08729514490108f5176f21739379a..336c59b27b3b85e0c22fc4ec61f0d644bd24dffc 100644 (file)
@@ -7,7 +7,7 @@
 (define-fun g ((x Int)) List (cons (+ x 1) nil))
 (define-fun i () List (cons 3 nil))
 
-(synth-fun f ((x Int)) List ((Start List ((g StartInt) i (cons StartInt Start) (nil) (tail Start)))
+(synth-fun f ((x Int)) List ((Start List ((g StartInt) i (cons StartInt Start) nil (tail Start)))
                              (StartInt Int (x 0 1 (+ StartInt StartInt)))))
 
 (declare-var x Int)