Optimizations for datatypes theory. There seems to be a bug in trans_closure, curren...
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 8 Oct 2013 14:01:13 +0000 (09:01 -0500)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 8 Oct 2013 14:01:13 +0000 (09:01 -0500)
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h
src/util/datatype.cpp
src/util/datatype.h
src/util/trans_closure.cpp
src/util/trans_closure.h

index c827a8f07dd229a196cefed3e4a370720a336bc7..fb0ac59235008af9335eaf346dcce34940c5bc9f 100644 (file)
@@ -118,6 +118,12 @@ void TheoryDatatypes::check(Effort e) {
   }
 
   if( e == EFFORT_FULL ) {
+    //check for cycles
+    checkCycles();
+    if( d_conflict ){
+      return;
+    }
+    //check for splits
     Debug("datatypes-split") << "Check for splits " << e << endl;
     bool addedFact = false;
     do {
@@ -159,20 +165,19 @@ void TheoryDatatypes::check(Effort e) {
                   }
                 }
               }
+              /*
               if( !needSplit && mustSpecifyAssignment() ){
-                // &&
                 //for the sake of termination, we must choose the constructor of a ground term
                 //NEED GUARENTEE: groundTerm should not contain any subterms of the same type
-                //** TODO: this is probably not good enough, actually need fair enumeration strategy
-                /*
+                // TODO: this is probably not good enough, actually need fair enumeration strategy
                 Node groundTerm = n.getType().mkGroundTerm();
                 int index = Datatype::indexOf( groundTerm.getOperator().toExpr() );
                 if( pcons[index] ){
                   consIndex = index;
                 }
                 needSplit = true;
-                */
               }
+              */
               if( needSplit && consIndex!=-1 ) {
                 Node test = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( dt[consIndex].getTester() ), n );
                 Trace("dt-split") << "*************Split for possible constructor " << dt[consIndex] << " for " << n <<  endl;
@@ -446,10 +451,16 @@ void TheoryDatatypes::explain(TNode literal, std::vector<TNode>& assumptions){
   Debug("datatypes-explain") << "Explain " << literal << std::endl;
   bool polarity = literal.getKind() != kind::NOT;
   TNode atom = polarity ? literal : literal[0];
+  std::vector<TNode> tassumptions;
   if (atom.getKind() == kind::EQUAL || atom.getKind() == kind::IFF) {
-    d_equalityEngine.explainEquality(atom[0], atom[1], polarity, assumptions);
+    d_equalityEngine.explainEquality(atom[0], atom[1], polarity, tassumptions);
   } else {
-    d_equalityEngine.explainPredicate(atom, polarity, assumptions);
+    d_equalityEngine.explainPredicate(atom, polarity, tassumptions);
+  }
+  for( unsigned i=0; i<tassumptions.size(); i++ ){
+    if( std::find( assumptions.begin(), assumptions.end(), tassumptions[i] )==assumptions.end() ){
+      assumptions.push_back( tassumptions[i] );
+    }
   }
 }
 
@@ -472,7 +483,7 @@ void TheoryDatatypes::conflict(TNode a, TNode b){
   } else {
     d_conflictNode = explain( a.eqNode(b) );
   }
-  Debug("datatypes-conflict") << "CONFLICT: Eq engine conflict : " << d_conflictNode << std::endl;
+  Trace("dt-conflict") << "CONFLICT: Eq engine conflict : " << d_conflictNode << std::endl;
   d_out->conflict( d_conflictNode );
   d_conflict = true;
 }
