From 40449557777db2d1170cb86274f83b431b5fef04 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 28 Apr 2011 23:32:16 +0000 Subject: [PATCH] more fixes/improvements to datatypes theory and transitive closure --- src/theory/datatypes/datatypes_rewriter.h | 3 +- src/theory/datatypes/theory_datatypes.cpp | 602 ++++++++++------------ src/theory/datatypes/theory_datatypes.h | 43 +- src/util/datatype.cpp | 10 + src/util/datatype.h | 8 + src/util/trans_closure.cpp | 30 +- src/util/trans_closure.h | 25 +- 7 files changed, 344 insertions(+), 377 deletions(-) diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index 9bfbaf12e..eea15d6b6 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -22,7 +22,6 @@ #define __CVC4__THEORY__DATATYPES__DATATYPES_REWRITER_H #include "theory/rewriter.h" -#include "theory/datatypes/theory_datatypes.h" namespace CVC4 { namespace theory { @@ -37,7 +36,7 @@ public: if(in.getKind() == kind::APPLY_TESTER) { if(in[0].getKind() == kind::APPLY_CONSTRUCTOR) { - bool result = TheoryDatatypes::checkTrivialTester(in); + bool result = Datatype::indexOf(in.getOperator().toExpr()) == Datatype::indexOf(in[0].getOperator().toExpr()); Debug("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " << "Rewrite trivial tester " << in << " " << result << std::endl; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index 9bc195aed..88667de8d 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -25,6 +25,7 @@ #include +//#define USE_TRANSITIVE_CLOSURE using namespace std; using namespace CVC4; @@ -33,23 +34,31 @@ using namespace CVC4::context; using namespace CVC4::theory; using namespace CVC4::theory::datatypes; -bool TheoryDatatypes::isConstructorFinite( Node cons ){ +const Datatype::Constructor& TheoryDatatypes::getConstructor( Node cons ) +{ Expr consExpr = cons.toExpr(); - size_t consIndex = Datatype::indexOf(consExpr); - const Datatype& dt = Datatype::datatypeOf(consExpr); - const Datatype::Constructor& c = dt[consIndex]; - Debug("datatypes-fin-check") << cons << " is "; - if( !c.isFinite() ){ - Debug("datatypes-fin-check") << "not "; + return Datatype::datatypeOf(consExpr)[ Datatype::indexOf(consExpr) ]; +} + +Node TheoryDatatypes::getConstructorForSelector( Node sel ) +{ + size_t selIndex = Datatype::indexOf( sel.toExpr() ); + const Datatype& dt = ((DatatypeType)((sel.getType()[0]).toType())).getDatatype(); + for( unsigned i = 0; iselIndex && + Node::fromExpr( dt[i][selIndex].getSelector() )==sel ){ + return Node::fromExpr( dt[i].getConstructor() ); + } } - Debug("datatypes-fin-check") << "finite." << std::endl; - return c.isFinite(); + Assert( false ); + return Node::null(); } + TheoryDatatypes::TheoryDatatypes(Context* c, OutputChannel& out, Valuation valuation) : Theory(THEORY_DATATYPES, c, out, valuation), - //d_currAsserts(c), - //d_currEqualities(c), + d_currAsserts(c), + d_currEqualities(c), d_drv_map(c), d_axioms(c), d_selectors(c), @@ -58,6 +67,7 @@ TheoryDatatypes::TheoryDatatypes(Context* c, OutputChannel& out, Valuation valua d_equivalence_class(c), d_inst_map(c), d_cycle_check(c), + d_hasSeenCycle(c, false), d_labels(c), d_ccChannel(this), d_cc(c, &d_ccChannel), @@ -65,42 +75,20 @@ TheoryDatatypes::TheoryDatatypes(Context* c, OutputChannel& out, Valuation valua d_disequalities(c), d_equalities(c), d_conflict(), - d_noMerge(false) { - + d_noMerge(false), + d_inCheck(false){ + + ////bug test for transitive closure + //TransitiveClosure tc( c ); + //Debug("datatypes-tc") << "1 -> 0 : " << tc.addEdge( 1, 0 ) << std::endl; + //Debug("datatypes-tc") << "32 -> 1 : " << tc.addEdge( 32, 1 ) << std::endl; + //tc.debugPrintMatrix(); } TheoryDatatypes::~TheoryDatatypes() { } - -void TheoryDatatypes::addDatatypeDefinitions(TypeNode dttn) { - AssertArgument(dttn.getKind() == DATATYPE_TYPE, dttn, "expected a datatype"); - - Debug("datatypes") << "TheoryDatatypes::addDataTypeDefinitions(): " - << dttn.getConst().getName() << endl; - if(d_addedDatatypes.find(dttn) != d_addedDatatypes.end()) { - // already have incorporated this datatype definition - Debug("datatypes") << "+ can skip" << endl; - return; - } - - const Datatype& dt = dttn.getConst(); - Debug("datatypes") << dt << endl; - for(Datatype::const_iterator it = dt.begin(); it != dt.end(); ++it) { - Node constructor = Node::fromExpr((*it).getConstructor()); - d_cons[dttn].push_back(constructor); - d_testers[dttn].push_back(Node::fromExpr((*it).getTester())); - for(Datatype::Constructor::const_iterator itc = (*it).begin(); itc != (*it).end(); ++itc) { - Node selector = Node::fromExpr((*itc).getSelector()); - d_sels[constructor].push_back(selector); - d_sel_cons[selector] = constructor; - } - } - d_addedDatatypes.insert(dttn); -} - - void TheoryDatatypes::addSharedTerm(TNode t) { Debug("datatypes") << "TheoryDatatypes::addSharedTerm(): " << t << endl; @@ -129,6 +117,10 @@ void TheoryDatatypes::notifyCongruent(TNode lhs, TNode rhs) { Debug("datatypes-debug") << "TheoryDatatypes::notifyCongruent(): done." << endl; } +void TheoryDatatypes::preRegisterTerm(TNode n) { + Debug("datatypes-prereg") << "TheoryDatatypes::preRegisterTerm() " << n << endl; +} + void TheoryDatatypes::presolve() { Debug("datatypes") << "TheoryDatatypes::presolve()" << endl; @@ -136,99 +128,92 @@ void TheoryDatatypes::presolve() { void TheoryDatatypes::check(Effort e) { - //for( int i=0; i<(int)d_currAsserts.size(); i++ ) { - // Debug("datatypes") << "currAsserts[" << i << "] = " << d_currAsserts[i] << endl; - //} - //for( int i=0; i<(int)d_currEqualities.size(); i++ ) { - // Debug("datatypes") << "currEqualities[" << i << "] = " << d_currEqualities[i] << endl; - //} - //for( BoolMap::iterator i = d_inst_map.begin(); i != d_inst_map.end(); i++ ) { - // Debug("datatypes") << "inst_map = " << (*i).first << endl; - //} - //for( EqListsN::iterator i = d_selector_eq.begin(); i != d_selector_eq.end(); i++ ) { - // EqListN* m = (*i).second; - // Debug("datatypes") << "selector_eq " << (*i).first << ":" << endl; - // for( EqListN::const_iterator j = m->begin(); j != m->end(); j++ ) { - // Debug("datatypes") << " : " << (*j) << endl; - // } - //} + for( int i=0; i<(int)d_currAsserts.size(); i++ ) { + Debug("datatypes") << "currAsserts[" << i << "] = " << d_currAsserts[i] << endl; + } + for( int i=0; i<(int)d_currEqualities.size(); i++ ) { + Debug("datatypes") << "currEqualities[" << i << "] = " << d_currEqualities[i] << endl; + } while(!done()) { Node assertion = get(); - if( Debug.isOn("datatypes") || Debug.isOn("datatypes-split") ) { + if( Debug.isOn("datatypes") || Debug.isOn("datatypes-split") || Debug.isOn("datatypes-cycles") ) { cout << "*** TheoryDatatypes::check(): " << assertion << endl; + d_currAsserts.push_back( assertion ); } - //d_currAsserts.push_back( assertion ); //clear from the derived map - if( !d_drv_map[assertion].get().isNull() ) { - Debug("datatypes") << "Assertion has already been derived" << endl; - d_drv_map[assertion] = Node::null(); - } else { - collectTerms( assertion ); - switch(assertion.getKind()) { - case kind::EQUAL: - case kind::IFF: - addEquality(assertion); - break; - case kind::APPLY_TESTER: - checkTester( assertion ); - break; - case kind::NOT: - { - switch( assertion[0].getKind()) { - case kind::EQUAL: - case kind::IFF: - { - Node a = assertion[0][0]; - Node b = assertion[0][1]; - addDisequality(assertion[0]); - d_cc.addTerm(a); - d_cc.addTerm(b); - if(Debug.isOn("datatypes")) { - Debug("datatypes") << " a == > " << a << endl - << " b == > " << b << endl - << " find(a) == > " << debugFind(a) << endl - << " find(b) == > " << debugFind(b) << endl; - } - // There are two ways to get a conflict here. - if(d_conflict.isNull()) { - if(find(a) == find(b)) { - // We get a conflict this way if we WERE previously watching - // a, b and were notified previously (via notifyCongruent()) - // that they were congruent. - NodeBuilder<> nb(kind::AND); - nb << d_cc.explain( assertion[0][0], assertion[0][1] ); - nb << assertion; - d_conflict = nb; - Debug("datatypes") << "Disequality conflict " << d_conflict << endl; - } else { - - // If we get this far, there should be nothing conflicting due - // to this disequality. - Assert(!d_cc.areCongruent(a, b)); + d_inCheck = true; + collectTerms( assertion ); + if( d_conflict.isNull() ){ + if( !d_drv_map[assertion].get().isNull() ) { + Debug("datatypes") << "Assertion has already been derived" << endl; + d_drv_map[assertion] = Node::null(); + } else { + switch(assertion.getKind()) { + case kind::EQUAL: + case kind::IFF: + addEquality(assertion); + break; + case kind::APPLY_TESTER: + checkTester( assertion ); + break; + case kind::NOT: + { + switch( assertion[0].getKind()) { + case kind::EQUAL: + case kind::IFF: + { + Node a = assertion[0][0]; + Node b = assertion[0][1]; + addDisequality(assertion[0]); + d_cc.addTerm(a); + d_cc.addTerm(b); + if(Debug.isOn("datatypes")) { + Debug("datatypes") << " a == > " << a << endl + << " b == > " << b << endl + << " find(a) == > " << debugFind(a) << endl + << " find(b) == > " << debugFind(b) << endl; + } + // There are two ways to get a conflict here. + if(d_conflict.isNull()) { + if(find(a) == find(b)) { + // We get a conflict this way if we WERE previously watching + // a, b and were notified previously (via notifyCongruent()) + // that they were congruent. + NodeBuilder<> nb(kind::AND); + nb << d_cc.explain( assertion[0][0], assertion[0][1] ); + nb << assertion; + d_conflict = nb; + Debug("datatypes") << "Disequality conflict " << d_conflict << endl; + } else { + // If we get this far, there should be nothing conflicting due + // to this disequality. + Assert(!d_cc.areCongruent(a, b)); + } } } + break; + case kind::APPLY_TESTER: + checkTester( assertion ); + break; + default: + Unhandled(assertion[0].getKind()); + break; } - break; - case kind::APPLY_TESTER: - checkTester( assertion ); - break; - default: - Unhandled(assertion[0].getKind()); - break; } + break; + default: + Unhandled(assertion.getKind()); + break; } - break; - default: - Unhandled(assertion.getKind()); - break; - } - if(!d_conflict.isNull()) { - throwConflict(); - return; } } + d_inCheck = false; + if(!d_conflict.isNull()) { + throwConflict(); + return; + } } if( e == FULL_EFFORT ) { @@ -237,20 +222,14 @@ void TheoryDatatypes::check(Effort e) { for( EqLists::iterator i = d_labels.begin(); i != d_labels.end(); i++ ) { Node sf = find( (*i).first ); if( sf == (*i).first || sf.getKind() != APPLY_CONSTRUCTOR ) { - Debug("datatypes-split") << "Check for splitting " << (*i).first << ", "; EqList* lbl = (sf == (*i).first) ? (*i).second : (*d_labels.find( sf )).second; - if( lbl->empty() ) { - Debug("datatypes-split") << "empty label" << endl; - } else { - Debug("datatypes-split") << "label size = " << lbl->size() << endl; - } + Debug("datatypes-split") << "Check for splitting " << (*i).first << ", "; + Debug("datatypes-split") << "label size = " << lbl->size() << endl; Node cons = getPossibleCons( (*i).first, false ); if( !cons.isNull() ) { + const Datatype::Constructor& cn = getConstructor( cons ); Debug("datatypes-split") << "*************Split for possible constructor " << cons << endl; - TypeNode typ = (*i).first.getType(); - int cIndex = Datatype::indexOf( cons.toExpr() ); - Assert( cIndex != -1 ); - Node test = NodeManager::currentNM()->mkNode( APPLY_TESTER, d_testers[typ][cIndex], (*i).first ); + Node test = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( cn.getTester() ), (*i).first ); NodeBuilder<> nb(kind::OR); nb << test << test.notNode(); Node lemma = nb; @@ -272,6 +251,7 @@ void TheoryDatatypes::checkTester( Node assertion, bool doAdd ) { Debug("datatypes") << "Check tester " << assertion << endl; Node tassertion = ( assertion.getKind() == NOT ) ? assertion[0] : assertion; + const Datatype& dt = Datatype::datatypeOf( tassertion.getOperator().toExpr() ); //add the term into congruence closure consideration d_cc.addTerm( tassertion[0] ); @@ -317,7 +297,7 @@ void TheoryDatatypes::checkTester( Node assertion, bool doAdd ) { //check if empty label (no possible constructors for term) bool add = true; - int notCount = 0; + unsigned int notCount = 0; if( !lbl->empty() ) { for( EqList::const_iterator i = lbl->begin(); i != lbl->end(); i++ ) { Node leqn = (*i); @@ -351,8 +331,7 @@ void TheoryDatatypes::checkTester( Node assertion, bool doAdd ) { } } if( add ) { - //Assert( (int)d_cons[ tRep.getType() ].size()== Datatype::datatypeOf(tassertion.getOperator).getNumConstructors() ); - if( assertionRep.getKind() == NOT && notCount == (int)d_cons[ tRep.getType() ].size()-1 ) { + if( assertionRep.getKind() == NOT && notCount == dt.getNumConstructors()-1 ) { NodeBuilder<> nb(kind::AND); if( !lbl->empty() ) { for( EqList::const_iterator i = lbl->begin(); i != lbl->end(); i++ ) { @@ -409,20 +388,16 @@ void TheoryDatatypes::checkInstantiate( Node t ) { //there is one remaining constructor if( !cons.isNull() && lbl_i != d_labels.end() ) { EqList* lbl = (*lbl_i).second; + const Datatype::Constructor& cn = Datatype::datatypeOf( cons.toExpr() )[ Datatype::indexOf( cons.toExpr() ) ]; + //only one constructor possible for term (label is singleton), apply instantiation rule - bool consFinite = isConstructorFinite( cons ); - Debug("datatypes-fin-check") << "checkInst: " << cons << " is "; - if( !consFinite ){ - Debug("datatypes-fin-check") << "not "; - } - Debug("datatypes-fin-check") << "finite. " << std::endl; //find if selectors have been applied to t vector< Node > selectorVals; selectorVals.push_back( cons ); NodeBuilder<> justifyEq(kind::AND); bool foundSel = false; - for( int i=0; i<(int)d_sels[cons].size(); i++ ) { - Node s = NodeManager::currentNM()->mkNode( APPLY_SELECTOR, d_sels[cons][i], te ); + for( unsigned int i=0; imkNode( APPLY_SELECTOR, Node::fromExpr( cn[i].getSelector() ), te ); Debug("datatypes") << "Selector[" << i << "] = " << s << endl; if( d_selectors.find( s ) != d_selectors.end() ) { Node sf = find( s ); @@ -434,7 +409,7 @@ void TheoryDatatypes::checkInstantiate( Node t ) { } selectorVals.push_back( s ); } - if( consFinite || foundSel ) { + if( cn.isFinite() || foundSel ) { d_inst_map[ te ] = true; //instantiate, add equality Node val = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, selectorVals ); @@ -450,13 +425,9 @@ void TheoryDatatypes::checkInstantiate( Node t ) { } } } - Node jeq; - if( justifyEq.getNumChildren() == 1 ) { - jeq = justifyEq.getChild( 0 ); - } else { - jeq = justifyEq; - } + Node jeq = ( justifyEq.getNumChildren() == 1 ) ? justifyEq.getChild( 0 ) : justifyEq; Debug("datatypes-split") << "Instantiate " << newEq << endl; + preRegisterTerm( val ); addDerivedEquality( newEq, jeq ); return; } @@ -477,15 +448,18 @@ Node TheoryDatatypes::getPossibleCons( Node t, bool checkInst ) { EqList* lbl = (*lbl_i).second; TypeNode typ = t.getType(); + const Datatype& dt = ((DatatypeType)typ.toType()).getDatatype(); + //if ended by one positive tester if( !lbl->empty() && (*lbl)[ lbl->size()-1 ].getKind() != NOT ) { if( checkInst ) { - return d_cons[typ][ Datatype::indexOf( (*lbl)[ lbl->size()-1 ].getOperator().toExpr() ) ]; + size_t testerIndex = Datatype::indexOf( (*lbl)[ lbl->size()-1 ].getOperator().toExpr() ); + return Node::fromExpr( dt[ testerIndex ].getConstructor() ); } //if (n-1) negative testers - } else if( !checkInst || (int)lbl->size() == (int)d_cons[ t.getType() ].size()-1 ) { + } else if( !checkInst || lbl->size() == dt.getNumConstructors()-1 ) { vector< bool > possibleCons; - possibleCons.resize( (int)d_cons[ t.getType() ].size(), true ); + possibleCons.resize( dt.getNumConstructors(), true ); if( !lbl->empty() ) { for( EqList::const_iterator i = lbl->begin(); i != lbl->end(); i++ ) { TNode leqn = (*i); @@ -493,13 +467,13 @@ Node TheoryDatatypes::getPossibleCons( Node t, bool checkInst ) { } } Node cons = Node::null(); - for( int i=0; i<(int)possibleCons.size(); i++ ) { + for( unsigned int i=0; imkNode( APPLY_SELECTOR, d_sels[cons][i], tf ); + for( unsigned int j=0; jmkNode( APPLY_SELECTOR, Node::fromExpr( dt[i][j].getSelector() ), tf ); if( d_selectors.find( s ) != d_selectors.end() ) { Debug("datatypes") << " getPosCons: found selector " << s << endl; return cons; @@ -510,9 +484,9 @@ Node TheoryDatatypes::getPossibleCons( Node t, bool checkInst ) { } if( !checkInst ) { for( int i=0; i<(int)possibleCons.size(); i++ ) { - if( possibleCons[i] && !isConstructorFinite( d_cons[typ][ i ] ) ) { + if( possibleCons[i] && !dt[ i ].isFinite() ) { Debug("datatypes") << "Did not find selector for " << tf; - Debug("datatypes") << " and " << d_cons[typ][ i ] << " is not finite." << endl; + Debug("datatypes") << " and " << dt[ i ].getConstructor() << " is not finite." << endl; return Node::null(); } } @@ -544,12 +518,17 @@ Node TheoryDatatypes::getValue(TNode n) { void TheoryDatatypes::merge(TNode a, TNode b) { if( d_noMerge ) { - Debug("datatypes") << "Append to merge pending list " << d_merge_pending.size() << endl; + //Debug("datatypes") << "Append to merge pending list " << d_merge_pending.size() << endl; d_merge_pending[d_merge_pending.size()-1].push_back( pair< Node, Node >( a, b ) ); return; } Assert(d_conflict.isNull()); - Debug("datatypes") << "Merge "<< a << " " << b << endl; + a = find(a); + b = find(b); + if( a == b) { + return; + } + Debug("datatypes-cycles") << "Merge "<< a << " " << b << endl; // make "a" the one with shorter diseqList EqLists::iterator deq_ia = d_disequalities.find(a); @@ -564,14 +543,6 @@ void TheoryDatatypes::merge(TNode a, TNode b) { } } - a = find(a); - b = find(b); - - //Debug("datatypes") << "After find: "<< a << " " << b << endl; - - if( a == b) { - return; - } //if b is a selector, swap a and b if( b.getKind() == APPLY_SELECTOR && a.getKind() != APPLY_SELECTOR ) { TNode tmp = a; @@ -591,23 +562,28 @@ void TheoryDatatypes::merge(TNode a, TNode b) { b = tmp; } - + //check for clash NodeBuilder<> explanation(kind::AND); - if( checkClash( a, b, explanation ) ) { + if( a.getKind() == kind::APPLY_CONSTRUCTOR && b.getKind() == kind::APPLY_CONSTRUCTOR + && a.getOperator()!=b.getOperator() ){ explanation << d_cc.explain( a, b ); d_conflict = explanation.getNumChildren() == 1 ? explanation.getChild( 0 ) : explanation; Debug("datatypes") << "Clash " << a << " " << b << endl; Debug("datatypes") << "Conflict is " << d_conflict << endl; - return; + return; } Debug("datatypes-debug") << "Done clash" << endl; Debug("datatypes") << "Set canon: "<< a << " " << b << endl; - // b becomes the canon of a d_unionFind.setCanon(a, b); d_reps[a] = false; d_reps[b] = true; +#ifdef USE_TRANSITIVE_CLOSURE + bool result = d_cycle_check.addEdgeNode( a, b ); + d_hasSeenCycle.set( d_hasSeenCycle.get() || result ); +#endif + //merge equivalence classes initializeEqClass( a ); initializeEqClass( b ); @@ -617,8 +593,6 @@ void TheoryDatatypes::merge(TNode a, TNode b) { eqc_b->push_back( *i ); } - //Debug("datatypes") << "After check 1" << endl; - deq_ia = d_disequalities.find(a); map alreadyDiseqs; if(deq_ia != d_disequalities.end()) { @@ -689,17 +663,37 @@ void TheoryDatatypes::merge(TNode a, TNode b) { } //Debug("datatypes-debug") << "Done clash" << endl; - //if( d_cycle_check.addEdgeNode( a, b ) ){ +#ifdef USE_TRANSITIVE_CLOSURE + Debug("datatypes-cycles") << "Equal " << a << " -> " << b << " " << d_hasSeenCycle.get() << endl; + if( d_hasSeenCycle.get() ){ + checkCycles(); + if( !d_conflict.isNull() ){ + return; + } + }else{ + checkCycles(); + if( !d_conflict.isNull() ){ + Debug("datatypes-cycles") << "Cycle is " << d_conflict << std::endl; + for( int i=0; i<(int)d_currEqualities.size(); i++ ) { + Debug("datatypes-cycles") << "currEqualities[" << i << "] = " << d_currEqualities[i] << endl; + } + d_cycle_check.debugPrint(); + Assert( false ); + } + } +#else checkCycles(); - //Assert( !d_conflict.isNull() ); - if( !d_conflict.isNull() ) { + if( !d_conflict.isNull() ){ return; } - //} +#endif Debug("datatypes-debug") << "Done cycles" << endl; //merge selector lists updateSelectors( a ); + if( !d_conflict.isNull() ){ + return; + } Debug("datatypes-debug") << "Done collapse" << endl; //merge labels @@ -718,18 +712,15 @@ void TheoryDatatypes::merge(TNode a, TNode b) { Debug("datatypes-debug") << "Done merge labels" << endl; //do unification - if( d_conflict.isNull() ) { - if( a.getKind() == APPLY_CONSTRUCTOR && b.getKind() == APPLY_CONSTRUCTOR && - a.getOperator() == b.getOperator() ) { - Debug("datatypes") << "Unification: " << a << " and " << b << "." << endl; - for( int i=0; i<(int)a.getNumChildren(); i++ ) { - if( find( a[i] ) != find( b[i] ) ) { - Node newEq = NodeManager::currentNM()->mkNode( EQUAL, a[i], b[i] ); - Node jEq = d_cc.explain(a, b); - Debug("datatypes-drv") << "UEqual: " << newEq << ", justification: " << jEq << " from " << a << " " << b << endl; - Debug("datatypes-drv") << "UEqual find: " << find( a[i] ) << " " << find( b[i] ) << endl; - addDerivedEquality( newEq, jEq ); - } + Assert( d_conflict.isNull() ); + if( a.getKind() == APPLY_CONSTRUCTOR && b.getKind() == APPLY_CONSTRUCTOR && + a.getOperator() == b.getOperator() ) { + Debug("datatypes") << "Unification: " << a << " and " << b << "." << endl; + for( int i=0; i<(int)a.getNumChildren(); i++ ) { + if( find( a[i] ) != find( b[i] ) ) { + Node newEq = NodeManager::currentNM()->mkNode( EQUAL, a[i], b[i] ); + Node jEq = d_cc.explain(a, b); + addDerivedEquality( newEq, jEq ); } } } @@ -743,19 +734,13 @@ Node TheoryDatatypes::collapseSelector( TNode t, bool useContext ) { TypeNode typ = t[0].getType(); Node sel = t.getOperator(); TypeNode selType = sel.getType(); - Node cons = d_sel_cons[sel]; + Node cons = getConstructorForSelector( sel ); + const Datatype::Constructor& cn = getConstructor( cons ); Node tmp = find( t[0] ); Node retNode = t; if( tmp.getKind() == APPLY_CONSTRUCTOR ) { if( tmp.getOperator() == cons ) { - int selIndex = -1; - for(int i=0; i<(int)d_sels[cons].size(); i++ ) { - if( d_sels[cons][i] == sel ) { - selIndex = i; - break; - } - } - Assert( selIndex != -1 ); + size_t selIndex = Datatype::indexOf( sel.toExpr() ); Debug("datatypes") << "Applied selector " << t << " to correct constructor, index = " << selIndex << endl; Debug("datatypes") << "Return " << tmp[selIndex] << endl; retNode = tmp[selIndex]; @@ -773,10 +758,8 @@ Node TheoryDatatypes::collapseSelector( TNode t, bool useContext ) { } } else { if( useContext ) { - int cIndex = Datatype::indexOf( cons.toExpr() ); - Assert( cIndex != -1 ); //check labels - Node tester = NodeManager::currentNM()->mkNode( APPLY_TESTER, d_testers[typ][cIndex], tmp ); + Node tester = NodeManager::currentNM()->mkNode( APPLY_TESTER, Node::fromExpr( cn.getTester() ), tmp ); checkTester( tester, false ); if( !d_conflict.isNull() ) { Debug("datatypes") << "Applied selector " << t << " to provably wrong constructor." << endl; @@ -853,39 +836,68 @@ void TheoryDatatypes::updateSelectors( Node a ) { } } -void TheoryDatatypes::collectTerms( TNode t ) { - for( int i=0; i<(int)t.getNumChildren(); i++ ) { - collectTerms( t[i] ); -#if 0 - if( t.getKind() == APPLY_CONSTRUCTOR ){ - if( d_cycle_check.addEdgeNode( t, t[i] ) ){ - checkCycles(); - //Assert( !d_conflict.isNull() ); - if( !d_conflict.isNull() ){ - return; - } +void TheoryDatatypes::addTermToLabels( Node t ) { + if( t.getKind() == VARIABLE || t.getKind() == APPLY_SELECTOR ) { + Node tmp = find( t ); + if( tmp == t ) { + //add to labels + EqLists::iterator lbl_i = d_labels.find(t); + if(lbl_i == d_labels.end()) { + EqList* lbl = new(getContext()->getCMM()) EqList(true, getContext(), false, + ContextMemoryAllocator(getContext()->getCMM())); + d_labels.insertDataFromContextMemory(tmp, lbl); } } + } +} + +void TheoryDatatypes::initializeEqClass( Node t ) { + EqListsN::iterator eqc_i = d_equivalence_class.find( t ); + if( eqc_i == d_equivalence_class.end() ) { + EqListN* eqc = new(getContext()->getCMM()) EqListN(true, getContext(), false, + ContextMemoryAllocator(getContext()->getCMM())); + eqc->push_back( t ); + d_equivalence_class.insertDataFromContextMemory(t, eqc); + } +} + +void TheoryDatatypes::collectTerms( Node n ) { + for( int i=0; i<(int)n.getNumChildren(); i++ ) { + collectTerms( n[i] ); + } + if( n.getKind() == APPLY_CONSTRUCTOR ){ +#ifdef USE_TRANSITIVE_CLOSURE + for( int i=0; i<(int)n.getNumChildren(); i++ ) { + Debug("datatypes-cycles") << "Subterm " << n << " -> " << n[i] << endl; + bool result = d_cycle_check.addEdgeNode( n, n[i] ); + //if( result ){ + // for( int i=0; i<(int)d_currEqualities.size(); i++ ) { + // Debug("datatypes-cycles") << "currEqualities[" << i << "] = " << d_currEqualities[i] << endl; + // } + // d_cycle_check.debugPrint(); + //} + Assert( !result ); //this should not create any new cycles (relevant terms should have been recorded before) + } #endif } - if( t.getKind() == APPLY_SELECTOR ) { - if( d_selectors.find( t ) == d_selectors.end() ) { - Debug("datatypes-split") << " Found selector " << t << endl; - d_selectors[ t ] = true; - d_cc.addTerm( t ); - Node tmp = find( t[0] ); + if( n.getKind() == APPLY_SELECTOR ) { + if( d_selectors.find( n ) == d_selectors.end() ) { + Debug("datatypes-split") << " Found selector " << n << endl; + d_selectors[ n ] = true; + d_cc.addTerm( n ); + Node tmp = find( n[0] ); checkInstantiate( tmp ); - Node s = t; - if( tmp != t[0] ) { - s = NodeManager::currentNM()->mkNode( APPLY_SELECTOR, t.getOperator(), tmp ); + Node s = n; + if( tmp != n[0] ) { + s = NodeManager::currentNM()->mkNode( APPLY_SELECTOR, n.getOperator(), tmp ); } Debug("datatypes-split") << " Before collapse: " << s << endl; s = collapseSelector( s, true ); Debug("datatypes-split") << " After collapse: " << s << endl; if( s.getKind() == APPLY_SELECTOR ) { //add selector to selector eq list - Debug("datatypes") << " Add selector to list " << tmp << " " << t << endl; + Debug("datatypes") << " Add selector to list " << tmp << " " << n << endl; EqListsN::iterator sel_i = d_selector_eq.find( tmp ); EqListN* sel; if( sel_i == d_selector_eq.end() ) { @@ -901,32 +913,7 @@ void TheoryDatatypes::collectTerms( TNode t ) { } } } - addTermToLabels( t ); -} - -void TheoryDatatypes::addTermToLabels( Node t ) { - if( t.getKind() == VARIABLE || t.getKind() == APPLY_SELECTOR ) { - Node tmp = find( t ); - if( tmp == t ) { - //add to labels - EqLists::iterator lbl_i = d_labels.find(t); - if(lbl_i == d_labels.end()) { - EqList* lbl = new(getContext()->getCMM()) EqList(true, getContext(), false, - ContextMemoryAllocator(getContext()->getCMM())); - d_labels.insertDataFromContextMemory(tmp, lbl); - } - } - } -} - -void TheoryDatatypes::initializeEqClass( Node t ) { - EqListsN::iterator eqc_i = d_equivalence_class.find( t ); - if( eqc_i == d_equivalence_class.end() ) { - EqListN* eqc = new(getContext()->getCMM()) EqListN(true, getContext(), false, - ContextMemoryAllocator(getContext()->getCMM())); - eqc->push_back( t ); - d_equivalence_class.insertDataFromContextMemory(t, eqc); - } + addTermToLabels( n ); } void TheoryDatatypes::appendToDiseqList(TNode of, TNode eq) { @@ -950,27 +937,6 @@ void TheoryDatatypes::appendToDiseqList(TNode of, TNode eq) { //} } -void TheoryDatatypes::appendToEqList(TNode of, TNode eq) { - Debug("datatypes") << "appending " << eq << endl - << " to eq list of " << of << endl; - Assert(eq.getKind() == kind::EQUAL || - eq.getKind() == kind::IFF); - Assert(of == debugFind(of)); - EqLists::iterator eq_i = d_equalities.find(of); - EqList* eql; - if(eq_i == d_equalities.end()) { - eql = new(getContext()->getCMM()) EqList(true, getContext(), false, - ContextMemoryAllocator(getContext()->getCMM())); - d_equalities.insertDataFromContextMemory(of, eql); - } else { - eql = (*eq_i).second; - } - eql->push_back(eq); - //if(Debug.isOn("uf")) { - // Debug("uf") << " size is now " << eql->size() << endl; - //} -} - void TheoryDatatypes::addDerivedEquality(TNode eq, TNode jeq) { Debug("datatypes-drv") << "Justification for " << eq << "is: " << jeq << "." << endl; d_drv_map[eq] = jeq; @@ -982,21 +948,32 @@ void TheoryDatatypes::addEquality(TNode eq) { eq.getKind() == kind::IFF); if( eq[0] != eq[1] ) { Debug("datatypes") << "Add equality " << eq << "." << endl; + + //setup merge pending list d_merge_pending.push_back( vector< pair< Node, Node > >() ); bool prevNoMerge = d_noMerge; d_noMerge = true; + d_cc.addTerm(eq[0]); d_cc.addTerm(eq[1]); d_cc.addEquality(eq); - //d_currEqualities.push_back(eq); + if( Debug.isOn("datatypes") || Debug.isOn("datatypes-cycles") ){ + d_currEqualities.push_back(eq); + } + + //record which nodes are waiting to be merged d_noMerge = prevNoMerge; - unsigned int mpi = d_merge_pending.size()-1; vector< pair< Node, Node > > mp; - mp.insert( mp.begin(), d_merge_pending[mpi].begin(), d_merge_pending[mpi].end() ); + mp.insert( mp.begin(), + d_merge_pending[d_merge_pending.size()-1].begin(), + d_merge_pending[d_merge_pending.size()-1].end() ); d_merge_pending.pop_back(); + + //merge original nodes if( d_conflict.isNull() ) { merge(eq[0], eq[1]); } + //merge nodes waiting to be merged for( int i=0; i<(int)mp.size(); i++ ) { if( d_conflict.isNull() ) { merge( mp[i].first, mp[i].second ); @@ -1016,40 +993,16 @@ void TheoryDatatypes::addDisequality(TNode eq) { appendToDiseqList(find(b), eq); } -void TheoryDatatypes::registerEqualityForPropagation(TNode eq) { - // should NOT be in search at this point, this must be called during - // preregistration - - // FIXME with lemmas on demand, this could miss future propagations, - // since we are not necessarily at context level 0, but are updating - // context-sensitive structures. - - Assert(eq.getKind() == kind::EQUAL || - eq.getKind() == kind::IFF); - - TNode a = eq[0]; - TNode b = eq[1]; - - appendToEqList(find(a), eq); - appendToEqList(find(b), eq); -} - void TheoryDatatypes::throwConflict() { Debug("datatypes") << "Convert conflict : " << d_conflict << endl; NodeBuilder<> nb(kind::AND); convertDerived( d_conflict, nb ); - if( nb.getNumChildren() == 1 ) { - d_conflict = nb.getChild( 0 ); - } else { - d_conflict = nb; - } - if( Debug.isOn("datatypes") || Debug.isOn("datatypes-split") ) { + d_conflict = ( nb.getNumChildren() == 1 ) ? nb.getChild( 0 ) : nb; + if( Debug.isOn("datatypes") || Debug.isOn("datatypes-split") || Debug.isOn("datatypes-cycles") ) { cout << "Conflict constructed : " << d_conflict << endl; } - //if( d_conflict.getKind() != kind::AND ) { - // NodeBuilder<> nb(kind::AND); - // nb << d_conflict << d_conflict; - // d_conflict = nb; + //if( d_conflict.getKind()!=kind::AND ){ + // d_conflict = NodeManager::currentNM()->mkNode(kind::AND, d_conflict, d_conflict); //} d_out->conflict( d_conflict, false ); d_conflict = Node::null(); @@ -1095,8 +1048,23 @@ bool TheoryDatatypes::searchForCycle( Node n, Node on, if( visited.find( nn ) == visited.end() ) { visited[nn] = true; if( nn == on || searchForCycle( nn, on, visited, explanation ) ) { + if( !d_cycle_check.isConnectedNode( n, n[i] ) ){ + Debug("datatypes-cycles") << "Cycle subterm: " << n << " is not -> " << n[i] << "!!!!" << std::endl; + } if( nn != n[i] ) { - explanation << d_cc.explain( nn, n[i] ); + Node e = d_cc.explain( nn, n[i] ); + if( !d_cycle_check.isConnectedNode( n[i], nn ) ){ + Debug("datatypes-cycles") << "Cycle equality: " << n[i] << " is not -> " << nn << "!!!!" << std::endl; + Debug("datatypes-cycles") << "Explanation: " << e << std::endl; + if( e.getKind()==kind::AND ){ + for( int a=0; a& explanation ) { - //Debug("datatypes") << "Check clash " << n1 << " " << n2 << endl; - Node n1f = find( n1 ); - Node n2f = find( n2 ); - bool retVal = false; - if( n1f != n2f ) { - if( n1f.getKind() == kind::APPLY_CONSTRUCTOR && n2f.getKind() == kind::APPLY_CONSTRUCTOR ) { - if( n1f.getOperator() != n2f.getOperator() ) { - retVal =true; - } else { - Assert( n1f.getNumChildren() == n2f.getNumChildren() ); - for( int i=0; i<(int)n1f.getNumChildren(); i++ ) { - if( checkClash( n1f[i], n2f[i], explanation ) ) { - retVal = true; - break; - } - } - } - } - if( retVal ) { - if( n1f != n1 ) { - explanation << d_cc.explain( n1f, n1 ); - } - if( n2f != n2 ) { - explanation << d_cc.explain( n2f, n2 ); - } - } - } - return retVal; -} diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index d6fc837fd..1a944a6e0 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -44,20 +44,9 @@ private: typedef context::CDMap EqListsN; typedef context::CDMap< Node, bool, NodeHashFunction > BoolMap; - std::hash_set d_addedDatatypes; - - //context::CDList d_currAsserts; - //context::CDList d_currEqualities; - - //TODO: the following 4 maps can be eliminated - /** a list of types with the list of constructors for that type */ - std::map > d_cons; - /** a list of types with the list of constructors for that type */ - std::map > d_testers; - /** a list of constructors with the list of selectors */ - std::map > d_sels; - /** map from selectors to the constructors they are for */ - std::map d_sel_cons; + /** for debugging */ + context::CDList d_currAsserts; + context::CDList d_currEqualities; /** map from equalties and the equalities they are derived from */ context::CDMap< Node, Node, NodeHashFunction > d_drv_map; @@ -75,8 +64,12 @@ private: BoolMap d_inst_map; /** transitive closure to record equivalence/subterm relation. */ TransitiveClosureNode d_cycle_check; - /** check whether constructor is finite */ - bool isConstructorFinite( Node cons ); + /** has seen cycle */ + context::CDO< bool > d_hasSeenCycle; + /** get the constructor for the node */ + const Datatype::Constructor& getConstructor( Node cons ); + /** get the constructor for the selector */ + Node getConstructorForSelector( Node sel ); /** * map from terms to testers asserted for that term @@ -140,15 +133,11 @@ private: */ bool d_noMerge; std::vector< std::vector< std::pair< Node, Node > > > d_merge_pending; + bool d_inCheck; public: TheoryDatatypes(context::Context* c, OutputChannel& out, Valuation valuation); ~TheoryDatatypes(); - void preRegisterTerm(TNode n) { - TypeNode type = n.getType(); - if(type.getKind() == kind::DATATYPE_TYPE) { - addDatatypeDefinitions(type); - } - } + void preRegisterTerm(TNode n); void presolve(); void addSharedTerm(TNode t); @@ -158,30 +147,26 @@ public: void shutdown() { } std::string identify() const { return std::string("TheoryDatatypes"); } - void addDatatypeDefinitions(TypeNode dttn); - private: /* Helper methods */ void checkTester( Node assertion, bool doAdd = true ); - static bool checkTrivialTester(Node assertion); + bool checkTrivialTester(Node assertion); void checkInstantiate( Node t ); Node getPossibleCons( Node t, bool checkInst = false ); Node collapseSelector( TNode t, bool useContext = false ); void updateSelectors( Node a ); - void collectTerms( TNode t ); void addTermToLabels( Node t ); void initializeEqClass( Node t ); + void collectTerms( Node n ); /* from uf_morgan */ void merge(TNode a, TNode b); inline TNode find(TNode a); inline TNode debugFind(TNode a) const; void appendToDiseqList(TNode of, TNode eq); - void appendToEqList(TNode of, TNode eq); void addDisequality(TNode eq); void addDerivedEquality(TNode eq, TNode jeq); void addEquality(TNode eq); - void registerEqualityForPropagation(TNode eq); void convertDerived(Node n, NodeBuilder<>& nb); void throwConflict(); @@ -189,8 +174,6 @@ private: bool searchForCycle( Node n, Node on, std::map< Node, bool >& visited, NodeBuilder<>& explanation ); - bool checkClash( Node n1, Node n2, NodeBuilder<>& explanation ); - friend class DatatypesRewriter;// for access to checkTrivialTester(); };/* class TheoryDatatypes */ inline TNode TheoryDatatypes::find(TNode a) { diff --git a/src/util/datatype.cpp b/src/util/datatype.cpp index 20d63995f..2a3f69fd6 100644 --- a/src/util/datatype.cpp +++ b/src/util/datatype.cpp @@ -399,6 +399,10 @@ void Datatype::Constructor::resolve(ExprManager* em, DatatypeType self, d_tester = em->mkVar(d_name.substr(d_name.find('\0') + 1), em->mkTesterType(self)); d_name.resize(d_name.find('\0')); d_constructor = em->mkVar(d_name, em->mkConstructorType(*this, self)); + //associate constructor with all selectors + for(iterator i = begin(), i_end = end(); i != i_end; ++i) { + (*i).d_constructor = d_constructor; + } } Datatype::Constructor::Constructor(std::string name, std::string tester) : @@ -605,6 +609,12 @@ Expr Datatype::Constructor::Arg::getSelector() const { return d_selector; } +Expr Datatype::Constructor::Arg::getConstructor() const { + CheckArgument(isResolved(), this, + "cannot get a associated constructor for argument of an unresolved datatype constructor"); + return d_constructor; +} + bool Datatype::Constructor::Arg::isUnresolvedSelf() const throw() { return d_selector.isNull() && d_name.size() == d_name.find('\0') + 1; } diff --git a/src/util/datatype.h b/src/util/datatype.h index df7dd1814..75da1405f 100644 --- a/src/util/datatype.h +++ b/src/util/datatype.h @@ -147,6 +147,8 @@ public: std::string d_name; Expr d_selector; + /** the constructor associated with this selector */ + Expr d_constructor; bool d_resolved; explicit Arg(std::string name, Expr selector); @@ -166,6 +168,12 @@ public: */ Expr getSelector() const; + /** + * Get the associated constructor for this constructor argument; this call is + * only permitted after resolution. + */ + Expr getConstructor() const; + /** * Get the name of the type of this constructor argument * (Datatype field). Can be used for not-yet-resolved Datatypes diff --git a/src/util/trans_closure.cpp b/src/util/trans_closure.cpp index a31dc3378..43c8735ad 100644 --- a/src/util/trans_closure.cpp +++ b/src/util/trans_closure.cpp @@ -74,6 +74,14 @@ bool TransitiveClosure::addEdge(unsigned i, unsigned j) return false; } +bool TransitiveClosure::isConnected(unsigned i, unsigned j) +{ + if( i>=adjMatrix.size() || j>adjMatrix.size() ){ + return false; + }else{ + return adjMatrix[i] != NULL && adjMatrix[i]->read(j); + } +} void TransitiveClosure::debugPrintMatrix() { @@ -89,16 +97,24 @@ void TransitiveClosure::debugPrintMatrix() } } -unsigned TransitiveClosureNode::d_counter = 0; - unsigned TransitiveClosureNode::getId( Node i ){ - std::map< Node, unsigned >::iterator it = nodeMap.find( i ); + context::CDMap< Node, unsigned, NodeHashFunction >::iterator it = nodeMap.find( i ); if( it==nodeMap.end() ){ - nodeMap[i] = d_counter; - d_counter++; - return d_counter-1; + unsigned c = d_counter.get(); + nodeMap[i] = c; + d_counter.set( c + 1 ); + return c; + } + return (*it).second; +} + +void TransitiveClosureNode::debugPrint(){ + for( int i=0; i<(int)currEdges.size(); i++ ){ + cout << "currEdges[ " << i << " ] = " + << currEdges[i].first << " -> " << currEdges[i].second; + //<< "(" << getId( currEdges[i].first ) << " -> " << getId( currEdges[i].second ) << ")"; + cout << std::endl; } - return it->second; } diff --git a/src/util/trans_closure.h b/src/util/trans_closure.h index 4d811d0c9..af16d2e13 100644 --- a/src/util/trans_closure.h +++ b/src/util/trans_closure.h @@ -23,6 +23,10 @@ #include "expr/node.h" #include +#include "context/cdlist.h" +#include "context/cdmap.h" +#include "context/cdo.h" + namespace CVC4 { /* @@ -65,14 +69,14 @@ public: } bool read(unsigned index) { - if (index < 64) return (d_data & (unsigned(1) << index)) != 0; + if (index < 64) return (d_data & (uint64_t(1) << index)) != 0; else if (d_next == NULL) return false; else return d_next->read(index - 64); } void write(unsigned index) { if (index < 64) { - unsigned mask = unsigned(1) << index; + unsigned mask = uint64_t(1) << index; if ((d_data & mask) != 0) return; makeCurrent(); d_data = d_data | mask; @@ -111,6 +115,8 @@ public: /* Add an edge from node i to node j. Return false if successful, true if this edge would create a cycle */ bool addEdge(unsigned i, unsigned j); + /** whether node i is connected to node j */ + bool isConnected(unsigned i, unsigned j); void debugPrintMatrix(); }; @@ -119,17 +125,26 @@ public: * */ class TransitiveClosureNode : public TransitiveClosure{ - static unsigned d_counter; - std::map< Node, unsigned > nodeMap; + context::CDO< unsigned > d_counter; + context::CDMap< Node, unsigned, NodeHashFunction > nodeMap; unsigned getId( Node i ); + //for debugging + context::CDList< std::pair< Node, Node > > currEdges; public: - TransitiveClosureNode(context::Context* context) : TransitiveClosure(context) {} + TransitiveClosureNode(context::Context* context) : + TransitiveClosure(context), d_counter( context, 0 ), nodeMap( context ), currEdges(context) {} ~TransitiveClosureNode(){} /* Add an edge from node i to node j. Return false if successful, true if this edge would create a cycle */ bool addEdgeNode(Node i, Node j) { + currEdges.push_back( std::pair< Node, Node >( i, j ) ); return addEdge( getId( i ), getId( j ) ); } + /** whether node i is connected to node j */ + bool isConnectedNode(Node i, Node j) { + return isConnected( getId( i ), getId( j ) ); + } + void debugPrint(); }; }/* CVC4 namespace */ -- 2.30.2