Improve accuracy of stats for sygus sampler (#1755)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 10 Apr 2018 21:44:02 +0000 (16:44 -0500)
committerGitHub <noreply@github.com>
Tue, 10 Apr 2018 21:44:02 +0000 (16:44 -0500)
src/smt/smt_engine.cpp
src/theory/quantifiers/dynamic_rewrite.cpp
src/theory/quantifiers/dynamic_rewrite.h
src/theory/quantifiers/sygus/ce_guided_conjecture.cpp
src/theory/quantifiers/sygus_sampler.cpp
src/theory/quantifiers/sygus_sampler.h

index dd8a2e50264a2338e94658b6e62833aa78810499..029fb84c9642579b0221faf8d05c7b1e39949a94 100644 (file)
@@ -1481,8 +1481,11 @@ void SmtEngine::setDefaults() {
     options::boolToBitvector.set(false);
   }
 
-  if(options::produceAssignments() && !options::produceModels()) {
-    Notice() << "SmtEngine: turning on produce-models to support produce-assignments" << endl;
+  // cases where we need produce models
+  if (!options::produceModels()
+      && (options::produceAssignments() || options::sygusRewSynthCheck()))
+  {
+    Notice() << "SmtEngine: turning on produce-models" << endl;
     setOption("produce-models", SExpr("true"));
   }
 
index cb7379910318ffd2b2245333c9aea0d2c44ffaf7..00e527abad243a7a8ac87dd8e7d4d4427e363336 100644 (file)
@@ -31,39 +31,26 @@ DynamicRewriter::DynamicRewriter(const std::string& name, QuantifiersEngine* qe)
   d_equalityEngine.addFunctionKind(kind::APPLY_UF);
 }
 
-bool DynamicRewriter::addRewrite(Node a, Node b)
+void DynamicRewriter::addRewrite(Node a, Node b)
 {
   Trace("dyn-rewrite") << "Dyn-Rewriter : " << a << " == " << b << std::endl;
   if (a == b)
   {
-    Trace("dyn-rewrite") << "...fail, equal." << std::endl;
-    return false;
+    Trace("dyn-rewrite") << "...equal." << std::endl;
+    return;
   }
 
   // add to the equality engine
   Node ai = toInternal(a);
   Node bi = toInternal(b);
   Trace("dyn-rewrite-debug") << "Internal : " << ai << " " << bi << std::endl;
-  d_equalityEngine.addTerm(ai);
-  d_equalityEngine.addTerm(bi);
-
-  Trace("dyn-rewrite-debug") << "get reps..." << std::endl;
-  // may already be equal by congruence
-  Node air = d_equalityEngine.getRepresentative(ai);
-  Node bir = d_equalityEngine.getRepresentative(bi);
-  Trace("dyn-rewrite-debug") << "Reps : " << air << " " << bir << std::endl;
-  if (air == bir)
-  {
-    Trace("dyn-rewrite") << "...fail, congruent." << std::endl;
-    return false;
-  }
 
   Trace("dyn-rewrite-debug") << "assert eq..." << std::endl;
   Node eq = ai.eqNode(bi);
   d_rewrites.push_back(eq);
   d_equalityEngine.assertEquality(eq, true, eq);
+  Assert(d_equalityEngine.consistent());
   Trace("dyn-rewrite-debug") << "Finished" << std::endl;
-  return true;
 }
 
 bool DynamicRewriter::areEqual(Node a, Node b)
index 388173829baa97d0c6264b504d048870b71ddcf8..56f59147072085d1f28b43c06c2db66689a92cb7 100644 (file)
@@ -57,12 +57,8 @@ class DynamicRewriter
  public:
   DynamicRewriter(const std::string& name, QuantifiersEngine* qe);
   ~DynamicRewriter() {}
-  /** inform this class that the equality a = b holds.
-   *
-   * This function returns true if this class did not already know that
-   * a = b based on the previous equalities it has seen.
-   */
-  bool addRewrite(Node a, Node b);
+  /** inform this class that the equality a = b holds. */
+  void addRewrite(Node a, Node b);
   /**
    * Check whether this class knows that the equality a = b holds.
    */
index 98855a7cac9004cdc32b8c705456a0d1789cbd7c..046c5724e174c9fd1242e9aad4f8a8db91f7551e 100644 (file)
@@ -582,22 +582,11 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
       ss << prog;
       std::string f(ss.str());
       f.erase(f.begin());
-      out << "(define-fun " << f << " ";
-      if( dt.getSygusVarList().isNull() ){
-        out << "() ";
-      }else{
-        out << dt.getSygusVarList() << " ";
-      }
-      out << dt.getSygusType() << " ";
-      if( status==0 ){
-        out << sol;
-      }else{
-        Printer::getPrinter(options::outputLanguage())->toStreamSygus(out, sol);
-      }
-      out << ")" << std::endl;
       CegInstantiation* cei = d_qe->getCegInstantiation();
       ++(cei->d_statistics.d_solutions);
 
+      bool is_unique_term = true;
+
       if (status != 0 && options::sygusRewSynth())
       {
         TermDbSygus* sygusDb = d_qe->getTermDatabaseSygus();
@@ -612,9 +601,10 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
         // eq_sol is a candidate solution that is equivalent to sol
         if (eq_sol != sol)
         {
-          ++(cei->d_statistics.d_candidate_rewrites);
+          is_unique_term = false;
           // if eq_sol is null, then we have an uninteresting candidate rewrite,
           // e.g. one that is alpha-equivalent to another.
+          bool success = true;
           if (!eq_sol.isNull())
           {
             ExtendedRewriter* er = sygusDb->getExtRewriter();
@@ -622,12 +612,11 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
             Node solbr = er->extendedRewrite(solb);
             Node eq_solb = sygusDb->sygusToBuiltin(eq_sol);
             Node eq_solr = er->extendedRewrite(eq_solb);
-            bool success = true;
             bool verified = false;
+            Trace("rr-check") << "Check candidate rewrite..." << std::endl;
             // verify it if applicable
             if (options::sygusRewSynthCheck())
             {
-              Trace("rr-check") << "Check candidate rewrite..." << std::endl;
               // Notice we don't set produce-models. rrChecker takes the same
               // options as the SmtEngine we belong to, where we ensure that
               // produce-models is set.
@@ -645,6 +634,7 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
                 Trace("rr-check")
                     << "...rewrite does not hold for: " << std::endl;
                 success = false;
+                is_unique_term = true;
                 std::vector<Node> vars;
                 d_sampler[prog].getVariables(vars);
                 std::vector<Node> pt;
@@ -665,11 +655,18 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
                 verified = true;
               }
             }
+            else
+            {
+              // just insist that constants are not relevant pairs
+              success = !solb.isConst() || !eq_solb.isConst();
+            }
             if (success)
             {
-              // The analog of terms sol and eq_sol are equivalent under sample
-              // points but do not rewrite to the same term. Hence, this
-              // indicates a candidate rewrite.
+              // register this as a relevant pair (helps filtering)
+              d_sampler[prog].registerRelevantPair(sol, eq_sol);
+              // The analog of terms sol and eq_sol are equivalent under
+              // sample points but do not rewrite to the same term. Hence,
+              // this indicates a candidate rewrite.
               Printer* p = Printer::getPrinter(options::outputLanguage());
               out << "(" << (verified ? "" : "candidate-") << "rewrite ";
               p->toStreamSygus(out, sol);
@@ -711,7 +708,35 @@ void CegConjecture::printSynthSolution( std::ostream& out, bool singleInvocation
               }
             }
           }
+          // we count this as a rewrite if we did not explicitly rule it out
+          if (success)
+          {
+            ++(cei->d_statistics.d_candidate_rewrites);
+          }
+        }
+      }
+      if (is_unique_term)
+      {
+        out << "(define-fun " << f << " ";
+        if (dt.getSygusVarList().isNull())
+        {
+          out << "() ";
+        }
+        else
+        {
+          out << dt.getSygusVarList() << " ";
+        }
+        out << dt.getSygusType() << " ";
+        if (status == 0)
+        {
+          out << sol;
+        }
+        else
+        {
+          Printer::getPrinter(options::outputLanguage())
+              ->toStreamSygus(out, sol);
         }
+        out << ")" << std::endl;
       }
     }
   }