@@ -523,7 +534,7 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){
           if( cons1.getOperator()!=cons2.getOperator() ){
             //check for clash
             d_conflictNode = explain( cons1.eqNode( cons2 ) );
-            Debug("datatypes-conflict") << "CONFLICT: Clash conflict : " << d_conflictNode << std::endl;
+            Trace("dt-conflict") << "CONFLICT: Clash conflict : " << d_conflictNode << std::endl;
             d_out->conflict( d_conflictNode );
             d_conflict = true;
             return;
@@ -559,8 +570,7 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){
                   explain( *i, assumptions );
                   explain( cons2.eqNode( (*i)[0][0] ), assumptions );
                   d_conflictNode = NodeManager::currentNM()->mkNode( AND, assumptions );
-                  Debug("datatypes-conflict") << "CONFLICT: Tester merge eq conflict : " << d_conflictNode << std::endl;
-                  Debug("datatypes-conflict-temp") << "CONFLICT: Tester merge eq conflict : " << d_conflictNode << std::endl;
+                  Trace("dt-conflict") << "CONFLICT: Tester merge eq conflict : " << d_conflictNode << std::endl;
                   d_out->conflict( d_conflictNode );
                   d_conflict = true;
                   return;
@@ -615,12 +625,16 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){
       newRep = trep2;
     }
     bool result = d_cycle_check.addEdgeNode( oldRep, newRep );
-    d_hasSeenCycle.set( d_hasSeenCycle.get() || result );
-    Debug("datatypes-cycles") << "DtCyc: Equal " << oldRep << " -> " << newRep << " " << d_hasSeenCycle.get() << endl;
-    if( d_hasSeenCycle.get() ){
+    //d_hasSeenCycle.set( d_hasSeenCycle.get() || result );
+    Debug("datatypes-cycles") << "DtCyc: Equal " << oldRep << " -> " << newRep << " " << result << " " << d_hasSeenCycle.get() << endl;
+    if( result ){
       checkCycles();
       if( d_conflict ){
+        Debug("datatypes-cycles-find") << "Cycle found." << std::endl;
         return;
+      }else{
+        Debug("datatypes-cycles-find") << "WARNING : no cycle found." << std::endl;
+        d_cycle_check.debugPrint();
       }
     }
   }
@@ -694,7 +708,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){
           explain( t, assumptions );
           explain( eqc->d_constructor.get().eqNode( tt[0] ), assumptions );
           d_conflictNode = NodeManager::currentNM()->mkNode( AND, assumptions );
-          Debug("datatypes-conflict") << "CONFLICT: Tester eq conflict : " << d_conflictNode << std::endl;
+          Trace("dt-conflict") << "CONFLICT: Tester eq conflict : " << d_conflictNode << std::endl;
           d_out->conflict( d_conflictNode );
           return;
         }else{
@@ -775,7 +789,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){
       explain( t, assumptions );
       explain( jt[0].eqNode( tt[0] ), assumptions );
       d_conflictNode = NodeManager::currentNM()->mkNode( AND, assumptions );
-      Debug("datatypes-conflict") << "CONFLICT: Tester conflict : " << d_conflictNode << std::endl;
+      Trace("dt-conflict") << "CONFLICT: Tester conflict : " << d_conflictNode << std::endl;
       d_out->conflict( d_conflictNode );
     }
   }
