Fix relevant domain for parametric operators (#7198)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 16 Sep 2021 17:14:07 +0000 (12:14 -0500)
committerGitHub <noreply@github.com>
Thu, 16 Sep 2021 17:14:07 +0000 (17:14 +0000)
Fixes #6531.

This issue also occurs when using `--full-saturate-quant` on facebook benchmarks that combine multiple sequence types.

It does some cleanup on relevant domain in the process.

src/theory/quantifiers/relevant_domain.cpp
src/theory/quantifiers/relevant_domain.h
src/theory/quantifiers/term_database.cpp
src/theory/quantifiers/term_tuple_enumerator.cpp

index f4eb954692c83a8c780aae8e95fbd1e492d6b072..a531d88b79c4dcd722c29f9d6d2a46ac3744adac 100644 (file)
@@ -85,20 +85,23 @@ RelevantDomain::RelevantDomain(Env& env,
 }
 
 RelevantDomain::~RelevantDomain() {
-  for( std::map< Node, std::map< int, RDomain * > >::iterator itr = d_rel_doms.begin(); itr != d_rel_doms.end(); ++itr ){
-    for( std::map< int, RDomain * >::iterator itr2 = itr->second.begin(); itr2 != itr->second.end(); ++itr2 ){
-      RDomain * current = (*itr2).second;
+  for (auto& r : d_rel_doms)
+  {
+    for (auto& rr : r.second)
+    {
+      RDomain* current = rr.second;
       Assert(current != NULL);
       delete current;
     }
   }
 }
 
-RelevantDomain::RDomain * RelevantDomain::getRDomain( Node n, int i, bool getParent ) {
+RelevantDomain::RDomain* RelevantDomain::getRDomain(Node n,
+                                                    size_t i,
+                                                    bool getParent)
+{
   if( d_rel_doms.find( n )==d_rel_doms.end() || d_rel_doms[n].find( i )==d_rel_doms[n].end() ){
     d_rel_doms[n][i] = new RDomain;
-    d_rn_map[d_rel_doms[n][i]] = n;
-    d_ri_map[d_rel_doms[n][i]] = i;
   }
   return getParent ? d_rel_doms[n][i]->getParent() : d_rel_doms[n][i];
 }
@@ -112,9 +115,11 @@ void RelevantDomain::registerQuantifier(Node q) {}
 void RelevantDomain::compute(){
   if( !d_is_computed ){
     d_is_computed = true;
-    for( std::map< Node, std::map< int, RDomain * > >::iterator it = d_rel_doms.begin(); it != d_rel_doms.end(); ++it ){
-      for( std::map< int, RDomain * >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
-        it2->second->reset();
+    for (auto& r : d_rel_doms)
+    {
+      for (auto& rr : r.second)
+      {
+        rr.second->reset();
       }
     }
     FirstOrderModel* fm = d_treg.getModel();
@@ -144,21 +149,37 @@ void RelevantDomain::compute(){
       }
     }
     //print debug
-    for( std::map< Node, std::map< int, RDomain * > >::iterator it = d_rel_doms.begin(); it != d_rel_doms.end(); ++it ){
-      Trace("rel-dom") << "Relevant domain for " << it->first << " : " << std::endl;
-      for( std::map< int, RDomain * >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){
-        Trace("rel-dom") << "   " << it2->first << " : ";
-        RDomain * r = it2->second;
+    for (std::pair<const Node, std::map<size_t, RDomain*> >& d : d_rel_doms)
+    {
+      Trace("rel-dom") << "Relevant domain for " << d.first << " : "
+                       << std::endl;
+      for (std::pair<const size_t, RDomain*>& dd : d.second)
+      {
+        Trace("rel-dom") << "   " << dd.first << " : ";
+        RDomain* r = dd.second;
         RDomain * rp = r->getParent();
         if( r==rp ){
           r->removeRedundantTerms(d_qs);
-          for( unsigned i=0; i<r->d_terms.size(); i++ ){
-            Trace("rel-dom") << r->d_terms[i] << " ";
-          }
+          Trace("rel-dom") << r->d_terms;
         }else{
-          Trace("rel-dom") << "Dom( " << d_rn_map[rp] << ", " << d_ri_map[rp] << " ) ";
+          Trace("rel-dom") << "Dom( " << d.first << ", " << dd.first << " ) ";
         }
         Trace("rel-dom") << std::endl;
+        if (Configuration::isAssertionBuild())
+        {
+          if (d.first.getKind() == FORALL)
+          {
+            TypeNode expectedType = d.first[0][dd.first].getType();
+            for (const Node& t : r->d_terms)
+            {
+              if (!t.getType().isComparableTo(expectedType))
+              {
+                Unhandled() << "Relevant domain: bad type " << t.getType()
+                            << ", expected " << expectedType;
+              }
+            }
+          }
+        }
       }
     }
   }
@@ -212,7 +233,10 @@ void RelevantDomain::computeRelevantDomainNode(Node q,
 {
   Trace("rel-dom-debug") << "Compute relevant domain " << n << "..." << std::endl;
   Node op = d_treg.getTermDatabase()->getMatchOperator(n);
-  if (!op.isNull())
+  // Relevant domain only makes sense for non-parametric operators, thus we
+  // check op==n.getOperator() here. This otherwise would lead to bad types
+  // for terms in the relevant domain.
+  if (!op.isNull() && op == n.getOperator())
   {
     for (size_t i = 0, nchild = n.getNumChildren(); i < nchild; i++)
     {
@@ -230,19 +254,24 @@ void RelevantDomain::computeRelevantDomainNode(Node q,
   if( ( ( n.getKind()==EQUAL && !n[0].getType().isBoolean() ) || n.getKind()==GEQ ) && TermUtil::hasInstConstAttr( n ) ){
     //compute the information for what this literal does
     computeRelevantDomainLit( q, hasPol, pol, n );
-    if( d_rel_dom_lit[hasPol][pol][n].d_merge ){
-      Assert(d_rel_dom_lit[hasPol][pol][n].d_rd[0] != NULL
-             && d_rel_dom_lit[hasPol][pol][n].d_rd[1] != NULL);
-      RDomain * rd1 = d_rel_dom_lit[hasPol][pol][n].d_rd[0]->getParent();
-      RDomain * rd2 = d_rel_dom_lit[hasPol][pol][n].d_rd[1]->getParent();
+    RDomainLit& rdl = d_rel_dom_lit[hasPol][pol][n];
+    if (rdl.d_merge)
+    {
+      Assert(rdl.d_rd[0] != nullptr && rdl.d_rd[1] != nullptr);
+      RDomain* rd1 = rdl.d_rd[0]->getParent();
+      RDomain* rd2 = rdl.d_rd[1]->getParent();
       if( rd1!=rd2 ){
         rd1->merge( rd2 );
       }
-    }else{
-      if( d_rel_dom_lit[hasPol][pol][n].d_rd[0]!=NULL ){
-        RDomain * rd = d_rel_dom_lit[hasPol][pol][n].d_rd[0]->getParent();
-        for( unsigned i=0; i<d_rel_dom_lit[hasPol][pol][n].d_val.size(); i++ ){
-          rd->addTerm( d_rel_dom_lit[hasPol][pol][n].d_val[i] );
+    }
+    else
+    {
+      if (rdl.d_rd[0] != nullptr)
+      {
+        RDomain* rd = rdl.d_rd[0]->getParent();
+        for (unsigned i = 0; i < rdl.d_val.size(); i++)
+        {
+          rd->addTerm(rdl.d_val[i]);
         }
       }
     }
@@ -254,7 +283,7 @@ void RelevantDomain::computeRelevantDomainOpCh( RDomain * rf, Node n ) {
   if( n.getKind()==INST_CONSTANT ){
     Node q = TermUtil::getInstConstAttr(n);
     //merge the RDomains
-    unsigned id = n.getAttribute(InstVarNumAttribute());
+    size_t id = n.getAttribute(InstVarNumAttribute());
     Assert(q[0][id].getType() == n.getType());
     Trace("rel-dom-debug") << n << " is variable # " << id << " for " << q;
     Trace("rel-dom-debug") << " with body : " << d_qreg.getInstConstantBody(q)
@@ -272,7 +301,8 @@ void RelevantDomain::computeRelevantDomainOpCh( RDomain * rf, Node n ) {
 
 void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, Node n ) {
   if( d_rel_dom_lit[hasPol][pol].find( n )==d_rel_dom_lit[hasPol][pol].end() ){
-    d_rel_dom_lit[hasPol][pol][n].d_merge = false;
+    RDomainLit& rdl = d_rel_dom_lit[hasPol][pol][n];
+    rdl.d_merge = false;
     int varCount = 0;
     int varCh = -1;
     for( unsigned i=0; i<n.getNumChildren(); i++ ){
@@ -281,24 +311,24 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No
         // different from q
         Node qi = TermUtil::getInstConstAttr(n[i]);
         unsigned id = n[i].getAttribute(InstVarNumAttribute());
-        d_rel_dom_lit[hasPol][pol][n].d_rd[i] = getRDomain(qi, id, false);
+        rdl.d_rd[i] = getRDomain(qi, id, false);
         varCount++;
         varCh = i;
       }else{
-        d_rel_dom_lit[hasPol][pol][n].d_rd[i] = NULL;
+        rdl.d_rd[i] = nullptr;
       }
     }
     
     Node r_add;
     bool varLhs = true;
     if( varCount==2 ){
-      d_rel_dom_lit[hasPol][pol][n].d_merge = true;
+      rdl.d_merge = true;
     }else{
       if( varCount==1 ){
         r_add = n[1-varCh];
         varLhs = (varCh==0);
-        d_rel_dom_lit[hasPol][pol][n].d_rd[0] = d_rel_dom_lit[hasPol][pol][n].d_rd[varCh];
-        d_rel_dom_lit[hasPol][pol][n].d_rd[1] = NULL;
+        rdl.d_rd[0] = rdl.d_rd[varCh];
+        rdl.d_rd[1] = nullptr;
       }else{
         //solve the inequality for one/two variables, if possible
         if( n[0].getType().isReal() ){
@@ -323,7 +353,10 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No
                 hasNonVar = true;
               }
             }
+            Trace("rel-dom") << "Process lit " << n << ", var/var2=" << var
+                             << "/" << var2 << std::endl;
             if( !var.isNull() ){
+              Assert(var.hasAttribute(InstVarNumAttribute()));
               if( var2.isNull() ){
                 //single variable solve
                 Node veq_c;
@@ -334,47 +367,54 @@ void RelevantDomain::computeRelevantDomainLit( Node q, bool hasPol, bool pol, No
                   if( veq_c.isNull() ){
                     r_add = val;
                     varLhs = (ires==1);
-                    d_rel_dom_lit[hasPol][pol][n].d_rd[0] = getRDomain( q, var.getAttribute(InstVarNumAttribute()), false );
-                    d_rel_dom_lit[hasPol][pol][n].d_rd[1] = NULL;
+                    rdl.d_rd[0] = getRDomain(
+                        q, var.getAttribute(InstVarNumAttribute()), false);
+                    rdl.d_rd[1] = nullptr;
                   }
                 }
               }else if( !hasNonVar ){
+                Assert(var2.hasAttribute(InstVarNumAttribute()));
                 //merge the domains
-                d_rel_dom_lit[hasPol][pol][n].d_rd[0] = getRDomain( q, var.getAttribute(InstVarNumAttribute()), false );
-                d_rel_dom_lit[hasPol][pol][n].d_rd[1] = getRDomain( q, var2.getAttribute(InstVarNumAttribute()), false );
-                d_rel_dom_lit[hasPol][pol][n].d_merge = true;
+                rdl.d_rd[0] = getRDomain(
+                    q, var.getAttribute(InstVarNumAttribute()), false);
+                rdl.d_rd[1] = getRDomain(
+                    q, var2.getAttribute(InstVarNumAttribute()), false);
+                rdl.d_merge = true;
               }
             }
           }
         }
       }
     }
-    if( d_rel_dom_lit[hasPol][pol][n].d_merge ){
+    if (rdl.d_merge)
+    {
       //do not merge if constant negative polarity
       if( hasPol && !pol ){
-        d_rel_dom_lit[hasPol][pol][n].d_merge = false;
+        rdl.d_merge = false;
       }
-    }else if( !r_add.isNull() && !TermUtil::hasInstConstAttr( r_add ) ){
+    }
+    else if (!r_add.isNull() && !TermUtil::hasInstConstAttr(r_add))
+    {
       Trace("rel-dom-debug") << "...add term " << r_add << ", pol = " << pol << ", kind = " << n.getKind() << std::endl;
       //the negative occurrence adds the term to the domain
       if( !hasPol || !pol ){
-        d_rel_dom_lit[hasPol][pol][n].d_val.push_back( r_add );
+        rdl.d_val.push_back(r_add);
       }
       //the positive occurence adds other terms
       if( ( !hasPol || pol ) && n[0].getType().isInteger() ){
         if( n.getKind()==EQUAL ){
           for( unsigned i=0; i<2; i++ ){
-            d_rel_dom_lit[hasPol][pol][n].d_val.push_back(
-                ArithMSum::offset(r_add, i == 0 ? 1 : -1));
+            rdl.d_val.push_back(ArithMSum::offset(r_add, i == 0 ? 1 : -1));
           }
         }else if( n.getKind()==GEQ ){
-          d_rel_dom_lit[hasPol][pol][n].d_val.push_back(
-              ArithMSum::offset(r_add, varLhs ? 1 : -1));
+          rdl.d_val.push_back(ArithMSum::offset(r_add, varLhs ? 1 : -1));
         }
       }
-    }else{
-      d_rel_dom_lit[hasPol][pol][n].d_rd[0] = NULL;
-      d_rel_dom_lit[hasPol][pol][n].d_rd[1] = NULL;
+    }
+    else
+    {
+      rdl.d_rd[0] = nullptr;
+      rdl.d_rd[1] = nullptr;
     }
   }
 }
index 3b44b226301c56062b3174d8750d50d5901df6e0..ab1a0e0641f5c5de184665ae0c71dfcb1e22804a 100644 (file)
@@ -108,21 +108,13 @@ class RelevantDomain : public QuantifiersUtil
    * of the equivalence class of relevant domain objects,
    * which is computed as a union find (see RDomain::d_parent).
    */
-  RDomain* getRDomain(Node n, int i, bool getParent = true);
+  RDomain* getRDomain(Node n, size_t i, bool getParent = true);
 
  private:
   /** the relevant domains for each quantified formula and function,
    * for each variable # and argument #.
    */
-  std::map< Node, std::map< int, RDomain * > > d_rel_doms;
-  /** stores the function or quantified formula associated with
-   * each relevant domain object.
-   */
-  std::map< RDomain *, Node > d_rn_map;
-  /** stores the argument or variable number associated with
-   * each relevant domain object.
-   */
-  std::map< RDomain *, int > d_ri_map;
+  std::map<Node, std::map<size_t, RDomain*> > d_rel_doms;
   /** Reference to the quantifiers state object */
   QuantifiersState& d_qs;
   /** Reference to the quantifiers registry */
index 7126d35670caf4c8899b66ef3753a82fdc4d8369..6644f2b2787817a63243aea64fbc6ec48a27a4ac 100644 (file)
@@ -178,7 +178,7 @@ Node TermDb::getMatchOperator( Node n ) {
   if (k == SELECT || k == STORE || k == UNION || k == INTERSECTION
       || k == SUBSET || k == SETMINUS || k == MEMBER || k == SINGLETON
       || k == APPLY_SELECTOR_TOTAL || k == APPLY_SELECTOR || k == APPLY_TESTER
-      || k == SEP_PTO || k == HO_APPLY || k == SEQ_NTH)
+      || k == SEP_PTO || k == HO_APPLY || k == SEQ_NTH || k == STRING_LENGTH)
   {
     //since it is parametric, use a particular one as op
     TypeNode tn = n[0].getType();
index f505d27741244845bc7ee996b6e153fb52acac5b..6982fc806f3cb3134166bed00810584f03247886 100644 (file)
@@ -292,9 +292,10 @@ void TermTupleEnumeratorBase::next(/*out*/ std::vector<Node>& terms)
                        : getTerm(variableIx, d_termIndex[variableIx]);
     terms[variableIx] = t;
     Trace("inst-alg-rd") << t << "  ";
-    Assert(terms[variableIx].isNull()
-           || terms[variableIx].getType().isComparableTo(
-               d_quantifier[0][variableIx].getType()));
+    Assert(t.isNull()
+           || t.getType().isComparableTo(d_quantifier[0][variableIx].getType()))
+        << "Bad type: " << t << " " << t.getType() << " "
+        << d_quantifier[0][variableIx].getType();
   }
   Trace("inst-alg-rd") << std::endl;
 }