Split collect model info by types in strings (#3847)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 2 Mar 2020 17:57:26 +0000 (11:57 -0600)
committerGitHub <noreply@github.com>
Mon, 2 Mar 2020 17:57:26 +0000 (11:57 -0600)
Towards a theory of sequences.

We will need to do similar splits per type for most of the functions throughout strings.

src/expr/type_node.cpp
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h
src/theory/strings/type_enumerator.h

index b003a7861aaf1d07ac5e300bf9210bc20b14dbd5..945462dd6cfe097bcc08de1100ae25538b2282ef 100644 (file)
@@ -298,7 +298,11 @@ Node TypeNode::mkGroundValue() const
   return *te;
 }
 
-bool TypeNode::isStringLike() const { return isString(); }
+bool TypeNode::isStringLike() const
+{
+  // TODO (cvc4-projects #23): sequence here
+  return isString();
+}
 
 bool TypeNode::isSubtypeOf(TypeNode t) const {
   if(*this == t) {
index 2fbf1655253ca4d68d2c109a13ffb8c1ed0c57e7..a5604925c53d77b5cdcaf1cec230409fa10373f1 100644 (file)
@@ -259,18 +259,36 @@ bool TheoryStrings::collectModelInfo(TheoryModel* m)
     return false;
   }
 
-  std::unordered_set<Node, NodeHashFunction> repSet;
-  NodeManager* nm = NodeManager::currentNM();
+  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction> > repSet;
   // Generate model
   // get the relevant string equivalence classes
   for (const Node& s : termSet)
   {
-    if (s.getType().isString())
+    TypeNode tn = s.getType();
+    if (tn.isStringLike())
     {
       Node r = d_state.getRepresentative(s);
-      repSet.insert(r);
+      repSet[tn].insert(r);
+    }
+  }
+  for (const std::pair<const TypeNode,
+                       std::unordered_set<Node, NodeHashFunction> >& rst :
+       repSet)
+  {
+    if (!collectModelInfoType(rst.first, rst.second, m))
+    {
+      return false;
     }
   }
+  return true;
+}
+
+bool TheoryStrings::collectModelInfoType(
+    TypeNode tn,
+    const std::unordered_set<Node, NodeHashFunction>& repSet,
+    TheoryModel* m)
+{
+  NodeManager* nm = NodeManager::currentNM();
   std::vector<Node> nodes(repSet.begin(), repSet.end());
   std::map< Node, Node > processed;
   std::vector< std::vector< Node > > col;
@@ -394,7 +412,9 @@ bool TheoryStrings::collectModelInfo(TheoryModel* m)
       //use type enumerator
       Assert(lts_values[i].getConst<Rational>() <= Rational(String::maxSize()))
           << "Exceeded UINT32_MAX in string model";
-      StringEnumeratorLength sel(lts_values[i].getConst<Rational>().getNumerator().toUnsignedInt());
+      StringEnumeratorLength sel(
+          tn,
+          lts_values[i].getConst<Rational>().getNumerator().toUnsignedInt());
       for (const Node& eqc : pure_eq)
       {
         Node c;
@@ -490,7 +510,7 @@ bool TheoryStrings::collectModelInfo(TheoryModel* m)
         nc.push_back(r.isConst() ? r : processed[r]);
       }
       Node cc = utils::mkNConcat(nc);
-      Assert(cc.getKind() == kind::CONST_STRING);
+      Assert(cc.isConst());
       Trace("strings-model") << "*** Determined constant " << cc << " for " << nodes[i] << std::endl;
       processed[nodes[i]] = cc;
       if (!m->assertEquality(nodes[i], cc, true))
index 79681a5f95f92829ac12d2a16a50a236396e60a3..84c9e60c6ceec133a22b60018deba9647d0859c1 100644 (file)
@@ -121,6 +121,11 @@ class TheoryStrings : public Theory {
                               std::vector<Node>& vars,
                               std::vector<Node>& subs,
                               std::map<Node, std::vector<Node> >& exp) override;
+  /**
+   * Get all relevant information in this theory regarding the current
+   * model. Return false if a contradiction is discovered.
+   */
+  bool collectModelInfo(TheoryModel* m) override;
 
   // NotifyClass for equality engine
   class NotifyClass : public eq::EqualityEngineNotify {
@@ -231,11 +236,6 @@ private:
 
   std::map< Node, Node > d_eqc_to_len_term;
 
-  /////////////////////////////////////////////////////////////////////////////
-  // MODEL GENERATION
-  /////////////////////////////////////////////////////////////////////////////
- public:
-  bool collectModelInfo(TheoryModel* m) override;
 
   /////////////////////////////////////////////////////////////////////////////
   // NOTIFICATIONS
@@ -298,6 +298,18 @@ private:
    */
   bool areCareDisequal(TNode x, TNode y);
 
+  /** Collect model info for type tn
+   *
+   * Assigns model values (in m) to all relevant terms of the string-like type
+   * tn in the current context, which are stored in repSet.
+   *
+   * Returns false if a conflict is discovered while doing this assignment.
+   */
+  bool collectModelInfoType(
+      TypeNode tn,
+      const std::unordered_set<Node, NodeHashFunction>& repSet,
+      TheoryModel* m);
+
   /** assert pending fact
    *
    * This asserts atom with polarity to the equality engine of this class,
index 0171effaf7b15266168cb5b503bfe7ea34032ed4..16bfc75a6b87d432a67dcd88c6793fbf33316ba0 100644 (file)
@@ -83,21 +83,16 @@ class StringEnumerator : public TypeEnumeratorBase<StringEnumerator> {
   bool isFinished() override { return d_curr.isNull(); }
 };/* class StringEnumerator */
 
-
+/**
+ * Enumerates string values for a given length.
+ */
 class StringEnumeratorLength {
- private:
-  uint32_t d_cardinality;
-  std::vector< unsigned > d_data;
-  Node d_curr;
-  void mkCurr() {
-    //make constant from d_data
-    d_curr = NodeManager::currentNM()->mkConst( ::CVC4::String( d_data ) );
-  }
-
  public:
-  StringEnumeratorLength(uint32_t length, uint32_t card = 256)
-      : d_cardinality(card)
+  StringEnumeratorLength(TypeNode tn, uint32_t length, uint32_t card = 256)
+      : d_type(tn), d_cardinality(card)
   {
+    // TODO (cvc4-projects #23): sequence here
+    Assert(tn.isString());
     for( unsigned i=0; i<length; i++ ){
       d_data.push_back( 0 );
     }
@@ -125,6 +120,21 @@ class StringEnumeratorLength {
   }
 
   bool isFinished() { return d_curr.isNull(); }
+
+ private:
+  /** The type we are enumerating */
+  TypeNode d_type;
+  /** The cardinality of the alphabet */
+  uint32_t d_cardinality;
+  /** The data (index to members) */
+  std::vector<unsigned> d_data;
+  /** The current term */
+  Node d_curr;
+  /** Make the current term from d_data */
+  void mkCurr()
+  {
+    d_curr = NodeManager::currentNM()->mkConst(::CVC4::String(d_data));
+  }
 };
 
 }/* CVC4::theory::strings namespace */