Merge equivalent sub-obligations instead of discarding them. (#6353)
authorAbdalrhman Mohamed <32971963+abdoo8080@users.noreply.github.com>
Wed, 14 Apr 2021 14:36:58 +0000 (07:36 -0700)
committerGitHub <noreply@github.com>
Wed, 14 Apr 2021 14:36:58 +0000 (14:36 +0000)
This PR modifies the behavior of the reconstruction algorithm when the term to reconstruct contains two or more equivalent sub-terms, but one is easier to reconstruct than the others. Since we do not know which one is easier to reconstruct by matching, we match against all sub-terms. If a solution is found for one sub-term, we use it to solve the others.

src/theory/quantifiers/sygus/rcons_obligation_info.cpp
src/theory/quantifiers/sygus/rcons_obligation_info.h
src/theory/quantifiers/sygus/sygus_reconstruct.cpp
src/theory/quantifiers/sygus/sygus_reconstruct.h
test/regress/CMakeLists.txt
test/regress/regress1/sygus/eq-sub-obs.sy [new file with mode: 0644]

index 0a35f43fcb53c84cc00b9a566234f20d2d6481c1..8a8fcf64b12b9cda7223fce8ee8d71f202fdd520 100644 (file)
@@ -15,6 +15,8 @@
 
 #include "rcons_obligation_info.h"
 
+#include <sstream>
+
 #include "expr/node_algorithm.h"
 #include "theory/datatypes/sygus_datatype_utils.h"
 
@@ -22,15 +24,26 @@ namespace cvc5 {
 namespace theory {
 namespace quantifiers {
 
-RConsObligationInfo::RConsObligationInfo(Node builtin) : d_builtin(builtin) {}
+RConsObligationInfo::RConsObligationInfo(Node builtin) : d_builtins({builtin})
+{
+}
 
-Node RConsObligationInfo::getBuiltin() const { return d_builtin; }
+const std::unordered_set<Node, NodeHashFunction>&
+RConsObligationInfo::getBuiltins() const
+{
+  return d_builtins;
+}
 
 void RConsObligationInfo::addCandidateSolution(Node candSol)
 {
   d_candSols.emplace(candSol);
 }
 
+void RConsObligationInfo::addBuiltin(Node builtin)
+{
+  d_builtins.emplace(builtin);
+}
+
 const std::unordered_set<Node, NodeHashFunction>&
 RConsObligationInfo::getCandidateSolutions() const
 {
@@ -51,8 +64,19 @@ RConsObligationInfo::getWatchSet() const
 std::string RConsObligationInfo::obToString(Node k,
                                             const RConsObligationInfo& obInfo)
 {
-  return "ob<" + obInfo.getBuiltin().toString() + ", " + k.getType().toString()
-         + ">";
+  std::stringstream ss;
+  ss << "([";
+  std::unordered_set<Node, NodeHashFunction>::const_iterator it =
+      obInfo.getBuiltins().cbegin();
+  ss << *it;
+  ++it;
+  while (it != obInfo.getBuiltins().cend())
+  {
+    ss << ", " << *it;
+    ++it;
+  }
+  ss << "]), " << k.getType() << ')' << std::endl;
+  return ss.str();
 }
 
 void RConsObligationInfo::printCandSols(
index c96d3d738d3a9e13bfe55bc36fcf249deed5b171..80bb207a59d5aa0c821d1b47c860cb9d9ead1dff 100644 (file)
@@ -45,9 +45,19 @@ class RConsObligationInfo
   explicit RConsObligationInfo(Node builtin = Node::null());
 
   /**
-   * @return builtin term to reconstruct for this class' obligation
+   * Add `builtin` to the set of equivalent builtins this class' obligation
+   * solves.
+   *
+   * \note `builtin` MUST be equivalent to the builtin terms in `d_builtins`
+   *
+   * @param builtin builtin term to add
+   */
+  void addBuiltin(Node builtin);
+
+  /**
+   * @return equivalent builtin terms to reconstruct for this class' obligation
    */
-  Node getBuiltin() const;
+  const std::unordered_set<Node, NodeHashFunction>& getBuiltins() const;
 
   /**
    * Add candidate solution to the set of candidate solutions for the
@@ -114,12 +124,12 @@ class RConsObligationInfo
           obInfo);
 
  private:
-  /** Builtin term for this class' obligation.
+  /** Equivalent builtin terms for this class' obligation.
    *
-   * To solve the obligation, this builtin term must be reconstructed in the
-   * specified grammar (sygus datatype type) of this class' obligation.
+   * To solve the obligation, one of these builtin terms must be reconstructed
+   * in the specified grammar (sygus datatype type) of the obligation.
    */
-  Node d_builtin;
+  std::unordered_set<Node, NodeHashFunction> d_builtins;
   /** A set of candidate solutions to this class' obligation.
    *
    * Each candidate solution is a sygus datatype term containing skolem subterms
index 0fe3032ffcd15d9de7ffdb5647c656592a6d16c4..1321ad8793c665c0a0ee2e5b7126846f2d094f19 100644 (file)
@@ -215,88 +215,109 @@ TypeObligationSetMap SygusReconstruct::matchNewObs(Node k, Node sz)
   // terms. So, we add redundant substitutions
   candObs.insert(d_sygusVars.cbegin(), d_sygusVars.cend());
 
-  // try to match the obligation's builtin term with the pattern sz
-  if (expr::match(Rewriter::rewrite(datatypes::utils::sygusToBuiltin(sz)),
-                  d_obInfo[k].getBuiltin(),
-                  candObs))
+  // try to match the obligation's builtin terms with the pattern sz
+  for (Node builtin : d_obInfo[k].getBuiltins())
   {
-    // the bound variables z generated by the enumerators are reused across
-    // enumerated terms, so we need to replace them with our own skolems
-    std::vector<std::pair<Node, Node>> subs;
-    Trace("sygus-rcons") << "-- ct: " << sz << std::endl;
-    // remove redundant substitutions
-    for (const std::pair<const Node, Node>& pair : d_sygusVars)
+    if (expr::match(Rewriter::rewrite(datatypes::utils::sygusToBuiltin(sz)),
+                    builtin,
+                    candObs))
     {
-      candObs.erase(pair.first);
-    }
-    // for each candidate obligation
-    for (const std::pair<const Node, Node>& candOb : candObs)
-    {
-      TypeNode stn =
-          datatypes::utils::builtinVarToSygus(candOb.first).getType();
-      Node newVar;
-      // have we come across a similar obligation before?
-      Node rep = d_stnInfo[stn].addTerm(candOb.second);
-      if (!d_stnInfo[stn].builtinToOb(rep).isNull())
+      // the bound variables z generated by the enumerators are reused across
+      // enumerated terms, so we need to replace them with our own skolems
+      std::vector<std::pair<Node, Node>> subs;
+      Trace("sygus-rcons") << "-- ct: " << sz << std::endl;
+      // remove redundant substitutions
+      for (const std::pair<const Node, Node>& pair : d_sygusVars)
       {
-        // if so, use the original obligation
-        newVar = d_stnInfo[stn].builtinToOb(rep);
+        candObs.erase(pair.first);
       }
-      else
+      // for each candidate obligation
+      for (const std::pair<const Node, Node>& candOb : candObs)
       {
-        // otherwise, create a new obligation of the corresponding sygus type
-        newVar = sm->mkDummySkolem("sygus_rcons", stn);
-        d_obInfo.emplace(newVar, candOb.second);
-        d_stnInfo[stn].setBuiltinToOb(candOb.second, newVar);
-        // if the candidate obligation is a constant and the grammar allows
-        // random constants
-        if (candOb.second.isConst()
-            && k.getType().getDType().getSygusAllowConst())
+        TypeNode stn =
+            datatypes::utils::builtinVarToSygus(candOb.first).getType();
+        Node newVar;
+        // did we come across an equivalent obligation before?
+        Node rep = d_stnInfo[stn].addTerm(candOb.second);
+        Node repOb = d_stnInfo[stn].builtinToOb(rep);
+        if (!repOb.isNull())
         {
-          // then immediately solve the obligation
-          markSolved(newVar, d_tds->getProxyVariable(stn, candOb.second));
+          // if so, use the original obligation
+          newVar = repOb;
+          // while `candOb.second` is equivalent to `rep`, it may be easier to
+          // reconstruct than `rep`. For example:
+          //
+          // Grammar: S -> p | q | (not S) | (and S S) | (or S S)
+          // rep = (= p q)
+          // candOb.second = (or (and p q) (and (not p) (not q)))
+          //
+          // In this case, `candOb.second` is easy to reconstruct by matching
+          // because it only uses operators that are already in the grammar.
+          // `rep`, on the other hand, is cannot be reconstructed by matching
+          // and can only be solved by enumeration (currently).
+          //
+          // At this point, we do not know which one is easier to reconstruct by
+          // matching, so we add `candOb.second` to the set of equivalent
+          // builtin terms corresponding to `k` and match against both terms.
+          d_obInfo[repOb].addBuiltin(candOb.second);
+          d_stnInfo[stn].setBuiltinToOb(candOb.second, repOb);
         }
         else
         {
-          // otherwise, add this candidate obligation to this list of
-          // obligations
-          obsPrime[stn].emplace(newVar);
+          // otherwise, create a new obligation of the corresponding sygus type
+          newVar = sm->mkDummySkolem("sygus_rcons", stn);
+          d_obInfo.emplace(newVar, candOb.second);
+          d_stnInfo[stn].setBuiltinToOb(candOb.second, newVar);
+          // if the candidate obligation is a constant and the grammar allows
+          // random constants
+          if (candOb.second.isConst()
+              && k.getType().getDType().getSygusAllowConst())
+          {
+            // then immediately solve the obligation
+            markSolved(newVar, d_tds->getProxyVariable(stn, candOb.second));
+          }
+          else
+          {
+            // otherwise, add this candidate obligation to this list of
+            // obligations
+            obsPrime[stn].emplace(newVar);
+          }
         }
+        subs.emplace_back(datatypes::utils::builtinVarToSygus(candOb.first),
+                          newVar);
       }
-      subs.emplace_back(datatypes::utils::builtinVarToSygus(candOb.first),
-                        newVar);
-    }
-    // replace original free vars in sz with new ones
-    if (!subs.empty())
-    {
-      sz = sz.substitute(subs.cbegin(), subs.cend());
-    }
-    // sz is solved if it has no sub-obligations or if all of them are solved
-    bool isSolved = true;
-    for (const std::pair<Node, Node>& sub : subs)
-    {
-      if (d_sol[sub.second].isNull())
+      // replace original free vars in sz with new ones
+      if (!subs.empty())
       {
-        isSolved = false;
-        d_subObs[sz].push_back(sub.second);
+        sz = sz.substitute(subs.cbegin(), subs.cend());
+      }
+      // sz is solved if it has no sub-obligations or if all of them are solved
+      bool isSolved = true;
+      for (const std::pair<Node, Node>& sub : subs)
+      {
+        if (d_sol[sub.second].isNull())
+        {
+          isSolved = false;
+          d_subObs[sz].push_back(sub.second);
+        }
       }
-    }
 
-    if (isSolved)
-    {
-      // As it traverses sz, substitute populates its input cache with TNodes
-      // that are not preserved by this module and maybe destroyed after the
-      // method call. To avoid referencing those unsafe TNodes throughout this
-      // module, we pass a iterators of d_sol instead.
-      Node s = sz.substitute(d_sol.cbegin(), d_sol.cend());
-      markSolved(k, s);
-    }
-    else
-    {
-      // add sz as a possible solution to obligation k
-      d_obInfo[k].addCandidateSolution(sz);
-      d_parentOb[sz] = k;
-      d_obInfo[d_subObs[sz].back()].addCandidateSolutionToWatchSet(sz);
+      if (isSolved)
+      {
+        // As it traverses sz, substitute populates its input cache with TNodes
+        // that are not preserved by this module and maybe destroyed after the
+        // method call. To avoid referencing those unsafe TNodes throughout this
+        // module, we pass a iterators of d_sol instead.
+        Node s = sz.substitute(d_sol.cbegin(), d_sol.cend());
+        markSolved(k, s);
+      }
+      else
+      {
+        // add sz as a possible solution to obligation k
+        d_obInfo[k].addCandidateSolution(sz);
+        d_parentOb[sz] = k;
+        d_obInfo[d_subObs[sz].back()].addCandidateSolutionToWatchSet(sz);
+      }
     }
   }
 
index bc3b3d476b0e99daaea2bbd36e6d88b51ed3a81a..af3a240079a7210518bf52225548cd63bc8f83d4 100644 (file)
@@ -43,12 +43,15 @@ using TypeObligationSetMap =
  * rcons(t_0, T_0) returns g
  * {
  *   Obs: A map from sygus types T to a set of triples to reconstruct into T,
- *        where each triple is of the form (k, t, s), where k is a skolem of
- *        type T, t is a builtin term of the type encoded by T, and s is a
- *        possibly null sygus term of type T representing the solution.
+ *        where each triple is of the form (k, ts, s), where k is a skolem of
+ *        type T, ts is a set of builtin terms of the type encoded by T, and s
+ *        is a possibly null sygus term of type T representing the solution.
  *
- *   Sol: A map from skolems k to solutions s in the triples (k, t, s). That is,
- *        Sol[k] = s.
+ *   Sol: A map from skolems k to solutions s in the triples (k, ts, s). That
+ *        is, Sol[k] = s.
+ *
+ *   Terms: A map from skolems k to a set of builtin terms in the triples
+ *          (k, ts, s). That is, Terms[k] = ts
  *
  *   CandSols : A map from a skolem k to a set of possible solutions for its
  *              corresponding obligation. Whenever there is a successful match,
@@ -59,51 +62,59 @@ using TypeObligationSetMap =
  *          for matching against the terms to reconstruct t in (k, t, s).
  *
  *   let k_0 be a fresh skolem of sygus type T_0
- *   Obs[T_0] += (k_0, t_0, null)
+ *   Obs[T_0] += (k_0, [t_0], null)
  *
  *   while Sol[k_0] == null
  *     Obs' = {} // map from T to sets of triples pending addition to Obs
  *     // enumeration phase
  *     for each subfield type T of T_0
- *       // enumerated terms may contain variables z ranging over all terms of
+ *       // enumerated terms may contain variables zs ranging over all terms of
  *       // their type (subfield types of T_0)
- *       s[z] = nextEnum(T)
- *       builtin = rewrite(toBuiltIn(s[z]))
- *       if (s[z] is ground)
+ *       s[zs] = nextEnum(T)
+ *       if (s[zs] is ground)
+ *         builtin = rewrite(toBuiltIn(s[zs]))
  *         // let X be the theory the solver is invoked with
- *         find (k, t, s) in Obs[T] s.t. |=_X t = builtin
+ *         find (k, ts, s) in Obs[T] s.t. |=_X ts[0] = builtin
  *         if no such triple exists
  *           let k be a new variable of type : T
- *           Obs[T] += (k, builtin, null)
- *         markSolved(k, s[z])
+ *           Obs[T] += (k, [builtin], null)
+ *         markSolved(k, s[zs])
  *       else if no s' in Pool[T] and matcher sigma s.t.
- *             rewrite(toBuiltIn(s')) * sigma = builtin
- *         Pool[T] += s[z]
- *         for each (k, t, null) in Obs[T]
- *           Obs' += matchNewObs(k, s[z])
+ *             rewrite(toBuiltIn(s')) * sigma = rewrite(toBuiltIn(s[zs]))
+ *         Pool[T] += s[zs]
+ *         for each (k, ts, null) in Obs[T]
+ *           Obs' += matchNewObs(k, s[zs])
  *     // match phase
  *     while Obs' != {}
  *       Obs'' = {}
- *       for each (k, t, null) in Obs' // s = null for all triples in Obs'
- *         Obs[T] += (k, t, null)
- *         for each s[z] in Pool[T]
- *           Obs'' += matchNewObs(k, s[z])
+ *       for each (k, ts, null) in Obs' // s = null for all triples in Obs'
+ *         Obs[T] += (k, ts, null)
+ *         for each s[zs] in Pool[T]
+ *           Obs'' += matchNewObs(k, s[zs])
  *       Obs' = Obs''
  *   g = Sol[k_0]
  *   instantiate free variables of g with arbitrary sygus datatype values
  * }
  *
- * matchNewObs(k, s[z]) returns Obs'
+ * matchNewObs(k, s[zs]) returns Obs'
  * {
- *   u = rewrite(toBuiltIn(s[z]))
- *   if match(u, t) == {toBuiltin(z) -> t'}
- *     // let X be the theory the solver is invoked with
- *     if forall t' exists (k'', t'', s'') in Obs[T] s.t. |=_X t'' = t'
- *       markSolved(k, s{z -> s''})
- *     else
- *       let k' be a new variable of type : typeOf(z)
- *       CandSol[k] += s{z -> k'}
- *       Obs'[typeOf(z)] += (k', t', null)
+ *   u = rewrite(toBuiltIn(s[zs]))
+ *   for each t in Terms[k]
+ *     if match(u, t) == {toBuiltin(zs) -> sts}
+ *       Sub = {} // substitution map from zs to corresponding new vars ks
+ *       for each (z, st) in {zs -> sts}
+ *         // let X be the theory the solver is invoked with
+ *         if exists (k', ts', s') in Obs[T] !=_X ts'[0] = st
+ *           ts' += st
+ *           Sub[z] = k'
+ *         else
+ *           let sk be a new variable of type : typeOf(z)
+ *           Sub[z] = sk
+ *           Obs'[typeOf(z)] += (sk, [st], null)
+ *       if Sol[sk] != null forall (z, sk) in Sub
+ *         markSolved(k, s{Sub})
+ *       else
+ *         CandSol[k] += s{Sub}
  * }
  *
  * markSolved(k, s)
index 5c3ceec217c7a6d9f153e78ff39e64ab499efa3d..710c06f969e3a91118ab73a45301e20339012053 100644 (file)
@@ -2170,6 +2170,7 @@ set(regress_1_tests
   regress1/sygus/double.sy
   regress1/sygus/dt-test-ns.sy
   regress1/sygus/dup-op.sy
+  regress1/sygus/eq-sub-obs.sy
   regress1/sygus/error1-dt.sy
   regress1/sygus/eval-uc.sy
   regress1/sygus/extract.sy
diff --git a/test/regress/regress1/sygus/eq-sub-obs.sy b/test/regress/regress1/sygus/eq-sub-obs.sy
new file mode 100644 (file)
index 0000000..4c24a49
--- /dev/null
@@ -0,0 +1,22 @@
+; COMMAND-LINE: --sygus-si=all  --sygus-out=status
+; EXPECT: unsat
+
+; This regression tests the behavior of the reconstruction algorithm when the
+; term to reconstruct contains two equivalent sub-terms, but one is easier to
+; reconstruct than the other.
+
+(set-logic UF)
+
+(synth-fun f ((p Bool) (q Bool) (r Bool)) Bool
+  ((Start Bool))
+  ((Start Bool (true false p q r (not Start) (and Start Start) (or Start Start)))))
+
+(define-fun eqReduce ((p Bool) (q Bool)) Bool (or (and p q) (and (not p) (not q))))
+
+(declare-var p Bool)
+(declare-var q Bool)
+(declare-var r Bool)
+
+(constraint (= (f p q r) (and (= (and p q) (and q r)) (eqReduce (and p q) (and q r)))))
+
+(check-synth)