Expand definitions in sygus operators at the SMT level (#7077)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 27 Aug 2021 18:03:41 +0000 (13:03 -0500)
committerGitHub <noreply@github.com>
Fri, 27 Aug 2021 18:03:41 +0000 (18:03 +0000)
Eliminates another call to currentSmtEngine.

This PR ensures we remember the mapping between operators that are embedded in sygus datatypes during preprocessing, instead of computing this within the sygus datatypes utilities when solving.

src/smt/smt_engine.cpp
src/smt/sygus_solver.cpp
src/smt/sygus_solver.h
src/theory/datatypes/sygus_datatype_utils.cpp
src/theory/datatypes/sygus_datatype_utils.h
src/theory/quantifiers/sygus/type_info.cpp

index be5c75af05f851ee7bddf1e3d7ee532ddf053136..0ca4d5b15ecc5bbf709891f36309c6e772db6048 100644 (file)
@@ -1021,6 +1021,7 @@ void SmtEngine::declareSynthFun(Node func,
                                 const std::vector<Node>& vars)
 {
   SmtScope smts(this);
+  finishInit();
   d_state->doPendingPops();
   d_sygusSolver->declareSynthFun(func, sygusType, isInv, vars);
 
index b7b6d9c18a2cd0b98f679e7a163c27874bf3d97c..240f96af741654cce7a3931b5ff1ac54344f16c3 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "base/modal_exception.h"
 #include "expr/dtype.h"
+#include "expr/dtype_cons.h"
 #include "expr/skolem_manager.h"
 #include "options/base_options.h"
 #include "options/option_exception.h"
@@ -28,6 +29,7 @@
 #include "smt/dump.h"
 #include "smt/preprocessor.h"
 #include "smt/smt_solver.h"
+#include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
 #include "theory/quantifiers/sygus/sygus_utils.h"
@@ -82,6 +84,8 @@ void SygusSolver::declareSynthFun(Node fn,
     // use an attribute to mark its grammar
     SygusSynthGrammarAttribute ssfga;
     fn.setAttribute(ssfga, sym);
+    // we must expand definitions for sygus operators in the block
+    expandDefinitionsSygusDt(sygusType);
   }
 
   // sygus conjecture is now stale
@@ -408,5 +412,44 @@ void SygusSolver::setSygusConjectureStale()
   // TODO (project #7): if incremental, we should pop a context
 }
 
+void SygusSolver::expandDefinitionsSygusDt(TypeNode tn) const
+{
+  std::unordered_set<TypeNode> processed;
+  std::vector<TypeNode> toProcess;
+  toProcess.push_back(tn);
+  size_t index = 0;
+  while (index < toProcess.size())
+  {
+    TypeNode tnp = toProcess[index];
+    index++;
+    Assert(tnp.isDatatype());
+    Assert(tnp.getDType().isSygus());
+    const std::vector<std::shared_ptr<DTypeConstructor>>& cons =
+        tnp.getDType().getConstructors();
+    for (const std::shared_ptr<DTypeConstructor>& c : cons)
+    {
+      Node op = c->getSygusOp();
+      // Only expand definitions if the operator is not constant, since
+      // calling expandDefinitions on them should be a no-op. This check
+      // ensures we don't try to expand e.g. bitvector extract operators,
+      // whose type is undefined, and thus should not be passed to
+      // expandDefinitions.
+      Node eop = op.isConst() ? op : d_pp.expandDefinitions(op);
+      datatypes::utils::setExpandedDefinitionForm(op, eop);
+      // also must consider the arguments
+      for (unsigned j = 0, nargs = c->getNumArgs(); j < nargs; ++j)
+      {
+        TypeNode tnc = c->getArgType(j);
+        if (tnc.isDatatype() && tnc.getDType().isSygus()
+            && processed.find(tnc) == processed.end())
+        {
+          toProcess.push_back(tnc);
+          processed.insert(tnc);
+        }
+      }
+    }
+  }
+}
+
 }  // namespace smt
 }  // namespace cvc5
index 82dfab3cccdd02dab661da304454b3f5583d3da0..3db61503762c4bf5b2d68d51a872e3cdc24c128c 100644 (file)
@@ -158,6 +158,18 @@ class SygusSolver
    * previously not stale.
    */
   void setSygusConjectureStale();
