Compute care graph for datatypes. Preliminary results show 20x speed up on larger...
authorAndrew Reynolds <reynolds@larapc05.epfl.ch>
Thu, 24 Apr 2014 12:27:38 +0000 (14:27 +0200)
committerAndrew Reynolds <reynolds@larapc05.epfl.ch>
Thu, 24 Apr 2014 12:27:38 +0000 (14:27 +0200)
src/theory/datatypes/datatypes_rewriter.h
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h

index 75d1f2b2e4f943425f114f517169fcfbb8915122..dc85d0cd672fef20b2573980455afd84bff030bc 100644 (file)
@@ -211,6 +211,8 @@ public:
       }else if( rew.size()==1 && rew[0]!=in ){
         Trace("datatypes-rewrite") << "Rewrite equality " << in << " to " << rew[0] << std::endl;
         return RewriteResponse(REWRITE_AGAIN_FULL, rew[0] );
+      }else{
+        Trace("datatypes-rewrite-debug") << "Did not rewrite equality " << in << " " << in[0].getKind() << " " << in[1].getKind() << std::endl;
       }
     }
 
@@ -240,8 +242,12 @@ public:
         }
       }
     }else if( n1!=n2 ){
-      Node eq = NodeManager::currentNM()->mkNode( n1.getType().isBoolean() ? kind::IFF : kind::EQUAL, n1, n2 );
-      rew.push_back( eq );
+      if( n1.isConst() && n2.isConst() ){
+        return true;        
+      }else{
+        Node eq = NodeManager::currentNM()->mkNode( n1.getType().isBoolean() ? kind::IFF : kind::EQUAL, n1, n2 );
+        rew.push_back( eq );
+      }
     }
     return false;
   }
