Move cardinality inference scheme to base solver in strings (#3792)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 22 Feb 2020 07:05:43 +0000 (01:05 -0600)
committerGitHub <noreply@github.com>
Sat, 22 Feb 2020 07:05:43 +0000 (23:05 -0800)
Moves handling of cardinality to the base solver, refactors how cardinality is accessed.

No intended behavior change in this commit.

src/theory/strings/base_solver.cpp
src/theory/strings/base_solver.h
src/theory/strings/regexp_operation.cpp
src/theory/strings/regexp_operation.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h
src/theory/strings/theory_strings_rewriter.cpp
src/theory/strings/theory_strings_rewriter.h
src/theory/strings/theory_strings_utils.cpp
src/theory/strings/theory_strings_utils.h
src/theory/strings/type_enumerator.h

index 2f5bc8e2bce8def82436624cacb70027efdeb743..343f6e4f86527dc1954025ac7786a41cc71595eb 100644 (file)
@@ -35,6 +35,7 @@ BaseSolver::BaseSolver(context::Context* c,
 {
   d_emptyString = NodeManager::currentNM()->mkConst(::CVC4::String(""));
   d_false = NodeManager::currentNM()->mkConst(false);
+  d_cardSize = utils::getAlphabetCardinality();
 }
 
 BaseSolver::~BaseSolver() {}
@@ -359,6 +360,138 @@ void BaseSolver::checkConstantEquivalenceClasses(TermIndex* ti,
   }
 }
 
+void BaseSolver::checkCardinality()
+{
+  // This will create a partition of eqc, where each collection has length that
+  // are pairwise propagated to be equal. We do not require disequalities
+  // between the lengths of each collection, since we split on disequalities
+  // between lengths of string terms that are disequal (DEQ-LENGTH-SP).
+  std::vector<std::vector<Node> > cols;
+  std::vector<Node> lts;
+  d_state.separateByLength(d_stringsEqc, cols, lts);
+  NodeManager* nm = NodeManager::currentNM();
+  Trace("strings-card") << "Check cardinality...." << std::endl;
+  // for each collection
+  for (unsigned i = 0, csize = cols.size(); i < csize; ++i)
+  {
+    Node lr = lts[i];
+    Trace("strings-card") << "Number of strings with length equal to " << lr
+                          << " is " << cols[i].size() << std::endl;
+    if (cols[i].size() <= 1)
+    {
+      // no restriction on sets in the partition of size 1
+      continue;
+    }
+    // size > c^k
+    unsigned card_need = 1;
+    double curr = static_cast<double>(cols[i].size());
+    while (curr > d_cardSize)
+    {
+      curr = curr / static_cast<double>(d_cardSize);
+      card_need++;
+    }
+    Trace("strings-card")
+        << "Need length " << card_need
+        << " for this number of strings (where alphabet size is " << d_cardSize
+        << ")." << std::endl;
+    // check if we need to split
+    bool needsSplit = true;
+    if (lr.isConst())
+    {
+      // if constant, compare
+      Node cmp = nm->mkNode(GEQ, lr, nm->mkConst(Rational(card_need)));
+      cmp = Rewriter::rewrite(cmp);
+      needsSplit = !cmp.getConst<bool>();
+    }
+    else
+    {
+      // find the minimimum constant that we are unknown to be disequal from, or
+      // otherwise stop if we increment such that cardinality does not apply
+      unsigned r = 0;
+      bool success = true;
+      while (r < card_need && success)
+      {
+        Node rr = nm->mkConst(Rational(r));
+        if (d_state.areDisequal(rr, lr))
+        {
+          r++;
+        }
+        else
+        {
+          success = false;
+        }
+      }
+      if (r > 0)
+      {
+        Trace("strings-card")
+            << "Symbolic length " << lr << " must be at least " << r
+            << " due to constant disequalities." << std::endl;
+      }
+      needsSplit = r < card_need;
+    }
+
+    if (!needsSplit)
+    {
+      // don't need to split
+      continue;
+    }
+    // first, try to split to merge equivalence classes
+    for (std::vector<Node>::iterator itr1 = cols[i].begin();
+         itr1 != cols[i].end();
+         ++itr1)
+    {
+      for (std::vector<Node>::iterator itr2 = itr1 + 1; itr2 != cols[i].end();
+           ++itr2)
+      {
+        if (!d_state.areDisequal(*itr1, *itr2))
+        {
+          // add split lemma
+          if (d_im.sendSplit(*itr1, *itr2, "CARD-SP"))
+          {
+            return;
+          }
+        }
+      }
+    }
+    // otherwise, we need a length constraint
+    uint32_t int_k = static_cast<uint32_t>(card_need);
+    EqcInfo* ei = d_state.getOrMakeEqcInfo(lr, true);
+    Trace("strings-card") << "Previous cardinality used for " << lr << " is "
+                          << ((int)ei->d_cardinalityLemK.get() - 1)
+                          << std::endl;
+    if (int_k + 1 > ei->d_cardinalityLemK.get())
+    {
+      Node k_node = nm->mkConst(Rational(int_k));
+      // add cardinality lemma
+      Node dist = nm->mkNode(DISTINCT, cols[i]);
+      std::vector<Node> vec_node;
+      vec_node.push_back(dist);
+      for (std::vector<Node>::iterator itr1 = cols[i].begin();
+           itr1 != cols[i].end();
+           ++itr1)
+      {
+        Node len = nm->mkNode(STRING_LENGTH, *itr1);
+        if (len != lr)
+        {
+          Node len_eq_lr = len.eqNode(lr);
+          vec_node.push_back(len_eq_lr);
+        }
+      }
+      Node len = nm->mkNode(STRING_LENGTH, cols[i][0]);
+      Node cons = nm->mkNode(GEQ, len, k_node);
+      cons = Rewriter::rewrite(cons);
+      ei->d_cardinalityLemK.set(int_k + 1);
+      if (!cons.isConst() || !cons.getConst<bool>())
+      {
+        std::vector<Node> emptyVec;
+        d_im.sendInference(emptyVec, vec_node, cons, "CARDINALITY", true);
+        return;
+      }
+    }
+  }
+  Trace("strings-card") << "...end check cardinality" << std::endl;
+}
+
 bool BaseSolver::isCongruent(Node n)
 {
   return d_congruent.find(n) != d_congruent.end();
index c87a3af9e8063121468508f32e730d0cb7a50e78..3681b49a4a1d7e5310efd3c86bcd068416dc947e 100644 (file)
@@ -70,6 +70,13 @@ class BaseSolver
    * case, we infer the fact x ++ "c" ++ y = "acb".
    */
   void checkConstantEquivalenceClasses();
+  /** check cardinality
+   *
+   * This function checks whether a cardinality inference needs to be applied
+   * to a set of equivalence classes. For details, see Step 5 of the proof
+   * procedure from Liang et al, CAV 2014.
+   */
+  void checkCardinality();
   //-----------------------end inference steps
 
   //-----------------------query functions
@@ -182,6 +189,8 @@ class BaseSolver
   std::vector<Node> d_stringsEqc;
   /** A term index for each function kind */
   std::map<Kind, TermIndex> d_termIndex;
+  /** the cardinality of the alphabet */
+  uint32_t d_cardSize;
 }; /* class BaseSolver */
 
 }  // namespace strings