+  /**
+   * Expand definitions in sygus datatype tn, which ensures that all
+   * sygus constructors that are used to build values of sygus datatype
+   * tn are associated with their expanded definition form.
+   *
+   * This method is required at this level since sygus grammars may include
+   * user-defined functions. Thus, we must use the preprocessor here to
+   * associate the use of those functions with their expanded form, since
+   * the internal sygus solver must reason about sygus operators after
+   * expansion.
+   */
+  void expandDefinitionsSygusDt(TypeNode tn) const;
   /** Reference to the env class */
   Env& d_env;
   /** The SMT solver, which is used during checkSynth. */
index 72ddd7b0e79e1c2dea7a571290c74b19d7474694..f1f7b45a414444418f1a183653bf4bb636346e0c 100644 (file)
@@ -22,8 +22,6 @@
 #include "expr/node_algorithm.h"
 #include "expr/sygus_datatype.h"
 #include "smt/env.h"
-#include "smt/smt_engine.h"
-#include "smt/smt_engine_scope.h"
 #include "theory/evaluator.h"
 #include "theory/rewriter.h"
 
@@ -175,12 +173,9 @@ Node mkSygusTerm(const DType& dt,
       }
       else
       {
-        // Only expand definitions if the operator is not constant, since
-        // calling expandDefinitions on them should be a no-op. This check
-        // ensures we don't try to expand e.g. bitvector extract operators,
-        // whose type is undefined, and thus should not be passed to
-        // expandDefinitions.
-        opn = smt::currentSmtEngine()->expandDefinitions(op);
+        // Get the expanded definition form, if it has been marked. This ensures
+        // that user-defined functions have been eliminated from op.
+        opn = getExpandedDefinitionForm(op);
         opn = Rewriter::rewrite(opn);
         SygusOpRewrittenAttribute sora;
         op.setAttribute(sora, opn);
@@ -736,6 +731,28 @@ unsigned getSygusTermSize(Node n)
   return weight + sum;
 }
 
+/**
+ * Map terms to the result of expand definitions calling smt::expandDefinitions
+ * on it.
+ */
+struct SygusExpDefFormAttributeId
+{
+};
+typedef expr::Attribute<SygusExpDefFormAttributeId, Node>
+    SygusExpDefFormAttribute;
+
+void setExpandedDefinitionForm(Node op, Node eop)
+{
+  op.setAttribute(SygusExpDefFormAttribute(), eop);
+}
+
+Node getExpandedDefinitionForm(Node op)
+{
+  Node eop = op.getAttribute(SygusExpDefFormAttribute());
+  // if not set, assume original
+  return eop.isNull() ? op : eop;
+}
+
 }  // namespace utils
 }  // namespace datatypes
 }  // namespace theory
index 35672434c77f050b2ebdf4ae02cb769b01caafb2..5784fe34af9b4cae7828826ff15081f0927d2d09 100644 (file)
@@ -232,6 +232,22 @@ TypeNode substituteAndGeneralizeSygusType(TypeNode sdt,
  * in n.
  */
 unsigned getSygusTermSize(Node n);
+
+/**
+ * Set expanded definition form of sygus op to eop. This is called when
+ * we require associating a sygus operator op to its expanded form, which
+ * replaces user-defined functions with their definitions. This allows
+ * the utilities above to consider op to be the original form, which is
+ * printed in the final solution (see isExternal to sygusToBuiltin above),
+ * whereas the internal solver will reason about eop.
+ */
+void setExpandedDefinitionForm(Node op, Node eop);
+/**
+ * Get the expanded definition form of sygus operator op, returns the
+ * expanded form if the above method has been called for op, or returns op
+ * otherwise.
+ */
+Node getExpandedDefinitionForm(Node op);
 // ------------------------ end sygus utils
 
 }  // namespace utils
index f9aa5bdc3d5ac65b52335aaa8a326ee930567b7d..7a8ff0b1d9879960df46b2bc48f14e625ea8ac41 100644 (file)
@@ -190,6 +190,20 @@ void SygusTypeInfo::initialize(TermDbSygus* tds, TypeNode tn)
     {
       d_hasBoolConnective = true;
     }
+    if (Trace.isOn("sygus-db"))
+    {
+      Node eop = datatypes::utils::getExpandedDefinitionForm(sop);
+      Trace("sygus-db") << "Expanded form: ";
+      if (eop == sop)
+      {
+        Trace("sygus-db") << "same";
+      }
+      else
+      {
+        Trace("sygus-db") << eop;
+      }
+      Trace("sygus-db") << std::endl;
+    }
   }
   // compute minimum type depth information
   computeMinTypeDepthInternal(tn, 0);