index 2fb1a2679fed82374ea464474df6d794b96bf45f..e706d38464896ff86fda1b73cba1356d6827c5e1 100644 (file)
@@ -50,7 +50,9 @@ TheoryDatatypes::TheoryDatatypes(Context* c, UserContext* u, OutputChannel& out,
   d_labels( c ),
   d_selector_apps( c ),
   d_conflict( c, false ),
-  d_collectTermsCache( c ){
+  d_collectTermsCache( c ),
+  d_consTerms( c ),
+  d_selTerms( c ){
 
   // The kinds we are treating as function application in congruence
   d_equalityEngine.addFunctionKind(kind::APPLY_CONSTRUCTOR);
@@ -963,7 +965,45 @@ EqualityStatus TheoryDatatypes::getEqualityStatus(TNode a, TNode b){
 }
 
 void TheoryDatatypes::computeCareGraph(){
-  Theory::computeCareGraph();
+  Trace("ajr-temp") << "Compute graph for dt..." << std::endl;
+  vector< pair<TNode, TNode> > currentPairs;
+  for( unsigned r=0; r<2; r++ ){
+    unsigned functionTerms = r==0 ? d_consTerms.size() : d_selTerms.size();
+    for( unsigned i=0; i<functionTerms; i++ ){
+      TNode f1 = r==0 ? d_consTerms[i] : d_selTerms[i];
+      for( unsigned j=i+1; j<functionTerms; j++ ){
+        TNode f2 = r==0 ? d_consTerms[j] : d_selTerms[j];
+        if( f1.getOperator()==f2.getOperator() && !areEqual( f1, f2 ) ){
+          Trace("ajr-temp") << "Check " << f1 << " and " << f2 << std::endl;
+          bool somePairIsDisequal = false;
+          currentPairs.clear();
+          for (unsigned k = 0; k < f1.getNumChildren(); ++ k) {
+            TNode x = f1[k];
+            TNode y = f2[k];
+            if( areDisequal(x, y) ){
+              somePairIsDisequal = true;
+              break;
+            }else if( !areEqual( x, y ) &&
+                      d_equalityEngine.isTriggerTerm(x, THEORY_UF) && 
+                      d_equalityEngine.isTriggerTerm(y, THEORY_UF) ){
+              Trace("ajr-temp") << "Arg #" << k << " is " << x << " " << y << std::endl;
+              TNode x_shared = d_equalityEngine.getTriggerTermRepresentative(x, THEORY_DATATYPES);
+              TNode y_shared = d_equalityEngine.getTriggerTermRepresentative(y, THEORY_DATATYPES);
+              Trace("ajr-temp") << "Arg #" << k << " shared term is " << x_shared << " " << y_shared << std::endl;
+              currentPairs.push_back(make_pair(x_shared, y_shared));
+            }
+          } 
+          if (!somePairIsDisequal) {
+            for (unsigned c = 0; c < currentPairs.size(); ++ c) {
+              addCarePair(currentPairs[c].first, currentPairs[c].second);
+            }
+          }
+        }
+      }
+    }
+  }
+  Trace("ajr-temp") << "Done Compute graph for dt." << std::endl;
+  //Theory::computeCareGraph();
 }
 
 void TheoryDatatypes::collectModelInfo( TheoryModel* m, bool fullModel ){
@@ -1143,6 +1183,7 @@ void TheoryDatatypes::collectTerms( Node n ) {
       collectTerms( n[i] );
     }
     if( n.getKind() == APPLY_CONSTRUCTOR ){
+      d_consTerms.push_back( n );
       /*
       //we must take into account subterm relation when checking for cycles
       for( int i=0; i<(int)n.getNumChildren(); i++ ) {
@@ -1158,6 +1199,7 @@ void TheoryDatatypes::collectTerms( Node n ) {
       }
       */
     }else if( n.getKind() == APPLY_SELECTOR_TOTAL ){
+      d_selTerms.push_back( n );
       //we must also record which selectors exist
       Debug("datatypes") << "  Found selector " << n << endl;
       if (n.getType().isBoolean()) {
@@ -1545,11 +1587,11 @@ bool TheoryDatatypes::mustCommunicateFact( Node n, Node exp ){
   return false;
 }
 
-bool TheoryDatatypes::hasTerm( Node a ){
+bool TheoryDatatypes::hasTerm( TNode a ){
   return d_equalityEngine.hasTerm( a );
 }
 
-bool TheoryDatatypes::areEqual( Node a, Node b ){
+bool TheoryDatatypes::areEqual( TNode a, TNode b ){
   if( a==b ){
     return true;
   }else if( hasTerm( a ) && hasTerm( b ) ){
@@ -1559,7 +1601,7 @@ bool TheoryDatatypes::areEqual( Node a, Node b ){
   }
 }
 
-bool TheoryDatatypes::areDisequal( Node a, Node b ){
+bool TheoryDatatypes::areDisequal( TNode a, TNode b ){
   if( a==b ){
     return false;
   }else if( hasTerm( a ) && hasTerm( b ) ){
@@ -1569,7 +1611,7 @@ bool TheoryDatatypes::areDisequal( Node a, Node b ){
   }
 }
 
-Node TheoryDatatypes::getRepresentative( Node a ){
+Node TheoryDatatypes::getRepresentative( TNode a ){
   if( hasTerm( a ) ){
     return d_equalityEngine.getRepresentative( a );
   }else{
index 307e90e9160e1806dbb09da1f20a71d559446cb9..eb86c3f761a7573168211afe5cda012fb33cbc41 100644 (file)
@@ -168,6 +168,10 @@ private:
   std::vector< Node > d_pending_merge;
   /** expand definition skolem functions */
   std::map< Node, Node > d_exp_def_skolem;
+  /** All the constructor terms that the theory has seen */
+  context::CDList<TNode> d_consTerms;
+  /** All the selector terms that the theory has seen */
+  context::CDList<TNode> d_selTerms;
 private:
   /** assert fact */
   void assertFact( Node fact, Node exp );
@@ -261,10 +265,10 @@ private:
   bool mustCommunicateFact( Node n, Node exp );
 private:
   //equality queries
-  bool hasTerm( Node a );
-  bool areEqual( Node a, Node b );
-  bool areDisequal( Node a, Node b );
-  Node getRepresentative( Node a );
+  bool hasTerm( TNode a );
+  bool areEqual( TNode a, TNode b );
+  bool areDisequal( TNode a, TNode b );
+  Node getRepresentative( TNode a );
 public:
   /** get equality engine */
   eq::EqualityEngine* getEqualityEngine() { return &d_equalityEngine; }