Generalize interface for candidate rewrite database (#4797)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 3 Aug 2020 22:05:35 +0000 (17:05 -0500)
committerGitHub <noreply@github.com>
Mon, 3 Aug 2020 22:05:35 +0000 (17:05 -0500)
This class will be used as a utility in a new algorithm for solution reconstruction and requires a generalized interface.

FYI @abdoo8080

src/theory/quantifiers/candidate_rewrite_database.cpp
src/theory/quantifiers/candidate_rewrite_database.h
src/theory/quantifiers/expr_miner_manager.cpp

index 4593f36f146d43c12c1e182dc3bb695d1af386d8..835dc1dba96ac22c06bee3b2bc088e6c1f3941c1 100644 (file)
@@ -16,7 +16,6 @@
 
 #include "api/cvc4cpp.h"
 #include "options/base_options.h"
-#include "options/quantifiers_options.h"
 #include "printer/printer.h"
 #include "smt/smt_engine.h"
 #include "smt/smt_engine_scope.h"
@@ -33,12 +32,16 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
-CandidateRewriteDatabase::CandidateRewriteDatabase()
+CandidateRewriteDatabase::CandidateRewriteDatabase(bool doCheck,
+                                                   bool rewAccel,
+                                                   bool silent)
     : d_qe(nullptr),
       d_tds(nullptr),
       d_ext_rewrite(nullptr),
-      d_using_sygus(false),
-      d_silent(false)
+      d_doCheck(doCheck),
+      d_rewAccel(rewAccel),
+      d_silent(silent),
+      d_using_sygus(false)
 {
 }
 void CandidateRewriteDatabase::initialize(const std::vector<Node>& vars,
@@ -69,13 +72,13 @@ void CandidateRewriteDatabase::initializeSygus(const std::vector<Node>& vars,
   ExprMiner::initialize(vars, ss);
 }
 
-bool CandidateRewriteDatabase::addTerm(Node sol,
+Node CandidateRewriteDatabase::addTerm(Node sol,
                                        bool rec,
                                        std::ostream& out,
                                        bool& rew_print)
 {
   // have we added this term before?
-  std::unordered_map<Node, bool, NodeHashFunction>::iterator itac =
+  std::unordered_map<Node, Node, NodeHashFunction>::iterator itac =
       d_add_term_cache.find(sol);
   if (itac != d_add_term_cache.end())
   {
@@ -127,7 +130,7 @@ bool CandidateRewriteDatabase::addTerm(Node sol,
       bool verified = false;
       Trace("rr-check") << "Check candidate rewrite..." << std::endl;
       // verify it if applicable
-      if (options::sygusRewSynthCheck())
+      if (d_doCheck)
       {
         Node crr = solbr.eqNode(eq_solr).negate();
         Trace("rr-check") << "Check candidate rewrite : " << crr << std::endl;
@@ -177,8 +180,8 @@ bool CandidateRewriteDatabase::addTerm(Node sol,
           d_sampler->addSamplePoint(pt);
           // add the solution again
           // by construction of the above point, we should be unique now
-          Node eq_sol_new = d_sampler->registerTerm(sol);
-          Assert(eq_sol_new == sol);
+          eq_sol = d_sampler->registerTerm(sol);
+          Assert(eq_sol == sol);
         }
         else
         {
@@ -188,7 +191,11 @@ bool CandidateRewriteDatabase::addTerm(Node sol,
       else
       {
         // just insist that constants are not relevant pairs
-        is_unique_term = solb.isConst() && eq_solb.isConst();
+        if (solb.isConst() && eq_solb.isConst())
+        {
+          is_unique_term = true;
+          eq_sol = sol;
+        }
       }
       if (!is_unique_term)
       {
@@ -222,7 +229,7 @@ bool CandidateRewriteDatabase::addTerm(Node sol,
           Trace("sygus-rr-debug")
               << "; candidate #2 ext-rewrites to: " << eq_solr << std::endl;
         }
-        if (options::sygusRewSynthAccel() && d_using_sygus)
+        if (d_rewAccel && d_using_sygus)
         {
           Assert(d_tds != nullptr);
           // Add a symmetry breaking clause that excludes the larger
@@ -258,18 +265,19 @@ bool CandidateRewriteDatabase::addTerm(Node sol,
     // it discards it as a redundant candidate rewrite rule before
     // checking its correctness.
   }
-  d_add_term_cache[sol] = is_unique_term;
-  return is_unique_term;
+  d_add_term_cache[sol] = eq_sol;
+  return eq_sol;
 }
 
-bool CandidateRewriteDatabase::addTerm(Node sol, bool rec, std::ostream& out)
+Node CandidateRewriteDatabase::addTerm(Node sol, bool rec, std::ostream& out)
 {
   bool rew_print = false;
   return addTerm(sol, rec, out, rew_print);
 }
 bool CandidateRewriteDatabase::addTerm(Node sol, std::ostream& out)
 {
-  return addTerm(sol, false, out);
+  Node rsol = addTerm(sol, false, out);
+  return sol == rsol;
 }
 
 void CandidateRewriteDatabase::setSilent(bool flag) { d_silent = flag; }
index 9173a654b544b3b15de6f80522d768a413561512..d64ba8a9925250a0d4ccb6dd1094cfd93c7a196c 100644 (file)
@@ -44,7 +44,17 @@ namespace quantifiers {
 class CandidateRewriteDatabase : public ExprMiner
 {
  public:
-  CandidateRewriteDatabase();
+  /**
+   * Constructor
+   * @param doCheck Whether to check rewrite rules using subsolvers.
+   * @param rewAccel Whether to construct symmetry breaking lemmas based on
+   * discovered rewrites (see option sygusRewSynthAccel()).
+   * @param silent Whether to silence the output of rewrites discovered by this
+   * class.
+   */
+  CandidateRewriteDatabase(bool doCheck,
+                           bool rewAccel = false,
+                           bool silent = false);
   ~CandidateRewriteDatabase() {}
   /**  Initialize this class */
   void initialize(const std::vector<Node>& var,
@@ -69,15 +79,22 @@ class CandidateRewriteDatabase : public ExprMiner
    *
    * Notifies this class that the solution sol was enumerated. This may
    * cause a candidate-rewrite to be printed on the output stream out.
-   * We return true if the term sol is distinct (up to equivalence) with
-   * all previous terms added to this class. The argument rew_print is set to
-   * true if this class printed a rewrite involving sol.
    *
-   * If the flag rec is true, then we also recursively add all subterms of sol
+   * @param sol The term to add to this class.
+   * @param rec If true, then we also recursively add all subterms of sol
    * to this class as well.
+   * @param out The stream to output rewrite rules on.
+   * @param rew_print Set to true if this class printed a rewrite involving sol.
+   * @return A previous term eq_sol added to this class, such that sol is
+   * equivalent to eq_sol based on the criteria used by this class.
+   */
+  Node addTerm(Node sol, bool rec, std::ostream& out, bool& rew_print);
+  Node addTerm(Node sol, bool rec, std::ostream& out);
+  /**
+   * Same as above, returns true if the return value of addTerm was equal to
+   * sol, in other words, sol was a new unique term. This assumes false for
+   * the argument rec.
    */
-  bool addTerm(Node sol, bool rec, std::ostream& out, bool& rew_print);
-  bool addTerm(Node sol, bool rec, std::ostream& out);
   bool addTerm(Node sol, std::ostream& out) override;
   /** sets whether this class should output candidate rewrites it finds */
   void setSilent(bool flag);
@@ -93,14 +110,21 @@ class CandidateRewriteDatabase : public ExprMiner
   ExtendedRewriter* d_ext_rewrite;
   /** the function-to-synthesize we are testing (if sygus) */
   Node d_candidate;
+  /** whether we are checking equivalence using subsolver */
+  bool d_doCheck;
+  /**
+   * If true, we use acceleration for symmetry breaking rewrites (see option
+   * sygusRewSynthAccel()).
+   */
+  bool d_rewAccel;
+  /** if true, we silence the output of candidate rewrites */
+  bool d_silent;
   /** whether we are using sygus */
   bool d_using_sygus;
   /** candidate rewrite filter */
   CandidateRewriteFilter d_crewrite_filter;
   /** the cache for results of addTerm */
-  std::unordered_map<Node, bool, NodeHashFunction> d_add_term_cache;
-  /** if true, we silence the output of candidate rewrites */
-  bool d_silent;
+  std::unordered_map<Node, Node, NodeHashFunction> d_add_term_cache;
 };
 
 } /* CVC4::theory::quantifiers namespace */
index f99b06567e85d771c36c34977da94fe114758f2f..36f152508a6a9fb554c8e90457436ca59a8dd2a6 100644 (file)
@@ -27,7 +27,8 @@ ExpressionMinerManager::ExpressionMinerManager()
       d_doFilterLogicalStrength(false),
       d_use_sygus_type(false),
       d_qe(nullptr),
-      d_tds(nullptr)
+      d_tds(nullptr),
+      d_crd(options::sygusRewSynthCheck(), options::sygusRewSynthAccel(), false)
 {
 }
 
@@ -142,7 +143,8 @@ bool ExpressionMinerManager::addTerm(Node sol,
   bool ret = true;
   if (d_doRewSynth)
   {
-    ret = d_crd.addTerm(sol, options::sygusRewSynthRec(), out, rew_print);
+    Node rsol = d_crd.addTerm(sol, options::sygusRewSynthRec(), out, rew_print);
+    ret = (sol == rsol);
   }
 
   // a unique term, let's try the query generator