index f9ae0b553dd34957fa932cf53363fb8f38d544c5..f15c1199c8a106796622cb421ee748f7ad3150f5 100644 (file)
@@ -14,7 +14,9 @@
 
 #include "theory/quantifiers/sygus_sampler.h"
 
+#include "options/base_options.h"
 #include "options/quantifiers_options.h"
+#include "printer/printer.h"
 #include "util/bitvector.h"
 #include "util/random.h"
 
@@ -700,13 +702,13 @@ void SygusSamplerExt::initializeSygusExt(QuantifiersEngine* qe,
 Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
 {
   Node eq_n = SygusSampler::registerTerm(n, forceKeep);
-  Trace("sygus-synth-rr") << "sygusSampleExt : " << n << "..." << eq_n
-                          << std::endl;
   if (eq_n == n)
   {
     // this is a unique term
     return n;
   }
+  Trace("sygus-synth-rr") << "sygusSampleExt : " << n << "..." << eq_n
+                          << std::endl;
   Node bn = n;
   Node beq_n = eq_n;
   if (d_use_sygus_type)
@@ -727,7 +729,7 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
     if (!d_match_trie.getMatches(bn, &d_ssenm))
     {
       keep = false;
-      Trace("sygus-synth-rr-debug") << "...redundant (matchable)" << std::endl;
+      Trace("sygus-synth-rr") << "...redundant (matchable)" << std::endl;
     }
   }
 
@@ -735,39 +737,64 @@ Node SygusSamplerExt::registerTerm(Node n, bool forceKeep)
   if (d_drewrite != nullptr)
   {
     Trace("sygus-synth-rr-debug") << "Add rewrite pair..." << std::endl;
-    if (!d_drewrite->addRewrite(bn, beq_n))
+    if (d_drewrite->areEqual(bn, beq_n))
     {
       // must be unique according to the dynamic rewriter
+      Trace("sygus-synth-rr") << "...redundant (rewritable)" << std::endl;
       keep = false;
-      Trace("sygus-synth-rr-debug") << "...redundant (rewritable)" << std::endl;
     }
   }
 
   if (keep)
   {
-    // add to match information
-    for (unsigned r = 0; r < 2; r++)
-    {
-      Node t = r == 0 ? bn : beq_n;
-      Node to = r == 0 ? beq_n : bn;
-      // insert in match trie if first time
-      if (d_pairs.find(t) == d_pairs.end())
-      {
-        Trace("sse-match") << "SSE add term : " << t << std::endl;
-        d_match_trie.addTerm(t);
-      }
-      d_pairs[t].insert(to);
-    }
     return eq_n;
   }
-  else if (Trace.isOn("sygus-synth-rr"))
+  Trace("sygus-synth-rr") << "Redundant pair : " << eq_n << " " << n;
+  Trace("sygus-synth-rr") << std::endl;
+  if (Trace.isOn("sygus-rr-filter"))
   {
-    Trace("sygus-synth-rr") << "Redundant pair : " << eq_n << " " << n;
-    Trace("sygus-synth-rr") << std::endl;
+    Printer* p = Printer::getPrinter(options::outputLanguage());
+    std::stringstream ss;
+    ss << "(redundant-rewrite ";
+    p->toStreamSygus(ss, n);
+    ss << " ";
+    p->toStreamSygus(ss, eq_n);
+    Trace("sygus-rr-filter") << ss.str() << std::endl;
   }
   return Node::null();
 }
 