index 8707593c733e3e9b6a41dbb5e3e621b3bce46b22..1b2de0eb5b216c784257d44f82707fdea36caf9b 100644 (file)
@@ -42,7 +42,7 @@ RegExpOpr::RegExpOpr()
                                                std::vector<Node>{})),
       d_sigma_star(NodeManager::currentNM()->mkNode(kind::REGEXP_STAR, d_sigma))
 {
-  d_lastchar = TheoryStringsRewriter::getAlphabetCardinality()-1;
+  d_lastchar = utils::getAlphabetCardinality() - 1;
 }
 
 RegExpOpr::~RegExpOpr() {}
index c7464760d91ebfc0233d6b336b2ed4d7429bc48e..91d5df7443909d7aa8ae211d07d9d84de6c13232 100644 (file)
@@ -59,7 +59,7 @@ class RegExpOpr {
 
  private:
   /** the code point of the last character in the alphabet we are using */
-  unsigned d_lastchar;
+  uint32_t d_lastchar;
   Node d_emptyString;
   Node d_true;
   Node d_false;
index 8a3c32fd8f582a8b6c93d2921bf105b158679711..5be2f96d4a92433262b151424d0a9a673a0a23d0 100644 (file)
@@ -127,7 +127,7 @@ TheoryStrings::TheoryStrings(context::Context* c,
   d_true = NodeManager::currentNM()->mkConst( true );
   d_false = NodeManager::currentNM()->mkConst( false );
 
-  d_card_size = TheoryStringsRewriter::getAlphabetCardinality();
+  d_cardSize = utils::getAlphabetCardinality();
 }
 
 TheoryStrings::~TheoryStrings() {
@@ -565,7 +565,7 @@ void TheoryStrings::preRegisterTerm(TNode n) {
             std::vector<unsigned> vec = n.getConst<String>().getVec();
             for (unsigned u : vec)
             {
-              if (u >= d_card_size)
+              if (u >= d_cardSize)
               {
                 std::stringstream ss;
                 ss << "Characters in string \"" << n
@@ -1146,110 +1146,6 @@ void TheoryStrings::checkRegisterTermsNormalForms()
   }
 }
 
-void TheoryStrings::checkCardinality() {
-  //int cardinality = options::stringCharCardinality();
-  //Trace("strings-solve-debug2") << "get cardinality: " << cardinality << endl;
-
-  //AJR: this will create a partition of eqc, where each collection has length that are pairwise propagated to be equal.
-  //  we do not require disequalities between the lengths of each collection, since we split on disequalities between lengths of string terms that are disequal (DEQ-LENGTH-SP).
-  //  TODO: revisit this?
-  const std::vector<Node>& seqc = d_bsolver.getStringEqc();
-  std::vector< std::vector< Node > > cols;
-  std::vector< Node > lts;
-  d_state.separateByLength(seqc, cols, lts);
-
-  Trace("strings-card") << "Check cardinality...." << std::endl;
-  for( unsigned i = 0; i<cols.size(); ++i ) {
-    Node lr = lts[i];
-    Trace("strings-card") << "Number of strings with length equal to " << lr << " is " << cols[i].size() << std::endl;
-    if( cols[i].size() > 1 ) {
-      // size > c^k
-      unsigned card_need = 1;
-      double curr = (double)cols[i].size();
-      while( curr>d_card_size ){
-        curr = curr/(double)d_card_size;
-        card_need++;
-      }
-      Trace("strings-card") << "Need length " << card_need << " for this number of strings (where alphabet size is " << d_card_size << ")." << std::endl;
-      //check if we need to split
-      bool needsSplit = true;
-      if( lr.isConst() ){
-        // if constant, compare
-        Node cmp = NodeManager::currentNM()->mkNode( kind::GEQ, lr, NodeManager::currentNM()->mkConst( Rational( card_need ) ) );
-        cmp = Rewriter::rewrite( cmp );
-        needsSplit = cmp!=d_true;
-      }else{
-        // find the minimimum constant that we are unknown to be disequal from, or otherwise stop if we increment such that cardinality does not apply
-        unsigned r=0; 
-        bool success = true;
-        while( r<card_need && success ){
-          Node rr = NodeManager::currentNM()->mkConst<Rational>( Rational(r) );
-          if (d_state.areDisequal(rr, lr))
-          {
-            r++;
-          }
-          else
-          {
-            success = false;
-          }
-        }
-        if( r>0 ){
-          Trace("strings-card") << "Symbolic length " << lr << " must be at least " << r << " due to constant disequalities." << std::endl;
-        }
-        needsSplit = r<card_need;
-      }
-
-      if( needsSplit ){
-        unsigned int int_k = (unsigned int)card_need;
-        for( std::vector< Node >::iterator itr1 = cols[i].begin();
-            itr1 != cols[i].end(); ++itr1) {
-          for( std::vector< Node >::iterator itr2 = itr1 + 1;
-            itr2 != cols[i].end(); ++itr2) {
-            if (!d_state.areDisequal(*itr1, *itr2))
-            {
-              // add split lemma
-              if (d_im.sendSplit(*itr1, *itr2, "CARD-SP"))
-              {
-                return;
-              }
-            }
-          }
-        }
-        EqcInfo* ei = d_state.getOrMakeEqcInfo(lr, true);
-        Trace("strings-card")
-            << "Previous cardinality used for " << lr << " is "
-            << ((int)ei->d_cardinalityLemK.get() - 1) << std::endl;
-        if (int_k + 1 > ei->d_cardinalityLemK.get())
-        {
-          Node k_node = NodeManager::currentNM()->mkConst( ::CVC4::Rational( int_k ) );
-          //add cardinality lemma
-          Node dist = NodeManager::currentNM()->mkNode( kind::DISTINCT, cols[i] );
-          std::vector< Node > vec_node;
-          vec_node.push_back( dist );
-          for( std::vector< Node >::iterator itr1 = cols[i].begin();
-              itr1 != cols[i].end(); ++itr1) {
-            Node len = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, *itr1 );
-            if( len!=lr ) {
-              Node len_eq_lr = len.eqNode(lr);
-              vec_node.push_back( len_eq_lr );
-            }
-          }
-          Node len = NodeManager::currentNM()->mkNode( kind::STRING_LENGTH, cols[i][0] );
-          Node cons = NodeManager::currentNM()->mkNode( kind::GEQ, len, k_node );
-          cons = Rewriter::rewrite( cons );
-          ei->d_cardinalityLemK.set(int_k + 1);
-          if( cons!=d_true ){
-            d_im.sendInference(
-                d_empty_vec, vec_node, cons, "CARDINALITY", true);
-            return;
-          }
-        }
-      }
-    }
-  }
-  Trace("strings-card") << "...end check cardinality" << std::endl;
-}
-
 Node TheoryStrings::ppRewrite(TNode atom) {
   Trace("strings-ppr") << "TheoryStrings::ppRewrite " << atom << std::endl;
   Node atomElim;
@@ -1328,7 +1224,7 @@ void TheoryStrings::runInferStep(InferStep s, int effort)
     case CHECK_REGISTER_TERMS_NF: checkRegisterTermsNormalForms(); break;
     case CHECK_EXTF_REDUCTION: d_esolver->checkExtfReductions(effort); break;
     case CHECK_MEMBERSHIP: checkMemberships(); break;
-    case CHECK_CARDINALITY: checkCardinality(); break;
+    case CHECK_CARDINALITY: d_bsolver.checkCardinality(); break;
     default: Unreachable(); break;
   }
   Trace("strings-process") << "Done " << s
index 55852490f366c0482f0155f196fa91925a5c8187..f40af6e67ce758c784e3c912e35f1f09209bc17a 100644 (file)
@@ -206,7 +206,7 @@ class TheoryStrings : public Theory {
   Node d_one;
   Node d_neg_one;
   /** the cardinality of the alphabet */
-  unsigned d_card_size;
+  uint32_t d_cardSize;
   /** The notify class */
   NotifyClass d_notify;
   /** Equaltity engine */
@@ -401,13 +401,6 @@ private:
    * FroCoS 2015.
    */
   void checkMemberships();
-  /** check cardinality
-   *
-   * This function checks whether a cardinality inference needs to be applied
-   * to a set of equivalence classes. For details, see Step 5 of the proof
-   * procedure from Liang et al, CAV 2014.
-   */
-  void checkCardinality();
   //-----------------------end inference steps
 
   //-----------------------representation of the strategy
index e9a4ebfd1f02458c6f13e7129f8882509723b199..339d11dd2715332e28a4b6a8f417b07d34dce5a4 100644 (file)
@@ -230,17 +230,6 @@ Node TheoryStringsRewriter::simpleRegexpConsume( std::vector< Node >& mchildren,
   return Node::null();
 }
 
-unsigned TheoryStringsRewriter::getAlphabetCardinality()
-{
-  if (options::stdPrintASCII())
-  {
-    Assert(128 <= String::num_codes());
-    return 128;
-  }
-  Assert(256 <= String::num_codes());
-  return 256;
-}
-
 Node TheoryStringsRewriter::rewriteEquality(Node node)
 {
   Assert(node.getKind() == kind::EQUAL);
index 7d76234bcca6182ec2180e84cd126c7cb72f1606..c9733433c1bec8e9feaaff28fa9cfb4d58657077 100644 (file)
@@ -159,8 +159,6 @@ class TheoryStringsRewriter : public TheoryRewriter
   RewriteResponse postRewrite(TNode node) override;
   RewriteResponse preRewrite(TNode node) override;
 
-  /** get the cardinality of the alphabet used, based on the options */
-  static unsigned getAlphabetCardinality();
   /** rewrite equality
    *
    * This method returns a formula that is equivalent to the equality between
index a564c82e16a136f6fc47befb0038e57a3e5bbc41..a325108e446a0970ea41906c6d48b863121685ce 100644 (file)
@@ -16,6 +16,7 @@
 
 #include "theory/strings/theory_strings_utils.h"
 
+#include "options/strings_options.h"
 #include "theory/rewriter.h"
 
 using namespace CVC4::kind;
@@ -25,6 +26,17 @@ namespace theory {
 namespace strings {
 namespace utils {
 
+uint32_t getAlphabetCardinality()
+{
+  if (options::stdPrintASCII())
+  {
+    Assert(128 <= String::num_codes());
+    return 128;
+  }
+  Assert(256 <= String::num_codes());
+  return 256;
+}
+
 Node mkAnd(const std::vector<Node>& a)
 {
   std::vector<Node> au;
index ccdac8edf3e732ae42ad4c96ab834d65c8bb78d2..51fe8cfc7650d2e1b77515c82b65712676b3e276 100644 (file)
@@ -28,6 +28,9 @@ namespace theory {
 namespace strings {
 namespace utils {
 
+/** get the cardinality of the alphabet used, based on the options */
+uint32_t getAlphabetCardinality();
+
 /**
  * Make the conjunction of nodes in a. Removes duplicate conjuncts, returns
  * true if a is empty, and a single literal if a has size 1.
index 4218d4ce5510aa1aea0bb1c26335b3c93978b68a..0171effaf7b15266168cb5b503bfe7ea34032ed4 100644 (file)
@@ -24,6 +24,7 @@
 #include "expr/kind.h"
 #include "expr/type_node.h"
 #include "theory/strings/theory_strings_rewriter.h"
+#include "theory/strings/theory_strings_utils.h"
 #include "theory/type_enumerator.h"
 #include "util/regexp.h"
 
@@ -33,7 +34,7 @@ namespace strings {
 
 class StringEnumerator : public TypeEnumeratorBase<StringEnumerator> {
   std::vector< unsigned > d_data;
-  unsigned d_cardinality;
+  uint32_t d_cardinality;
   Node d_curr;
   void mkCurr() {
     //make constant from d_data
@@ -46,7 +47,7 @@ class StringEnumerator : public TypeEnumeratorBase<StringEnumerator> {
   {
     Assert(type.getKind() == kind::TYPE_CONSTANT
            && type.getConst<TypeConstant>() == STRING_TYPE);
-    d_cardinality = TheoryStringsRewriter::getAlphabetCardinality();
+    d_cardinality = utils::getAlphabetCardinality();
     mkCurr();
   }
   Node operator*() override { return d_curr; }
@@ -85,7 +86,7 @@ class StringEnumerator : public TypeEnumeratorBase<StringEnumerator> {
 
 class StringEnumeratorLength {
  private:
-  unsigned d_cardinality;
+  uint32_t d_cardinality;
   std::vector< unsigned > d_data;
   Node d_curr;
   void mkCurr() {
@@ -94,7 +95,9 @@ class StringEnumeratorLength {
   }
 
  public:
-  StringEnumeratorLength(unsigned length, unsigned card = 256) : d_cardinality(card) {
+  StringEnumeratorLength(uint32_t length, uint32_t card = 256)
+      : d_cardinality(card)
+  {
     for( unsigned i=0; i<length; i++ ){
       d_data.push_back( 0 );
     }