@@ -912,8 +926,11 @@ void TheoryDatatypes::collectTerms( Node n ) {
     if( n.getKind() == APPLY_CONSTRUCTOR ){
       //we must take into account subterm relation when checking for cycles
       for( int i=0; i<(int)n.getNumChildren(); i++ ) {
-        Debug("datatypes-cycles") << "DtCyc: Subterm " << n << " -> " << n[i] << endl;
         bool result = d_cycle_check.addEdgeNode( n, n[i] );
+        Debug("datatypes-cycles") << "DtCyc: Subterm " << n << " -> " << n[i] << " " << result << endl;
+        if( result && !d_hasSeenCycle.get() ){
+          Debug("datatypes-cycles") << "FOUND CYCLE" << std::endl;
+        }
         d_hasSeenCycle.set( d_hasSeenCycle.get() || result );
       }
     }else if( n.getKind() == APPLY_SELECTOR ){
@@ -941,7 +958,7 @@ void TheoryDatatypes::processNewTerm( Node n ){
   Trace("dt-terms") << "Created term : " << n << std::endl;
   //see if it is rewritten to be something different
   Node rn = Rewriter::rewrite( n );
-  if( rn!=n ){
+  if( rn!=n && !areEqual( rn, n ) ){
     Node eq = n.getType().isBoolean() ? rn.iffNode( n ) : rn.eqNode( n );
     d_pending.push_back( eq );
     d_pending_exp[ eq ] = NodeManager::currentNM()->mkConst( true );
@@ -1059,21 +1076,21 @@ void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){
     int index = getLabelIndex( eqc, n );
     const Datatype& dt = ((DatatypeType)(tt.getType()).toType()).getDatatype();
     //must be finite or have a selector
-    //if( eqc->d_selectors || dt[ index ].isFinite() || mustSpecifyAssignment() ){
-      //instantiate this equivalence class
-      eqc->d_inst = true;
-      Node tt_cons = getInstantiateCons( tt, dt, index );
-      Node eq;
-      if( tt!=tt_cons ){
-        eq = tt.eqNode( tt_cons );
-        Debug("datatypes-inst") << "DtInstantiate : " << eqc << " " << eq << std::endl;
-        d_pending.push_back( eq );
-        d_pending_exp[ eq ] = exp;
-        Trace("datatypes-infer") << "DtInfer : " << eq << " by " << exp << std::endl;
-        //eqc->d_inst.set( eq );
-        d_infer.push_back( eq );
-        d_infer_exp.push_back( exp );
-      }
+    //if( eqc->d_selectors || dt[ index ].isFinite() ){ // || mustSpecifyAssignment()
+    //instantiate this equivalence class
+    eqc->d_inst = true;
+    Node tt_cons = getInstantiateCons( tt, dt, index );
+    Node eq;
+    if( tt!=tt_cons ){
+      eq = tt.eqNode( tt_cons );
+      Debug("datatypes-inst") << "DtInstantiate : " << eqc << " " << eq << std::endl;
+      d_pending.push_back( eq );
+      d_pending_exp[ eq ] = exp;
+      Trace("datatypes-infer") << "DtInfer : " << eq << " by " << exp << std::endl;
+      //eqc->d_inst.set( eq );
+      d_infer.push_back( eq );
+      d_infer_exp.push_back( exp );
+    }
     //}
     //else{
     //  Debug("datatypes-inst") << "Do not instantiate" << std::endl;
@@ -1086,26 +1103,38 @@ void TheoryDatatypes::checkCycles() {
   eq::EqClassesIterator eqcs_i = eq::EqClassesIterator( &d_equalityEngine );
   while( !eqcs_i.isFinished() ){
     Node eqc = (*eqcs_i);
-    map< Node, bool > visited;
-    std::vector< TNode > expl;
-    if( searchForCycle( eqc, eqc, visited, expl ) ) {
-      Assert( expl.size()>0 );
-      if( expl.size() == 1 ){
-        d_conflictNode = expl[ 0 ];
-      }else{
-        d_conflictNode = NodeManager::currentNM()->mkNode( AND, expl );
+    if( eqc.getType().isDatatype() ) {
+      map< Node, bool > visited;
+      std::vector< TNode > expl;
+      Node cn = searchForCycle( eqc, eqc, visited, expl );
+      //if we discovered a different cycle while searching this one
+      if( !cn.isNull() && cn!=eqc ){
+        visited.clear();
+        expl.clear();
+        Node prev = cn;
+        cn = searchForCycle( cn, cn, visited, expl );
+        Assert( prev==cn );
+      }
+
+      if( !cn.isNull() ) {
+        Assert( expl.size()>0 );
+        if( expl.size() == 1 ){
+          d_conflictNode = expl[ 0 ];
+        }else{
+          d_conflictNode = NodeManager::currentNM()->mkNode( AND, expl );
+        }
+        Trace("dt-conflict") << "CONFLICT: Cycle conflict : " << d_conflictNode << std::endl;
+        d_out->conflict( d_conflictNode );
+        d_conflict = true;
+        return;
       }
-      Debug("datatypes-conflict") << "CONFLICT: Cycle conflict : " << d_conflictNode << std::endl;
-      d_out->conflict( d_conflictNode );
-      d_conflict = true;
-      return;
     }
     ++eqcs_i;
   }
 }
 
 //postcondition: if cycle detected, explanation is why n is a subterm of on
-bool TheoryDatatypes::searchForCycle( Node n, Node on,
+Node TheoryDatatypes::searchForCycle( Node n, Node on,
                                       map< Node, bool >& visited,
                                       std::vector< TNode >& explanation, bool firstTime ) {
   Debug("datatypes-cycle-check") << "Search for cycle " << n << " " << on << endl;
@@ -1116,19 +1145,20 @@ bool TheoryDatatypes::searchForCycle( Node n, Node on,
     if( nn==on ){
       Node lit = NodeManager::currentNM()->mkNode( EQUAL, n, nn );
       explain( lit, explanation );
-      return true;
+      return on;
     }
   }else{
     nn = n;
   }
   if( visited.find( nn ) == visited.end() ) {
     visited[nn] = true;
-    EqcInfo* eqc = getOrMakeEqcInfo( nn );
+    EqcInfo* eqc = getOrMakeEqcInfo( nn, false );
     if( eqc ){
       Node ncons = eqc->d_constructor.get();
       if( !ncons.isNull() ) {
         for( int i=0; i<(int)ncons.getNumChildren(); i++ ) {
-          if( searchForCycle( ncons[i], on, visited, explanation, false ) ) {
+          Node cn = searchForCycle( ncons[i], on, visited, explanation, false );
+          if( cn==on ) {
             if( Debug.isOn("datatypes-cycles") && !d_cycle_check.isConnectedNode( n, ncons[i] ) ){
               Debug("datatypes-cycles") << "Cycle subterm: " << n << " is not -> " << ncons[i] << "!!!!" << std::endl;
             }
@@ -1137,13 +1167,18 @@ bool TheoryDatatypes::searchForCycle( Node n, Node on,
               Node lit = NodeManager::currentNM()->mkNode( EQUAL, n, ncons );
               explain( lit, explanation );
             }
-            return true;
+            return on;
+          }else if( !cn.isNull() ){
+            return cn;
           }
         }
       }
     }
+    visited.erase( nn );
+    return Node::null();
+  }else{
+    return nn;
   }
-  return false;
 }
 
 bool TheoryDatatypes::mustSpecifyAssignment(){
index 203782a793e8d3f2cfe1267cd56feea6cc58c7ef..225139b2d51cbc946ecaf32709565a88c57b97b4 100644 (file)
@@ -219,7 +219,7 @@ private:
   void merge( Node t1, Node t2 );
   /** for checking if cycles exist */
   void checkCycles();
-  bool searchForCycle( Node n, Node on,
+  Node searchForCycle( Node n, Node on,
                        std::map< Node, bool >& visited,
                        std::vector< TNode >& explanation, bool firstTime = true );
   /** collect terms */
index 96e8692f5f3762adecad1755ffe215e4cb3907a3..8db91da69c466573568891bc306e1a7a3e615187 100644 (file)
@@ -107,6 +107,7 @@ void Datatype::resolve(ExprManager* em,
     Node::fromExpr((*i).d_tester).setAttribute(DatatypeIndexAttr(), index++);
   }
   d_self = self;
+  //d_card = getCardinality();
 }
 
 void Datatype::addConstructor(const DatatypeConstructor& c) {
@@ -125,6 +126,10 @@ Cardinality Datatype::getCardinality() const throw(IllegalArgumentException) {
   for(const_iterator i = begin(), i_end = end(); i != i_end; ++i) {
     c += (*i).getCardinality();
   }
+  //if( d_card!=c ){
+    //std::cout << "Bad card " << std::endl;
+  //}
+
   return c;
 }
 
index c46c10c973c88dfe74471754170094a9d4d0a1a6..01558fd30f73209121e36aa389971545d301f43a 100644 (file)
@@ -404,6 +404,7 @@ private:
   std::vector<DatatypeConstructor> d_constructors;
   bool d_resolved;
   Type d_self;
+  Cardinality d_card;
 
   /**
    * Datatypes refer to themselves, recursively, and we have a
@@ -616,7 +617,8 @@ inline Datatype::Datatype(std::string name) :
   d_params(),
   d_constructors(),
   d_resolved(false),
-  d_self() {
+  d_self(),
+  d_card(1) {
 }
 
 inline Datatype::Datatype(std::string name, const std::vector<Type>& params) :
@@ -624,7 +626,8 @@ inline Datatype::Datatype(std::string name, const std::vector<Type>& params) :
   d_params(params),
   d_constructors(),
   d_resolved(false),
-  d_self() {
+  d_self(),
+  d_card(1) {
 }
 
 inline std::string Datatype::getName() const throw() {
index 90b069485f8e69f8a3d930d3a71269a4c9943ec1..970d2542ee6fc2baa6366c2cbbada18858d47056 100644 (file)
@@ -37,16 +37,22 @@ TransitiveClosure::~TransitiveClosure() {
 
 bool TransitiveClosure::addEdge(unsigned i, unsigned j)
 {
+  Debug("trans-closure") << "Add edge " << i << " -> " << j << std::endl;
   // Check for loops
   Assert(i != j, "Cannot add self-loop");
-  if (adjMatrix.size() > j && adjMatrix[j] != NULL && adjMatrix[j]->read(i)) {
+  if (adjIndex.get() > j && adjMatrix[j] != NULL && adjMatrix[j]->read(i)) {
     return true;
   }
 
   // Grow matrix if necessary
   unsigned maxSize = ((i > j) ? i : j) + 1;
-  while (maxSize > adjMatrix.size()) {
-    adjMatrix.push_back(NULL);
+  while (maxSize > adjIndex.get()) {
+    if( maxSize > adjMatrix.size() ){
+      adjMatrix.push_back(NULL);
+    }else if( adjMatrix[adjIndex.get()]!=NULL ){
+      adjMatrix[adjIndex.get()]->clear();
+    }
+    adjIndex.set( adjIndex.get() + 1 );
   }
 
   // Add edge from i to j and everything j can reach
@@ -60,7 +66,7 @@ bool TransitiveClosure::addEdge(unsigned i, unsigned j)
 
   // Add edges from everything that can reach i to j and everything that j can reach
   unsigned k;
-  for (k = 0; k < adjMatrix.size(); ++k) {
+  for (k = 0; k < adjIndex.get(); ++k) {
     if (adjMatrix[k] != NULL && adjMatrix[k]->read(i)) {
       adjMatrix[k]->write(j);
       if (adjMatrix[j] != NULL) {
@@ -74,7 +80,7 @@ bool TransitiveClosure::addEdge(unsigned i, unsigned j)
 
 bool TransitiveClosure::isConnected(unsigned i, unsigned j)
 {
-  if( i>=adjMatrix.size() || j>adjMatrix.size() ){
+  if( i>=adjIndex.get() || j>adjIndex.get() ){//adjMatrix.size() ){
     return false;
   }else{
     return adjMatrix[i] != NULL && adjMatrix[i]->read(j);
@@ -84,15 +90,15 @@ bool TransitiveClosure::isConnected(unsigned i, unsigned j)
 void TransitiveClosure::debugPrintMatrix()
 {
   unsigned i,j;
-  for (i = 0; i < adjMatrix.size(); ++i) {
-    for (j = 0; j < adjMatrix.size(); ++j) {
+  for (i = 0; i < adjIndex.get(); ++i) {
+    for (j = 0; j < adjIndex.get(); ++j) {
       if (adjMatrix[i] != NULL && adjMatrix[i]->read(j)) {
         Debug("trans-closure") << "1 ";
       }
       else Debug("trans-closure") << "0 ";
     }
     Debug("trans-closure") << endl;
-  }      
+  }
 }
 
 unsigned TransitiveClosureNode::getId( Node i ){
@@ -108,10 +114,14 @@ unsigned TransitiveClosureNode::getId( Node i ){
 
 void TransitiveClosureNode::debugPrint(){
   for( int i=0; i<(int)currEdges.size(); i++ ){
-    Debug("trans-closure") << "currEdges[ " << i << " ] = " 
+    Debug("trans-closure") << "currEdges[ " << i << " ] = "
                            << currEdges[i].first << " -> " << currEdges[i].second;
+    int id1 = getId( currEdges[i].first );
+    int id2 = getId( currEdges[i].second );
+    Debug("trans-closure") << " { " << id1 << " -> " << id2 << " } ";
     Debug("trans-closure") << std::endl;
   }
+  debugPrintMatrix();
 }
 
 
index ce846637d6b798d29f20163921267a0097536180..14e7ab95f52126edf39b12d72fd9bebd02e881fb 100644 (file)
@@ -57,17 +57,17 @@ protected:
   CDBV* next() { return d_next; }
 
 public:
-  CDBV(context::Context* context) : 
+  CDBV(context::Context* context) :
     ContextObj(context), d_data(0), d_next(NULL)
   {}
 
-  ~CDBV() { 
+  ~CDBV() {
     if (d_next != NULL) {
       d_next->deleteSelf();
     }
     destroy();
   }
-
+  void clear() { d_data = 0; if( d_next ) d_next->clear(); }
   bool read(unsigned index) {
     if (index < 64) return (d_data & (uint64_t(1) << index)) != 0;
     else if (d_next == NULL) return false;
@@ -88,7 +88,7 @@ public:
       makeCurrent();
       d_next = new(true) CDBV(getContext());
       d_next->write(index - 64);
-    }      
+    }
   }
 
   void merge(CDBV* pcdbv) {
@@ -108,9 +108,10 @@ public:
 class TransitiveClosure {
   context::Context* d_context;
   std::vector<CDBV* > adjMatrix;
+  context::CDO<unsigned> adjIndex;
 
 public:
-  TransitiveClosure(context::Context* context) : d_context(context) {}
+  TransitiveClosure(context::Context* context) : d_context(context), adjIndex(context,0) {}
   virtual ~TransitiveClosure();
 
   /* Add an edge from node i to node j.  Return false if successful, true if this edge would create a cycle */
@@ -130,7 +131,7 @@ class TransitiveClosureNode : public TransitiveClosure{
   //for debugging
   context::CDList< std::pair< Node, Node > > currEdges;
 public:
-  TransitiveClosureNode(context::Context* context) : 
+  TransitiveClosureNode(context::Context* context) :
     TransitiveClosure(context), d_counter( context, 0 ), nodeMap( context ), currEdges(context) {}
   ~TransitiveClosureNode(){}