+void SygusSamplerExt::registerRelevantPair(Node n, Node eq_n)
+{
+  Node bn = n;
+  Node beq_n = eq_n;
+  if (d_use_sygus_type)
+  {
+    bn = d_tds->sygusToBuiltin(n);
+    beq_n = d_tds->sygusToBuiltin(eq_n);
+  }
+  // ----- check rewriting redundancy
+  if (d_drewrite != nullptr)
+  {
+    Trace("sygus-synth-rr-debug") << "Add rewrite pair..." << std::endl;
+    Assert(!d_drewrite->areEqual(bn, beq_n));
+    d_drewrite->addRewrite(bn, beq_n);
+  }
+  // add to match information
+  for (unsigned r = 0; r < 2; r++)
+  {
+    Node t = r == 0 ? bn : beq_n;
+    Node to = r == 0 ? beq_n : bn;
+    // insert in match trie if first time
+    if (d_pairs.find(t) == d_pairs.end())
+    {
+      Trace("sse-match") << "SSE add term : " << t << std::endl;
+      d_match_trie.addTerm(t);
+    }
+    d_pairs[t].insert(to);
+  }
+}
+
 bool SygusSamplerExt::notify(Node s,
                              Node n,
                              std::vector<Node>& vars,
index fa0d670d27bb2fbff65b5720705bef8e1568ffe4..18b8f5511ba465bfb5b92e6ac348a2376e6d908f 100644 (file)
@@ -431,6 +431,13 @@ class SygusSamplerExt : public SygusSampler
    */
   Node registerTerm(Node n, bool forceKeep = false) override;
 
+  /** register relevant pair
+   *
+   * This should be called after registerTerm( n ) returns eq_n.
+   * This registers ( n, eq_n ) as a relevant pair with this class.
+   */
+  void registerRelevantPair(Node n, Node eq_n);
+
  private:
   /** dynamic rewriter class */
   std::unique_ptr<DynamicRewriter> d_drewrite;