From dab7f460511bf0f36c286eaf456a4be11f4fea4b Mon Sep 17 00:00:00 2001 From: ajreynol Date: Fri, 5 Feb 2016 10:24:49 -0600 Subject: [PATCH] Add two optimizations for datatypes, currently disabled. Bug fix rewriter for selectors applied to codatatype values. --- src/options/datatypes_options | 4 + src/theory/datatypes/datatypes_rewriter.h | 143 ++++++++-- src/theory/datatypes/datatypes_sygus.cpp | 60 +++-- src/theory/datatypes/datatypes_sygus.h | 6 +- src/theory/datatypes/theory_datatypes.cpp | 304 ++++++++++++++-------- src/theory/datatypes/theory_datatypes.h | 9 +- src/theory/quantifiers_engine.cpp | 3 + 7 files changed, 364 insertions(+), 165 deletions(-) diff --git a/src/options/datatypes_options b/src/options/datatypes_options index ba700a594..b44a36e2a 100644 --- a/src/options/datatypes_options +++ b/src/options/datatypes_options @@ -15,6 +15,10 @@ option dtForceAssignment --dt-force-assignment bool :default false :read-write force the datatypes solver to give specific values to all datatypes terms before answering sat option dtBinarySplit --dt-binary-split bool :default false do binary splits for datatype constructor types +option dtRefIntro --dt-ref-sk-intro bool :default false + introduce reference skolems for shorter explanations +option dtUseTesters --dt-use-testers bool :default true + do not preprocess away tester predicates option cdtBisimilar --cdt-bisimilar bool :default true do bisimilarity check for co-datatypes option dtCyclic --dt-cyclic bool :default true diff --git a/src/theory/datatypes/datatypes_rewriter.h b/src/theory/datatypes/datatypes_rewriter.h index dc57f6b47..da2282a2c 100644 --- a/src/theory/datatypes/datatypes_rewriter.h +++ b/src/theory/datatypes/datatypes_rewriter.h @@ -78,8 +78,7 @@ public: Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " << "Rewrite trivial tester " << in << " " << result << std::endl; - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(result)); + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(result)); } else { const Datatype& dt = DatatypeType(in[0].getType().toType()).getDatatype(); if(dt.getNumConstructors() == 1) { @@ -88,8 +87,16 @@ public: << "only one ctor for " << dt.getName() << " and that is " << dt[0].getName() << std::endl; - return RewriteResponse(REWRITE_DONE, - NodeManager::currentNM()->mkConst(true)); + return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true)); + } + //TODO : else if( dt.getNumConstructors()==2 && Datatype::indexOf(in.getOperator())==1 ){ + else if( !options::dtUseTesters() ){ + unsigned tindex = Datatype::indexOf(in.getOperator().toExpr()); + Trace("datatypes-rewrite-debug") << "Convert " << in << " to equality " << in[0] << " " << tindex << std::endl; + Node neq = mkTester( in[0], tindex, dt ); + Assert( neq!=in ); + Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: Rewrite tester " << in << " to " << neq << std::endl; + return RewriteResponse(REWRITE_AGAIN_FULL, neq); } } } @@ -130,8 +137,7 @@ public: Debug("tuprec") << "==> returning " << in[0][selectorIndex] << std::endl; return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]); } - if(in.getKind() == kind::APPLY_SELECTOR_TOTAL && - in[0].getKind() == kind::APPLY_CONSTRUCTOR) { + if(in.getKind() == kind::APPLY_SELECTOR_TOTAL && in[0].getKind() == kind::APPLY_CONSTRUCTOR) { // Have to be careful not to rewrite well-typed expressions // where the selector doesn't match the constructor, // e.g. "pred(zero)". @@ -143,12 +149,20 @@ public: size_t constructorIndex = Datatype::indexOf(constructorExpr); const Datatype& dt = Datatype::datatypeOf(constructorExpr); const DatatypeConstructor& c = dt[constructorIndex]; - if(c.getNumArgs() > selectorIndex && - c[selectorIndex].getSelector() == selectorExpr) { - Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " - << "Rewrite trivial selector " << in - << std::endl; - return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]); + if(c.getNumArgs() > selectorIndex && c[selectorIndex].getSelector() == selectorExpr) { + if( dt.isCodatatype() && in[0][selectorIndex].isConst() ){ + //must replace all debruijn indices with self + Node sub = replaceDebruijn( in[0][selectorIndex], in[0], in[0].getType(), 0 ); + Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " + << "Rewrite trivial codatatype selector " << in << " to " << sub << std::endl; + if( sub!=in ){ + return RewriteResponse(REWRITE_AGAIN_FULL, sub ); + } + }else{ + Trace("datatypes-rewrite") << "DatatypesRewriter::postRewrite: " + << "Rewrite trivial selector " << in << std::endl; + return RewriteResponse(REWRITE_DONE, in[0][selectorIndex]); + } }else{ //typically should not be called TypeNode tn = in.getType(); @@ -320,6 +334,7 @@ public: } /** get instantiate cons */ static Node getInstCons( Node n, const Datatype& dt, int index ) { + Assert( index>=0 && index<(int)dt.getNumConstructors() ); Type tspec; if( dt.isParametric() ){ tspec = dt[index].getSpecializedConstructorType(n.getType().toType()); @@ -343,30 +358,85 @@ public: n_ic = NodeManager::currentNM()->mkNode( kind::APPLY_CONSTRUCTOR, children ); Assert( n_ic.getType()==n.getType() ); } - Assert( isInstCons( n_ic, dt, index ) ); + Assert( isInstCons( n, n_ic, dt )==index ); //n_ic = Rewriter::rewrite( n_ic ); return n_ic; } - static bool isInstCons( Node n, const Datatype& dt, int index ){ + + static int isInstCons( Node t, Node n, const Datatype& dt ){ if( n.getKind()==kind::APPLY_CONSTRUCTOR ){ + int index = Datatype::indexOf( n.getOperator().toExpr() ); const DatatypeConstructor& c = dt[index]; - if( n.getOperator()==Node::fromExpr( c.getConstructor() ) ){ - for( unsigned i=0; i=0; i-- ){ + if( n[i].getKind()==kind::APPLY_CONSTRUCTOR ){ + const Datatype& dt = Datatype::datatypeOf(n[i].getOperator().toExpr()); + int ic = isInstCons( n[1-i], n[i], dt ); + if( ic!=-1 ){ + a = n[1-i]; + return ic; + } } } - return true; } } - return false; + return -1; + } + static int isTester( Node n ) { + if( options::dtUseTesters() ){ + if( n.getKind()==kind::APPLY_TESTER ){ + return Datatype::indexOf( n.getOperator().toExpr() ); + } + }else{ + if( n.getKind()==kind::EQUAL ){ + for( int i=1; i>=0; i-- ){ + if( n[i].getKind()==kind::APPLY_CONSTRUCTOR ){ + const Datatype& dt = Datatype::datatypeOf(n[i].getOperator().toExpr()); + int ic = isInstCons( n[1-i], n[i], dt ); + if( ic!=-1 ){ + return ic; + } + } + } + } + } + return -1; } + static Node mkTester( Node n, int i, const Datatype& dt ){ - //Node ret = n.eqNode( DatatypesRewriter::getInstCons( n, dt, i ) ); - //Assert( isTester( ret )==i ); - Node ret = NodeManager::currentNM()->mkNode( kind::APPLY_TESTER, Node::fromExpr( dt[i].getTester() ), n ); - return ret; + if( options::dtUseTesters() ){ + return NodeManager::currentNM()->mkNode( kind::APPLY_TESTER, Node::fromExpr( dt[i].getTester() ), n ); + }else{ +#ifdef CVC4_ASSERTIONS + Node ret = n.eqNode( DatatypesRewriter::getInstCons( n, dt, i ) ); + Node a; + int ii = isTester( ret, a ); + Assert( ii==i ); + Assert( a==n ); + return ret; +#else + return n.eqNode( DatatypesRewriter::getInstCons( n, dt, i ) ); +#endif + } } static bool isNullaryApplyConstructor( Node n ){ Assert( n.getKind()==kind::APPLY_CONSTRUCTOR ); @@ -486,6 +556,29 @@ private: } } } + static Node replaceDebruijn( Node n, Node orig, TypeNode orig_tn, unsigned depth ) { + if( n.getKind()==kind::UNINTERPRETED_CONSTANT && n.getType()==orig_tn ){ + unsigned index = n.getConst().getIndex().toUnsignedInt(); + if( index==depth ){ + return orig; + } + }else if( n.getNumChildren()>0 ){ + std::vector< Node > children; + bool childChanged = false; + for( unsigned i=0; imkNode( n.getKind(), children ); + } + } + return n; + } public: static Node normalizeCodatatypeConstant( Node n ){ Trace("dt-nconst") << "Normalize " << n << std::endl; diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index 5f0466d30..07fb60e57 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -643,10 +643,10 @@ SygusSymBreak::SygusSymBreak( quantifiers::TermDbSygus * tds, context::Context* } -void SygusSymBreak::addTester( Node tst ) { +void SygusSymBreak::addTester( int tindex, Node n, Node exp ) { if( options::sygusNormalFormGlobal() ){ - Node a = getAnchor( tst[0] ); - Trace("sygus-sym-break-debug") << "Add tester " << tst << " for " << a << std::endl; + Node a = getAnchor( n ); + Trace("sygus-sym-break-debug") << "Add tester " << tindex << " " << n << " for " << a << std::endl; std::map< Node, ProgSearch * >::iterator it = d_prog_search.find( a ); ProgSearch * ps; if( it==d_prog_search.end() ){ @@ -664,7 +664,7 @@ void SygusSymBreak::addTester( Node tst ) { ps = it->second; } if( ps ){ - ps->addTester( tst ); + ps->addTester( tindex, n, exp ); } } } @@ -677,34 +677,39 @@ Node SygusSymBreak::getAnchor( Node n ) { } } -void SygusSymBreak::ProgSearch::addTester( Node tst ) { - NodeMap::const_iterator it = d_testers.find( tst[0] ); +void SygusSymBreak::ProgSearch::addTester( int tindex, Node n, Node exp ) { +#ifdef CVC4_ASSERTIONS + Node a; + int teindex = DatatypesRewriter::isTester( exp, a ); + Assert( teindex==tindex ); + Assert( a==n ); +#endif + NodeMap::const_iterator it = d_testers.find( n ); if( it==d_testers.end() ){ - d_testers[tst[0]] = tst; - if( tst[0]==d_anchor ){ - assignTester( tst, 0 ); + d_testers[n] = exp; + if( n==d_anchor ){ + assignTester( tindex, n, 0 ); }else{ - IntMap::const_iterator it = d_watched_terms.find( tst[0] ); + IntMap::const_iterator it = d_watched_terms.find( n ); if( it!=d_watched_terms.end() ){ - assignTester( tst, (*it).second ); + assignTester( tindex, n, (*it).second ); }else{ - Trace("sygus-sym-break-debug2") << "...add to wait list " << tst << " for " << d_anchor << std::endl; + Trace("sygus-sym-break-debug2") << "...add to wait list " << tindex << " " << n << " for " << d_anchor << std::endl; } } }else{ - Trace("sygus-sym-break-debug2") << "...already seen " << tst << " for " << d_anchor << std::endl; + Trace("sygus-sym-break-debug2") << "...already seen " << tindex << " " << n << " for " << d_anchor << std::endl; } } -bool SygusSymBreak::ProgSearch::assignTester( Node tst, int depth ) { - Trace("sygus-sym-break-debug") << "SymBreak : Assign tester : " << tst << ", depth = " << depth << " of " << d_anchor << std::endl; - int tindex = Datatype::indexOf( tst.getOperator().toExpr() ); - TypeNode tn = tst[0].getType(); +bool SygusSymBreak::ProgSearch::assignTester( int tindex, Node n, int depth ) { + Trace("sygus-sym-break-debug") << "SymBreak : Assign tester : " << tindex << " " << n << ", depth = " << depth << " of " << d_anchor << std::endl; + TypeNode tn = n.getType(); Assert( DatatypesRewriter::isTypeDatatype( tn ) ); const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); std::vector< Node > tst_waiting; for( unsigned i=0; imkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[tindex][i].getSelector() ), tst[0] ); + Node sel = NodeManager::currentNM()->mkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[tindex][i].getSelector() ), n ); NodeMap::const_iterator it = d_testers.find( sel ); if( it!=d_testers.end() ){ tst_waiting.push_back( (*it).second ); @@ -727,11 +732,14 @@ bool SygusSymBreak::ProgSearch::assignTester( Node tst, int depth ) { d_watched_count[depth] = d_watched_count[depth] - 1; } //determine if any subprograms on the current path are redundant - if( processSubprograms( tst[0], depth, depth ) ){ + if( processSubprograms( n, depth, depth ) ){ if( processProgramDepth( depth ) ){ //assign preexisting testers for( unsigned i=0; i1 ){ bool finished = false; const Datatype & pdt = ((DatatypeType)(at).toType()).getDatatype(); - int pc = Datatype::indexOf( testers[0].getOperator().toExpr() ); + int pc = DatatypesRewriter::isTester( testers[0] );//Datatype::indexOf( testers[0].getOperator().toExpr() ); + Assert( pc!=-1 ); // [1] determine a minimal subset of the arguments that the rewriting depended on //quick checks based on constants for( unsigned i=0; iregisterSygusType( tn ); Node op = d_tds->getArgOp( tn, tindex ); if( op!=rop ){ diff --git a/src/theory/datatypes/datatypes_sygus.h b/src/theory/datatypes/datatypes_sygus.h index 415bd6e4b..b00fade36 100644 --- a/src/theory/datatypes/datatypes_sygus.h +++ b/src/theory/datatypes/datatypes_sygus.h @@ -91,7 +91,7 @@ private: std::vector< Node >& testers, std::map< Node, std::vector< Node > >& testers_u ); bool processProgramDepth( int depth ); bool processSubprograms( Node n, int depth, int odepth ); - bool assignTester( Node tst, int depth ); + bool assignTester( int tindex, Node n, int depth ); public: ProgSearch( SygusSymBreak * p, Node a, context::Context* c ) : d_parent( p ), d_anchor( a ), d_testers( c ), d_watched_terms( c ), d_watched_count( c ), d_prog_depth( c, 0 ) { @@ -103,7 +103,7 @@ private: IntIntMap d_watched_count; TypeNode d_anchor_type; context::CDO d_prog_depth; - void addTester( Node tst ); + void addTester( int tindex, Node n, Node exp ); }; std::map< Node, ProgSearch * > d_prog_search; std::map< TypeNode, std::map< Node, Node > > d_normalized_to_orig; @@ -130,7 +130,7 @@ private: public: SygusSymBreak( quantifiers::TermDbSygus * tds, context::Context* c ); /** add tester */ - void addTester( Node tst ); + void addTester( int tindex, Node n, Node exp ); /** lemmas we have generated */ std::vector< Node > d_lemmas; }; diff --git a/src/theory/datatypes/theory_datatypes.cpp b/src/theory/datatypes/theory_datatypes.cpp index e33b4d05c..439fd0cfb 100644 --- a/src/theory/datatypes/theory_datatypes.cpp +++ b/src/theory/datatypes/theory_datatypes.cpp @@ -46,6 +46,7 @@ TheoryDatatypes::TheoryDatatypes(Context* c, UserContext* u, OutputChannel& out, d_hasSeenCycle(c, false), d_infer(c), d_infer_exp(c), + d_term_sk( u ), d_notify( *this ), d_equalityEngine(d_notify, c, "theory::datatypes::TheoryDatatypes", true), d_labels( c ), @@ -173,7 +174,7 @@ void TheoryDatatypes::check(Effort e) { }while( addedFact ); //check for splits - Debug("datatypes-split") << "Check for splits " << e << endl; + Trace("datatypes-debug") << "Check for splits " << e << endl; addedFact = false; do { std::map< TypeNode, Node > rec_singletons; @@ -303,7 +304,8 @@ void TheoryDatatypes::check(Effort e) { Assert( !children.empty() ); Node lemma = children.size()==1 ? children[0] : NodeManager::currentNM()->mkNode( kind::OR, children ); Trace("dt-split-debug") << "Split lemma is : " << lemma << std::endl; - doSendLemma( lemma ); + //doSendLemma( lemma ); + d_out->lemma( lemma, false, false, true ); } return; } @@ -332,7 +334,7 @@ void TheoryDatatypes::check(Effort e) { } } */ - }while( !d_conflict && addedFact ); + }while( !d_conflict && !d_addedLemma && addedFact ); Trace("datatypes-debug") << "Finished. " << d_conflict << std::endl; if( !d_conflict ){ Trace("dt-model-debug") << std::endl; @@ -348,35 +350,43 @@ void TheoryDatatypes::check(Effort e) { void TheoryDatatypes::flushPendingFacts(){ doPendingMerges(); - if( !d_pending.empty() ){ + if( !d_pending_lem.empty() ){ int i = 0; - while( !d_conflict && i<(int)d_pending.size() ){ - Node fact = d_pending[i]; - Node exp = d_pending_exp[ fact ]; - //check to see if we have to communicate it to the rest of the system - if( mustCommunicateFact( fact, exp ) ){ - Trace("dt-lemma-debug") << "Assert fact " << fact << " with explanation " << exp << std::endl; - Node lem = fact; - if( exp.isNull() || exp==d_true ){ - Trace("dt-lemma-debug") << "Trivial explanation." << std::endl; - }else{ - Trace("dt-lemma-debug") << "Get explanation..." << std::endl; - Node ee_exp = explain( exp ); - Trace("dt-lemma-debug") << "Explanation : " << ee_exp << std::endl; - lem = NodeManager::currentNM()->mkNode( OR, ee_exp.negate(), fact ); - lem = Rewriter::rewrite( lem ); - } - Trace("dt-lemma") << "Datatypes lemma : " << lem << std::endl; - doSendLemma( lem ); - d_addedLemma = true; + while( i<(int)d_pending_lem.size() ){ + doSendLemma( d_pending_lem[i] ); + i++; + } + d_pending_lem.clear(); + doPendingMerges(); + } + int i = 0; + while( !d_conflict && i<(int)d_pending.size() ){ + Node fact = d_pending[i]; + Node exp = d_pending_exp[ fact ]; + Trace("datatypes-debug") << "Assert fact (#" << (i+1) << "/" << d_pending.size() << ") " << fact << " with explanation " << exp << std::endl; + //check to see if we have to communicate it to the rest of the system + if( mustCommunicateFact( fact, exp ) ){ + Node lem = fact; + if( exp.isNull() || exp==d_true ){ + Trace("dt-lemma-debug") << "Trivial explanation." << std::endl; }else{ - assertFact( fact, exp ); + Trace("dt-lemma-debug") << "Get explanation..." << std::endl; + Node ee_exp = explain( exp ); + Trace("dt-lemma-debug") << "Explanation : " << ee_exp << std::endl; + lem = NodeManager::currentNM()->mkNode( OR, ee_exp.negate(), fact ); + lem = Rewriter::rewrite( lem ); } - i++; + Trace("dt-lemma") << "Datatypes lemma : " << lem << std::endl; + doSendLemma( lem ); + d_addedLemma = true; + }else{ + assertFact( fact, exp ); } - d_pending.clear(); - d_pending_exp.clear(); + Trace("datatypes-debug") << "Finished fact " << fact << ", now = " << d_conflict << " " << d_pending.size() << std::endl; + i++; } + d_pending.clear(); + d_pending_exp.clear(); } void TheoryDatatypes::doPendingMerges(){ @@ -394,6 +404,7 @@ void TheoryDatatypes::doPendingMerges(){ void TheoryDatatypes::doSendLemma( Node lem ) { if( d_lemmas_produced_c.find( lem )==d_lemmas_produced_c.end() ){ + Trace("dt-lemma-send") << "TheoryDatatypes::doSendLemma : " << lem << std::endl; d_lemmas_produced_c[lem] = true; d_out->lemma( lem ); } @@ -401,6 +412,7 @@ void TheoryDatatypes::doSendLemma( Node lem ) { void TheoryDatatypes::assertFact( Node fact, Node exp ){ Assert( d_pending_merge.empty() ); + Trace("datatypes-debug") << "TheoryDatatypes::assertFact : " << fact << std::endl; bool polarity = fact.getKind() != kind::NOT; TNode atom = polarity ? fact : fact[0]; if (atom.getKind() == kind::EQUAL) { @@ -410,19 +422,25 @@ void TheoryDatatypes::assertFact( Node fact, Node exp ){ } doPendingMerges(); //add to tester if applicable - if( atom.getKind()==kind::APPLY_TESTER ){ - Node rep = getRepresentative( atom[0] ); + Node t_arg; + int tindex = DatatypesRewriter::isTester( atom, t_arg ); + if( tindex!=-1 ){ + Trace("dt-tester") << "Assert tester : " << atom << " for " << t_arg << std::endl; + Node rep = getRepresentative( t_arg ); EqcInfo* eqc = getOrMakeEqcInfo( rep, true ); - addTester( fact, eqc, rep ); + addTester( tindex, fact, eqc, rep, t_arg ); + Trace("dt-tester") << "Done assert tester." << std::endl; if( !d_conflict && polarity ){ - Trace("dt-tester") << "Assert tester : " << atom << std::endl; if( d_sygus_sym_break ){ //Assert( !d_sygus_util->d_conflict ); - d_sygus_sym_break->addTester( atom ); + Trace("dt-tester") << "Assert tester to sygus : " << atom << std::endl; + d_sygus_sym_break->addTester( tindex, t_arg, atom ); + Trace("dt-tester") << "Done assert tester to sygus." << std::endl; for( unsigned i=0; id_lemmas.size(); i++ ){ Trace("dt-lemma-sygus") << "Sygus symmetry breaking lemma : " << d_sygus_sym_break->d_lemmas[i] << std::endl; doSendLemma( d_sygus_sym_break->d_lemmas[i] ); } + Trace("dt-lemma-sygus") << "No lemmas" << std::endl; d_sygus_sym_break->d_lemmas.clear(); /* if( d_sygus_util->d_conflict ){ @@ -439,8 +457,11 @@ void TheoryDatatypes::assertFact( Node fact, Node exp ){ */ } } + }else{ + Trace("dt-tester-debug") << "Assert (non-tester) : " << atom << std::endl; } doPendingMerges(); + Trace("datatypes-debug") << "TheoryDatatypes::assertFact : finished " << fact << std::endl; } void TheoryDatatypes::preRegisterTerm(TNode n) { @@ -662,6 +683,7 @@ bool TheoryDatatypes::propagate(TNode literal){ // Propagate out bool ok = d_out->propagate(literal); if (!ok) { + Trace("dt-conflict") << "CONFLICT: Eq engine propagate conflict " << std::endl; d_conflict = true; } return ok; @@ -777,7 +799,7 @@ void TheoryDatatypes::eqNotifyPreMerge(TNode t1, TNode t2){ /** called when two equivalance classes have merged */ void TheoryDatatypes::eqNotifyPostMerge(TNode t1, TNode t2){ if( DatatypesRewriter::isTermDatatype( t1 ) ){ - Debug("datatypes-debug") << "NotifyPostMerge : " << t1 << " " << t2 << std::endl; + Trace("datatypes-debug") << "NotifyPostMerge : " << t1 << " " << t2 << std::endl; d_pending_merge.push_back( t1.eqNode( t2 ) ); } } @@ -786,7 +808,7 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ if( !d_conflict ){ TNode trep1 = t1; TNode trep2 = t2; - Debug("datatypes-debug") << "Merge " << t1 << " " << t2 << std::endl; + Trace("datatypes-debug") << "Merge " << t1 << " " << t2 << std::endl; EqcInfo* eqc2 = getOrMakeEqcInfo( t2 ); if( eqc2 ){ bool checkInst = false; @@ -795,7 +817,7 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ } EqcInfo* eqc1 = getOrMakeEqcInfo( t1 ); if( eqc1 ){ - Debug("datatypes-debug") << " merge eqc info " << eqc2 << " into " << eqc1 << std::endl; + Trace("datatypes-debug") << " merge eqc info " << eqc2 << " into " << eqc1 << std::endl; if( !eqc1->d_constructor.get().isNull() ){ trep1 = eqc1->d_constructor.get(); } @@ -804,7 +826,7 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ TNode cons2 = eqc2->d_constructor.get(); //if both have constructor, then either clash or unification if( !cons1.isNull() && !cons2.isNull() ){ - Debug("datatypes-debug") << " constructors : " << cons1 << " " << cons2 << std::endl; + Trace("datatypes-debug") << " constructors : " << cons1 << " " << cons2 << std::endl; Node unifEq = cons1.eqNode( cons2 ); /* std::vector< Node > exp; @@ -856,11 +878,11 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ */ } } - Debug("datatypes-debug") << " instantiated : " << eqc1->d_inst << " " << eqc2->d_inst << std::endl; + Trace("datatypes-debug") << " instantiated : " << eqc1->d_inst << " " << eqc2->d_inst << std::endl; eqc1->d_inst = eqc1->d_inst || eqc2->d_inst; if( !cons2.isNull() ){ if( cons1.isNull() ){ - Debug("datatypes-debug") << " must check if it is okay to set the constructor." << std::endl; + Trace("datatypes-debug") << " must check if it is okay to set the constructor." << std::endl; checkInst = true; addConstructor( eqc2->d_constructor.get(), eqc1, t1 ); if( d_conflict ){ @@ -871,7 +893,7 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ //d_consEqc[t2] = false; } }else{ - Debug("datatypes-debug") << " no eqc info for " << t1 << ", must create" << std::endl; + Trace("datatypes-debug") << " no eqc info for " << t1 << ", must create" << std::endl; //just copy the equivalence class information eqc1 = getOrMakeEqcInfo( t1, true ); eqc1->d_inst.set( eqc2->d_inst ); @@ -883,12 +905,16 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ //merge labels NodeListMap::iterator lbl_i = d_labels.find( t2 ); if( lbl_i != d_labels.end() ){ - Debug("datatypes-debug") << " merge labels from " << eqc2 << " " << t2 << std::endl; + Trace("datatypes-debug") << " merge labels from " << eqc2 << " " << t2 << std::endl; NodeList* lbl = (*lbl_i).second; for( NodeList::const_iterator j = lbl->begin(); j != lbl->end(); ++j ){ - addTester( *j, eqc1, t1 ); + Node tt = (*j).getKind()==kind::NOT ? (*j)[0] : (*j); + Node t_arg; + int tindex = DatatypesRewriter::isTester( tt, t_arg ); + Assert( tindex!=-1 ); + addTester( tindex, *j, eqc1, t1, t_arg ); if( d_conflict ){ - Debug("datatypes-debug") << " conflict!" << std::endl; + Trace("datatypes-debug") << " conflict!" << std::endl; return; } } @@ -900,21 +926,21 @@ void TheoryDatatypes::merge( Node t1, Node t2 ){ } NodeListMap::iterator sel_i = d_selector_apps.find( t2 ); if( sel_i != d_selector_apps.end() ){ - Debug("datatypes-debug") << " merge selectors from " << eqc2 << " " << t2 << std::endl; + Trace("datatypes-debug") << " merge selectors from " << eqc2 << " " << t2 << std::endl; NodeList* sel = (*sel_i).second; for( NodeList::const_iterator j = sel->begin(); j != sel->end(); ++j ){ addSelector( *j, eqc1, t1, eqc2->d_constructor.get().isNull() ); } } if( checkInst ){ - Debug("datatypes-debug") << " checking instantiate" << std::endl; + Trace("datatypes-debug") << " checking instantiate" << std::endl; instantiate( eqc1, t1 ); if( d_conflict ){ return; } } } - Debug("datatypes-debug") << "Finished Merge " << t1 << " " << t2 << std::endl; + Trace("datatypes-debug") << "Finished Merge " << t1 << " " << t2 << std::endl; } } @@ -936,7 +962,7 @@ Node TheoryDatatypes::getLabel( Node n ) { NodeListMap::iterator lbl_i = d_labels.find( n ); if( lbl_i != d_labels.end() ){ NodeList* lbl = (*lbl_i).second; - if( !(*lbl).empty() && (*lbl)[ (*lbl).size() - 1 ].getKind()==kind::APPLY_TESTER ){ + if( !(*lbl).empty() && (*lbl)[ (*lbl).size() - 1 ].getKind()!=kind::NOT ){ return (*lbl)[ (*lbl).size() - 1 ]; } } @@ -947,7 +973,15 @@ int TheoryDatatypes::getLabelIndex( EqcInfo* eqc, Node n ){ if( eqc && !eqc->d_constructor.get().isNull() ){ return Datatype::indexOf( eqc->d_constructor.get().getOperator().toExpr() ); }else{ - return Datatype::indexOf( getLabel( n ).getOperator().toExpr() ); + Node lbl = getLabel( n ); + if( lbl.isNull() ){ + return -1; + }else{ + int tindex = DatatypesRewriter::isTester( lbl ); + Assert( tindex!=-1 ); + return tindex; + //return Datatype::indexOf( getLabel( n ).getOperator().toExpr() ); + } } } @@ -962,16 +996,20 @@ bool TheoryDatatypes::hasTester( Node n ) { void TheoryDatatypes::getPossibleCons( EqcInfo* eqc, Node n, std::vector< bool >& pcons ){ const Datatype& dt = ((DatatypeType)(n.getType()).toType()).getDatatype(); - pcons.resize( dt.getNumConstructors(), !hasLabel( eqc, n ) ); - if( hasLabel( eqc, n ) ){ - pcons[ getLabelIndex( eqc, n ) ] = true; + int lindex = getLabelIndex( eqc, n ); + pcons.resize( dt.getNumConstructors(), lindex==-1 ); + if( lindex!=-1 ){ + pcons[ lindex ] = true; }else{ NodeListMap::iterator lbl_i = d_labels.find( n ); if( lbl_i != d_labels.end() ){ NodeList* lbl = (*lbl_i).second; for( NodeList::const_iterator i = lbl->begin(); i != lbl->end(); i++ ) { Assert( (*i).getKind()==NOT ); - pcons[ Datatype::indexOf( (*i)[0].getOperator().toExpr() ) ] = false; + //pcons[ Datatype::indexOf( (*i)[0].getOperator().toExpr() ) ] = false; + int tindex = DatatypesRewriter::isTester( (*i)[0] ); + Assert( tindex!=-1 ); + pcons[ tindex ] = false; } } } @@ -986,23 +1024,42 @@ void TheoryDatatypes::mkExpDefSkolem( Node sel, TypeNode dt, TypeNode rt ) { } } -void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){ - Debug("datatypes-debug") << "Add tester : " << t << " to eqc(" << n << ")" << std::endl; +Node TheoryDatatypes::getTermSkolemFor( Node n ) { + if( n.getKind()==APPLY_CONSTRUCTOR ){ + NodeMap::const_iterator it = d_term_sk.find( n ); + if( it==d_term_sk.end() ){ + //add purification unit lemma ( k = n ) + Node k = NodeManager::currentNM()->mkSkolem( "k", n.getType(), "reference skolem for datatypes" ); + d_term_sk[n] = k; + Node eq = k.eqNode( n ); + Trace("datatypes-infer") << "DtInfer : ref : " << eq << std::endl; + d_pending_lem.push_back( eq ); + //doSendLemma( eq ); + //d_pending_exp[ eq ] = d_true; + return k; + }else{ + return (*it).second; + } + }else{ + return n; + } +} + +void TheoryDatatypes::addTester( int ttindex, Node t, EqcInfo* eqc, Node n, Node t_arg ){ + Trace("datatypes-debug") << "Add tester : " << t << " to eqc(" << n << ")" << std::endl; Debug("datatypes-labels") << "Add tester " << t << " " << n << " " << eqc << std::endl; bool tpolarity = t.getKind()!=NOT; - Node tt = ( t.getKind() == NOT ) ? t[0] : t; - int ttindex = Datatype::indexOf( tt.getOperator().toExpr() ); Node j, jt; bool makeConflict = false; - if( hasLabel( eqc, n ) ){ + int jtindex0 = getLabelIndex( eqc, n ); + if( jtindex0!=-1 ){ //if we already know the constructor type, check whether it is in conflict or redundant - int jtindex = getLabelIndex( eqc, n ); - if( (jtindex==ttindex)!=tpolarity ){ + if( (jtindex0==ttindex)!=tpolarity ){ if( !eqc->d_constructor.get().isNull() ){ //conflict because equivalence class contains a constructor std::vector< TNode > assumptions; explain( t, assumptions ); - explainEquality( eqc->d_constructor.get(), tt[0], true, assumptions ); + explainEquality( eqc->d_constructor.get(), t_arg, true, assumptions ); d_conflictNode = mkAnd( assumptions ); Trace("dt-conflict") << "CONFLICT: Tester eq conflict : " << d_conflictNode << std::endl; d_out->conflict( d_conflictNode ); @@ -1026,7 +1083,9 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){ Assert( (*i).getKind()==NOT ); j = *i; jt = j[0]; - int jtindex = Datatype::indexOf( jt.getOperator().toExpr() ); + //int jtindex = Datatype::indexOf( jt.getOperator().toExpr() ); + int jtindex = DatatypesRewriter::isTester( jt ); + Assert( jtindex!=-1 ); if( jtindex==ttindex ){ if( tpolarity ){ //we are in conflict makeConflict = true; @@ -1039,7 +1098,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){ if( !makeConflict ){ Debug("datatypes-labels") << "Add to labels " << t << std::endl; lbl->push_back( t ); - const Datatype& dt = ((DatatypeType)(tt[0].getType()).toType()).getDatatype(); + const Datatype& dt = ((DatatypeType)(t_arg.getType()).toType()).getDatatype(); Debug("datatypes-labels") << "Labels at " << lbl->size() << " / " << dt.getNumConstructors() << std::endl; if( tpolarity ){ instantiate( eqc, n ); @@ -1051,7 +1110,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){ std::vector< bool > pcons; getPossibleCons( eqc, n, pcons ); int testerIndex = -1; - for( int i=0; i<(int)pcons.size(); i++ ) { + for( unsigned i=0; i nb(kind::AND); for( NodeList::const_iterator i = lbl->begin(); i != lbl->end(); i++ ) { nb << (*i); - if( std::find( eq_terms.begin(), eq_terms.end(), (*i)[0][0] )==eq_terms.end() ){ - eq_terms.push_back( (*i)[0][0] ); - if( (*i)[0][0]!=tt[0] ){ - nb << (*i)[0][0].eqNode( tt[0] ); + Assert( (*i).getKind()==NOT ); + Node t_arg2; + int tindex = DatatypesRewriter::isTester( (*i)[0], t_arg2 ); + Assert( tindex!=-1 ); + if( std::find( eq_terms.begin(), eq_terms.end(), t_arg2 )==eq_terms.end() ){ + eq_terms.push_back( t_arg2 ); + if( t_arg2!=t_arg ){ + nb << t_arg2.eqNode( t_arg ); } } } - Node t_concl = DatatypesRewriter::mkTester( tt[0], testerIndex, dt ); + Node t_concl = DatatypesRewriter::mkTester( t_arg, testerIndex, dt ); Node t_concl_exp = ( nb.getNumChildren() == 1 ) ? nb.getChild( 0 ) : nb; d_pending.push_back( t_concl ); d_pending_exp[ t_concl ] = t_concl_exp; @@ -1088,7 +1151,7 @@ void TheoryDatatypes::addTester( Node t, EqcInfo* eqc, Node n ){ std::vector< TNode > assumptions; explain( j, assumptions ); explain( t, assumptions ); - explainEquality( jt[0], tt[0], true, assumptions ); + explainEquality( jt[0], t_arg, true, assumptions ); d_conflictNode = mkAnd( assumptions ); Trace("dt-conflict") << "CONFLICT: Tester conflict : " << d_conflictNode << std::endl; d_out->conflict( d_conflictNode ); @@ -1120,7 +1183,7 @@ void TheoryDatatypes::addSelector( Node s, EqcInfo* eqc, Node n, bool assertFact } void TheoryDatatypes::addConstructor( Node c, EqcInfo* eqc, Node n ){ - Debug("datatypes-debug") << "Add constructor : " << c << " to eqc(" << n << ")" << std::endl; + Trace("datatypes-debug") << "Add constructor : " << c << " to eqc(" << n << ")" << std::endl; Assert( eqc->d_constructor.get().isNull() ); //check labels NodeListMap::iterator lbl_i = d_labels.find( n ); @@ -1129,7 +1192,9 @@ void TheoryDatatypes::addConstructor( Node c, EqcInfo* eqc, Node n ){ NodeList* lbl = (*lbl_i).second; for( NodeList::const_iterator i = lbl->begin(); i != lbl->end(); i++ ) { if( (*i).getKind()==NOT ){ - if( Datatype::indexOf( (*i)[0].getOperator().toExpr() )==constructorIndex ){ + int tindex = DatatypesRewriter::isTester( (*i)[0] ); + Assert( tindex!=-1 ); + if( tindex==(int)constructorIndex ){ Node n = *i; std::vector< TNode > assumptions; explain( *i, assumptions ); @@ -1160,6 +1225,15 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) { Trace("dt-collapse-sel") << "collapse selector : " << s << " " << c << std::endl; Node r; bool wrong = false; + Node use_s; + Node eq_exp; + if( options::dtRefIntro() ){ + eq_exp = d_true; + use_s = getTermSkolemFor( c ); + }else{ + eq_exp = c.eqNode( s[0] ); + use_s = s; + } if( s.getKind()==kind::APPLY_SELECTOR_TOTAL ){ //Trace("dt-collapse-sel") << "Indices : " << Datatype::indexOf(c.getOperator().toExpr()) << " " << Datatype::cindexOf(s.getOperator().toExpr()) << std::endl; wrong = Datatype::indexOf(c.getOperator().toExpr())!=Datatype::cindexOf(s.getOperator().toExpr()); @@ -1171,11 +1245,20 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) { // r = NodeManager::currentNM()->mkNode( kind::APPLY_UF, d_exp_def_skolem[s.getOperator().toExpr()], s[0] ); //}else{ r = NodeManager::currentNM()->mkNode( kind::APPLY_SELECTOR_TOTAL, s.getOperator(), c ); + if( options::dtRefIntro() ){ + use_s = NodeManager::currentNM()->mkNode( kind::APPLY_SELECTOR_TOTAL, s.getOperator(), use_s ); + } }else{ if( s.getKind()==DT_SIZE ){ r = NodeManager::currentNM()->mkNode( DT_SIZE, c ); + if( options::dtRefIntro() ){ + use_s = NodeManager::currentNM()->mkNode( DT_SIZE, use_s ); + } }else if( s.getKind()==DT_HEIGHT_BOUND ){ r = NodeManager::currentNM()->mkNode( DT_HEIGHT_BOUND, c, s[1] ); + if( options::dtRefIntro() ){ + use_s = NodeManager::currentNM()->mkNode( DT_HEIGHT_BOUND, use_s, s[1] ); + } if( r==d_true ){ return; } @@ -1183,13 +1266,17 @@ void TheoryDatatypes::collapseSelector( Node s, Node c ) { } if( !r.isNull() ){ Node rr = Rewriter::rewrite( r ); - if( s!=rr ){ - Node eq_exp = c.eqNode( s[0] ); - Node eq = rr.getType().isBoolean() ? s.iffNode( rr ) : s.eqNode( rr ); + if( use_s!=rr ){ + Node eq = rr.getType().isBoolean() ? use_s.iffNode( rr ) : use_s.eqNode( rr ); + Node eq_exp; + if( options::dtRefIntro() ){ + eq_exp = d_true; + }else{ + eq_exp = c.eqNode( s[0] ); + } Trace("datatypes-infer") << "DtInfer : collapse sel"; Trace("datatypes-infer") << ( wrong ? " wrong" : ""); Trace("datatypes-infer") << " : " << eq << " by " << eq_exp << std::endl; - d_pending.push_back( eq ); d_pending_exp[ eq ] = eq_exp; d_infer.push_back( eq ); @@ -1564,16 +1651,7 @@ Node TheoryDatatypes::getInstantiateCons( Node n, const Datatype& dt, int index return it->second; }else{ //add constructor to equivalence class - Node k = n; - if( n.getKind()==APPLY_CONSTRUCTOR ){ - //must construct variable to refer to n, add lemma immediately - k = NodeManager::currentNM()->mkSkolem( "k", n.getType(), "for dt instantiation" ); - Node eq = k.eqNode( n ); - Trace("datatypes-infer") << "DtInfer : instantiation ref : " << eq << std::endl; - //doSendLemma( eq ); - d_pending.push_back( eq ); - d_pending_exp[ eq ] = d_true; - } + Node k = getTermSkolemFor( n ); Node n_ic = DatatypesRewriter::getInstCons( k, dt, index ); //Assert( n_ic==Rewriter::rewrite( n_ic ) ); n_ic = Rewriter::rewrite( n_ic ); @@ -1587,34 +1665,38 @@ Node TheoryDatatypes::getInstantiateCons( Node n, const Datatype& dt, int index void TheoryDatatypes::instantiate( EqcInfo* eqc, Node n ){ //add constructor to equivalence class if not done so already - if( hasLabel( eqc, n ) && !eqc->d_inst ){ - Node exp; - Node tt; - if( !eqc->d_constructor.get().isNull() ){ - exp = d_true; - tt = eqc->d_constructor; + int index = getLabelIndex( eqc, n ); + if( index!=-1 && !eqc->d_inst ){ + if( options::dtUseTesters() ){ + Node exp; + Node tt; + if( !eqc->d_constructor.get().isNull() ){ + exp = d_true; + tt = eqc->d_constructor; + }else{ + exp = getLabel( n ); + tt = exp[0]; + } + const Datatype& dt = ((DatatypeType)(tt.getType()).toType()).getDatatype(); + //must be finite or have a selector + //if( eqc->d_selectors || dt[ index ].isFinite() ){ + //instantiate this equivalence class + eqc->d_inst = true; + Node tt_cons = getInstantiateCons( tt, dt, index ); + Node eq; + if( tt!=tt_cons ){ + eq = tt.eqNode( tt_cons ); + Debug("datatypes-inst") << "DtInstantiate : " << eqc << " " << eq << std::endl; + d_pending.push_back( eq ); + d_pending_exp[ eq ] = exp; + Trace("datatypes-infer-debug") << "inst : " << eqc << " " << n << std::endl; + Trace("datatypes-infer") << "DtInfer : instantiate : " << eq << " by " << exp << std::endl; + //eqc->d_inst.set( eq ); + d_infer.push_back( eq ); + d_infer_exp.push_back( exp ); + } }else{ - exp = getLabel( n ); - tt = exp[0]; - } - int index = getLabelIndex( eqc, n ); - const Datatype& dt = ((DatatypeType)(tt.getType()).toType()).getDatatype(); - //must be finite or have a selector - //if( eqc->d_selectors || dt[ index ].isFinite() ){ - //instantiate this equivalence class - eqc->d_inst = true; - Node tt_cons = getInstantiateCons( tt, dt, index ); - Node eq; - if( tt!=tt_cons ){ - eq = tt.eqNode( tt_cons ); - Debug("datatypes-inst") << "DtInstantiate : " << eqc << " " << eq << std::endl; - d_pending.push_back( eq ); - d_pending_exp[ eq ] = exp; - Trace("datatypes-infer-debug") << "inst : " << eqc << " " << n << std::endl; - Trace("datatypes-infer") << "DtInfer : instantiate : " << eq << " by " << exp << std::endl; - //eqc->d_inst.set( eq ); - d_infer.push_back( eq ); - d_infer_exp.push_back( exp ); + eqc->d_inst = true; } //} //else{ diff --git a/src/theory/datatypes/theory_datatypes.h b/src/theory/datatypes/theory_datatypes.h index 82306b863..4dd621c86 100644 --- a/src/theory/datatypes/theory_datatypes.h +++ b/src/theory/datatypes/theory_datatypes.h @@ -39,6 +39,7 @@ private: typedef context::CDChunkList NodeList; typedef context::CDHashMap NodeListMap; typedef context::CDHashMap< Node, bool, NodeHashFunction > BoolMap; + typedef context::CDHashMap< Node, Node, NodeHashFunction > NodeMap; /** transitive closure to record equivalence/subterm relation. */ //TransitiveClosureNode d_cycle_check; @@ -131,7 +132,10 @@ private: /** get the possible constructors for n */ void getPossibleCons( EqcInfo* eqc, Node n, std::vector< bool >& cons ); /** mkExpDefSkolem */ - void mkExpDefSkolem( Node sel, TypeNode dt, TypeNode rt ); + void mkExpDefSkolem( Node sel, TypeNode dt, TypeNode rt ); + /** skolems for terms */ + NodeMap d_term_sk; + Node getTermSkolemFor( Node n ); private: /** The notify class */ NotifyClass d_notify; @@ -166,6 +170,7 @@ private: /** cache for which terms we have called collectTerms(...) on */ BoolMap d_collectTermsCache; /** pending assertions/merges */ + std::vector< Node > d_pending_lem; std::vector< Node > d_pending; std::map< Node, Node > d_pending_exp; std::vector< Node > d_pending_merge; @@ -257,7 +262,7 @@ public: void printModelDebug( const char* c ); private: /** add tester to equivalence class info */ - void addTester( Node t, EqcInfo* eqc, Node n ); + void addTester( int ttindex, Node t, EqcInfo* eqc, Node n, Node t_arg ); /** add selector to equivalence class info */ void addSelector( Node s, EqcInfo* eqc, Node n, bool assertFacts = true ); /** add constructor */ diff --git a/src/theory/quantifiers_engine.cpp b/src/theory/quantifiers_engine.cpp index e7a87927a..aee3ac4b8 100644 --- a/src/theory/quantifiers_engine.cpp +++ b/src/theory/quantifiers_engine.cpp @@ -929,6 +929,9 @@ bool QuantifiersEngine::addInstantiation( Node q, std::vector< Node >& terms, bo } Trace("inst-add-debug") << " -> " << terms[i] << std::endl; Assert( !terms[i].isNull() ); +#ifdef CVC4_ASSERTIONS + Assert( !quantifiers::TermDb::containsUninterpretedConstant( terms[i] ) ); +#endif } //check based on instantiation level -- 2.30.2