From 45b0ba984fde882d3cd762076de0f9ddce2485c7 Mon Sep 17 00:00:00 2001 From: ajreynol Date: Thu, 5 Feb 2015 14:50:09 +0100 Subject: [PATCH] Working version of sygus solution reconstruction from single inv cegqi. Heuristics to fit syntax. --- .../quantifiers/ce_guided_single_inv.cpp | 376 +++++++++++++----- src/theory/quantifiers/ce_guided_single_inv.h | 17 +- src/theory/quantifiers/term_database.cpp | 82 ++++ src/theory/quantifiers/term_database.h | 8 + 4 files changed, 373 insertions(+), 110 deletions(-) diff --git a/src/theory/quantifiers/ce_guided_single_inv.cpp b/src/theory/quantifiers/ce_guided_single_inv.cpp index e06c384c3..1dd60e583 100644 --- a/src/theory/quantifiers/ce_guided_single_inv.cpp +++ b/src/theory/quantifiers/ce_guided_single_inv.cpp @@ -541,6 +541,9 @@ void CegConjectureSingleInv::check( QuantifiersEngine * qe, std::vector< Node >& }else if( const_n.isNull() ){ const_n = n; } + //TODO : fairness + if( !slv_n.isNull() && !const_n.isNull() ){ + } } } ++eqc_i; @@ -717,8 +720,14 @@ Node CegConjectureSingleInv::getSolution( QuantifiersEngine * qe, unsigned i, Ty s = Rewriter::rewrite( s ); Trace("cegqi-si-debug-sol") << "Solution (post-rewrite): " << s << std::endl; d_solution = s; + reconstructed = 0; if( options::cegqiSingleInvReconstruct() && !stn.isNull() ){ - collectReconstructNodes( qe->getTermDatabaseSygus(), d_solution, stn, Node::null(), TypeNode::null(), false ); + int status; + d_templ_solution = collectReconstructNodes( qe->getTermDatabaseSygus(), d_solution, stn, status ); + if( status==1 ){ + setNeedsReconstruction( d_templ_solution, stn, Node::null(), TypeNode::null() ); + } + Trace("cegqi-si-debug-sol") << "Induced solution template is : " << d_templ_solution << std::endl; std::vector< TypeNode > rcons_tn; for( std::map< TypeNode, std::map< Node, bool > >::iterator it = d_rcons_to_process.begin(); it != d_rcons_to_process.end(); ++it ){ TypeNode tn = it->first; @@ -739,22 +748,39 @@ Node CegConjectureSingleInv::getSolution( QuantifiersEngine * qe, unsigned i, Ty success = true; std::vector< TypeNode > to_erase; for( std::map< TypeNode, std::map< Node, bool > >::iterator it = d_rcons_to_process.begin(); it != d_rcons_to_process.end(); ++it ){ + TypeNode stn = it->first; if( it->second.empty() ){ - to_erase.push_back( it->first ); + to_erase.push_back( stn ); }else{ - Node ns = qe->getTermDatabase()->getEnumerateTerm( it->first, index ); + Node ns = qe->getTermDatabase()->getEnumerateTerm( stn, index ); if( ns.isNull() ){ - to_erase.push_back( it->first ); - incomplete.push_back( it->first ); + to_erase.push_back( stn ); + incomplete.push_back( stn ); }else{ - Node nb = qe->getTermDatabaseSygus()->sygusToBuiltin( ns, it->first ); - Node nr = Rewriter::rewrite( nb );//qe->getTermDatabaseSygus()->getNormalized( it->first, nb, false, false ); - Trace("cegqi-si-rcons-debug2") << " - try " << ns << " -> " << nr << " for " << it->first << " " << nr.getKind() << std::endl; + Node nb = qe->getTermDatabaseSygus()->sygusToBuiltin( ns, stn ); + Node nr = Rewriter::rewrite( nb );//qe->getTermDatabaseSygus()->getNormalized( stn, nb, false, false ); + Trace("cegqi-si-rcons-debug2") << " - try " << ns << " -> " << nr << " for " << stn << " " << nr.getKind() << std::endl; if( it->second.find( nr )!=it->second.end() ){ Trace("cegqi-si-rcons") << "...reconstructed " << ns << " for term " << nr << std::endl; - d_reconstructed[nr][it->first] = ns; - d_reconstructed_op[nr][it->first] = false; - setReconstructed( nr, it->first ); + d_reconstructed[nr][stn] = ns; + d_reconstructed_op[nr][stn] = false; + Assert( d_rewrite_to_rcons.find( nr )!=d_rewrite_to_rcons.end() ); + Node nrr = d_rewrite_to_rcons[nr]; + setReconstructed( nrr, stn ); + }else{ + /* + // look if it has a kind of a term that we need to reconstruct TODO + Kind nrk = nr.getKind(); + std::map< Kind, std::map< Node, bool > >::iterator itk = d_rcons_kinds_to_process[stn].find( nrk ); + if( itk!=d_rcons_kinds_to_process[stn].end() ){ + Trace("cegqi-si-rcons") << "Term " << ns << " -> " << nr << " did not match, but has same operator " << nrk; + Trace("cegqi-si-rcons") << " as " << itk->second.size() << " waiting terms." << std::endl; + Assert( !itk->second.empty() ); + for( std::map< Node, bool >::iterator itkn = itk->second.begin(); itkn != itk->second.end(); ++itkn ){ + + } + } + */ } success = false; } @@ -769,11 +795,14 @@ Node CegConjectureSingleInv::getSolution( QuantifiersEngine * qe, unsigned i, Ty Trace("cegqi-si-rcons-stats") << "Tried " << index << " for each type." << std::endl; } if( success ){ - reconstructed = incomplete.empty() ? 1 : 0; + if( incomplete.empty() ){ + reconstructed = 1; + Trace("cegqi-si-debug-sol") << "Reconstructing sygus solution..." << std::endl; + d_sygus_solution = getReconstructedSolution( qe->getTermDatabaseSygus(), stn, d_templ_solution ); + Trace("cegqi-si-debug-sol") << "Sygus solution is : " << d_sygus_solution << std::endl; + } } }while( !success ); - }else{ - reconstructed = 0; } if( Trace.isOn("cegqi-si-debug-sol") ){ //debug solution @@ -794,9 +823,18 @@ Node CegConjectureSingleInv::getSolution( QuantifiersEngine * qe, unsigned i, Ty num_ite = 0; debugTermSize( d_solution, t_size, num_ite ); //Trace("cegqi-stats") << "simplified size " << t_size << " #ite " << num_ite << std::endl; + Trace("cegqi-stats") << t_size << " " << num_ite << " "; + t_size = 0; + num_ite = 0; + debugTermSize( d_templ_solution, t_size, num_ite ); + //Trace("cegqi-stats") << "sygus size " << t_size << " #ite " << num_ite << std::endl; Trace("cegqi-stats") << t_size << " " << num_ite << std::endl; } - return d_solution; + if( reconstructed==1 ){ + return d_sygus_solution; + }else{ + return d_solution; + } } bool CegConjectureSingleInv::debugSolution( Node sol ) { @@ -1120,11 +1158,13 @@ Node CegConjectureSingleInv::simplifySolution( QuantifiersEngine * qe, Node sol, std::map< Node, bool >::iterator it = atoms.find( atom ); if( it==atoms.end() ){ atoms[atom] = pol; - if( status==0 && ( atom.getKind()==IFF || atom.getKind()==EQUAL ) && ( pol==( sol.getKind()==AND ) ) ){ - Trace("csi-simp") << " ...equality." << std::endl; - if( getAssignEquality( qe, atom, vars, new_vars, new_subs, args ) ){ - children.push_back( sol[i] ); - do_exc = true; + if( status==0 && ( atom.getKind()==IFF || atom.getKind()==EQUAL ) ){ + if( pol==( sol.getKind()==AND ) ){ + Trace("csi-simp") << " ...equality." << std::endl; + if( getAssignEquality( qe, atom, vars, new_vars, new_subs, args ) ){ + children.push_back( sol[i] ); + do_exc = true; + } } } }else{ @@ -1181,7 +1221,39 @@ Node CegConjectureSingleInv::simplifySolution( QuantifiersEngine * qe, Node sol, } } } - return children.size()==0 ? NodeManager::currentNM()->mkConst( sol.getKind()==AND ) : ( children.size()==1 ? children[0] : NodeManager::currentNM()->mkNode( sol.getKind(), children ) ); + // now, remove all equalities that are implied + std::vector< Node > final_children; + for( unsigned i=0; i tmp_vars; + std::vector< Node > tmp_subs; + if( getAssignEquality( qe, atom, vars, tmp_vars, tmp_subs, args ) ){ + Trace("csi-simp-debug") << "Check if " << children[i] << " is redundant in " << sol << std::endl; + for( unsigned j=0; ji || std::find( final_children.begin(), final_children.end(), children[j] )!=final_children.end() ) ){ + Node sj = children[j].substitute( tmp_vars.begin(), tmp_vars.end(), tmp_subs.begin(), tmp_subs.end() ); + sj = Rewriter::rewrite( sj ); + if( sj==qe->getTermDatabase()->d_true ){ + Trace("csi-simp") << "--- " << children[i].negate() << " is implied by " << children[j].negate() << std::endl; + red = true; + break; + } + } + } + } + } + } + if( !red ){ + final_children.push_back( children[i] ); + } + } + + return final_children.size()==0 ? NodeManager::currentNM()->mkConst( sol.getKind()==AND ) : + ( final_children.size()==1 ? final_children[0] : NodeManager::currentNM()->mkNode( sol.getKind(), final_children ) ); }else{ //generic simplification std::vector< Node > children; @@ -1202,54 +1274,8 @@ Node CegConjectureSingleInv::simplifySolution( QuantifiersEngine * qe, Node sol, } } -Node CegConjectureSingleInv::getSolutionTemplate( TermDbSygus * tds, Node n, TypeNode stn, Node parent, int arg ) { - Assert( datatypes::DatatypesRewriter::isTypeDatatype( stn ) ); - const Datatype& dt = ((DatatypeType)(stn).toType()).getDatatype(); - Assert( dt.isSygus() ); - tds->registerSygusType( stn ); - int karg = tds->getKindArg( stn, n.getKind() ); - if( karg!=-1 ){ - std::vector< Node > children; - if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){ - children.push_back( n.getOperator() ); - } - if( n.getNumChildren()!=dt[karg].getNumArgs() ){ - if( n.getNumChildren()>dt[karg].getNumArgs() && tds->isAssoc( n.getKind() ) && dt[karg].getNumArgs()==2 ){ - // make n-ary applications into binary ones - Node n1 = getSolutionTemplate( tds, n[0], tds->getArgType( dt[karg], 0 ), n, 0 ); - for( unsigned i=1; imkNode( n.getKind(), children ); - n2 = getSolutionTemplate( tds, n2, tds->getArgType( dt[karg], 1 ), Node::null(), -1 ); - return NodeManager::currentNM()->mkNode( n.getKind(), n1, n2 ); - } - }else{ - bool childChanged = false; - for( unsigned i=0; igetArgType( dt[karg], i ), n, i ); - children.push_back( nc ); - childChanged = childChanged || nc!=n[i]; - } - if( !childChanged || n.getNumChildren()==0 ){ - return n; - }else{ - return NodeManager::currentNM()->mkNode( n.getKind(), children ); - } - } - }else{ - int carg = tds->getOpArg( stn, n ); - if( carg!=-1 ){ - return n; - }else if( n.isConst() ){ - // check if constant exists in grammar TODO - - } - } - return n; -} - -void CegConjectureSingleInv::collectReconstructNodes( TermDbSygus * tds, Node t, TypeNode stn, Node parent, TypeNode pstn, bool ignoreBoolean ) { +Node CegConjectureSingleInv::collectReconstructNodes( TermDbSygus * tds, Node t, TypeNode stn, int& status ) { + /* if( ignoreBoolean && t.getType().isBoolean() ){ if( t.getKind()==OR || t.getKind()==AND || t.getKind()==IFF || t.getKind()==ITE || t.getKind()==NOT ){ //FIXME for( unsigned i=0; i::iterator it = d_rcons_processed[t].find( stn ); + if( it==d_rcons_processed[t].end() ){ TypeNode tn = t.getType(); - d_rcons_processed[t][stn][parent].push_back( pstn ); Assert( datatypes::DatatypesRewriter::isTypeDatatype( stn ) ); const Datatype& dt = ((DatatypeType)(stn).toType()).getDatatype(); Assert( dt.isSygus() ); Trace("cegqi-si-rcons-debug") << "Check reconstruct " << t << " type " << tn << ", sygus type " << dt.getName() << std::endl; tds->registerSygusType( stn ); - int arg = tds->getKindArg( stn, t.getKind() ); - bool processed = false; - if( arg!=-1 ){ - if( t.getNumChildren()==dt[arg].getNumArgs() ){ + Node ret; + std::vector< Node > children; + if( t.getMetaKind() == kind::metakind::PARAMETERIZED ){ + children.push_back( t.getOperator() ); + } + bool childChanged = false; + std::vector< Node > rcons_child; + std::vector< TypeNode > rcons_child_tn; + Node rcons; + bool rcons_op; + bool rcons_set = false; + Kind tk = t.getKind(); + int karg = tds->getKindArg( stn, tk ); + //preprocessing to fit syntax + Node orig_t = t; + if( karg==-1 && t.getNumChildren()>0 ){ + Node new_t; + Kind dk; + if( tds->isAntisymmetric( tk, dk ) ){ + if( tds->hasKind( stn, dk ) ){ + new_t = NodeManager::currentNM()->mkNode( dk, t[1], t[0] ); + } + } + if( new_t.isNull() ){ + for( unsigned i=0; igetGenericBase( stn, i ); + if( g.getKind()==t.getKind() ){ + Trace("cegqi-si-rcons-debug") << "Possible match ? " << g << " " << t << " for " << dt[i].getName() << std::endl; + std::map< int, Node > sigma; + if( tds->getMatch( g, t, sigma ) ){ + //we found an exact match + bool success = true; + for( unsigned j=0; j var_count; + new_t = tds->mkGeneric( dt, i, var_count, sigma ); + Trace("cegqi-si-rcons-debug") << "Rewrote to : " << new_t << std::endl; + break; + } + } + } + } + } + if( !new_t.isNull() ){ + t = new_t; + } + }else{ + //flatten ITEs if necessary + if( t.getKind()==ITE ){ + TypeNode cstn = tds->getArgType( dt[karg], 0 ); + tds->registerSygusType( cstn ); + if( !tds->hasKind( cstn, t[0].getKind() ) ){ + t = flattenITEs( t, false ); + Assert( t.getKind()==ITE ); + } + } + } + if( t!=orig_t ){ + karg = tds->getKindArg( stn, t.getKind() ); + } + if( karg!=-1 ){ + if( t.getNumChildren()==dt[karg].getNumArgs() ){ Trace("cegqi-si-rcons-debug") << " Type has kind " << t.getKind() << ", recurse." << std::endl; for( unsigned i=0; igetArgType( dt[arg], i ); - collectReconstructNodes( tds, t[i], stnc, t, stn, ignB ); + TypeNode cstn = tds->getArgType( dt[karg], i ); + int status; + Node tc = collectReconstructNodes( tds, t[i], cstn, status ); + if( status==1 ){ + rcons_child.push_back( tc ); + rcons_child_tn.push_back( cstn ); + } + children.push_back( tc ); + childChanged = childChanged || tc!=t[i]; } - d_reconstructed[t][stn] = Node::fromExpr( dt[arg].getSygusOp() ); - d_reconstructed_op[t][stn] = true; - processed = true; + rcons = Node::fromExpr( dt[karg].getConstructor() ); + rcons_op = true; + rcons_set = true; }else{ - Trace("cegqi-si-rcons-debug") << " Type has kind " << t.getKind() << ", but argument mismatch, with parent " << parent << std::endl; + Trace("cegqi-si-rcons-debug") << " Type has kind " << t.getKind() << ", but argument mismatch " << std::endl; + if( t.getNumChildren()>dt[karg].getNumArgs() && tds->isAssoc( t.getKind() ) && dt[karg].getNumArgs()==2 ){ + // make n-ary applications into binary ones + TypeNode cstn = tds->getArgType( dt[karg], 0 ); + int status; + Node t1 = collectReconstructNodes( tds, t[0], cstn, status ); + children.push_back( t1 ); + if( status==1 ){ + rcons_child.push_back( t1 ); + rcons_child_tn.push_back( cstn ); + } + std::vector< Node > rem_children; + for( unsigned i=1; imkNode( t.getKind(), rem_children ); + cstn = tds->getArgType( dt[karg], 1 ); + t2 = collectReconstructNodes( tds, t2, cstn, status ); + children.push_back( t2 ); + if( status==1 ){ + rcons_child.push_back( t2 ); + rcons_child_tn.push_back( cstn ); + } + childChanged = true; + rcons = Node::fromExpr( dt[karg].getConstructor() ); + rcons_op = true; + rcons_set = true; + } } } - if( !processed ){ + if( !rcons_set ){ int carg = tds->getOpArg( stn, t ); if( carg==-1 ){ - Trace("cegqi-si-rcons") << "...Reconstruction needed for " << t << " sygus type " << dt.getName() << " with parent " << parent << std::endl; + if( t.isConst() ){ + ret = tds->builtinToSygusConst( t, stn ); + } + if( ret.isNull() ){ + Trace("cegqi-si-rcons") << "...Reconstruction needed for " << t << " sygus type " << dt.getName() << std::endl; + } }else{ - d_reconstructed[t][stn] = Node::fromExpr( dt[carg].getSygusOp() ); - d_reconstructed_op[t][stn] = false; - processed = true; + rcons = NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, Node::fromExpr( dt[carg].getConstructor() ) ); + rcons_op = false; + rcons_set = true; Trace("cegqi-si-rcons-debug") << " Type has constant." << std::endl; } } - //add to parent if necessary - if( !processed || !d_rcons_graph[0][t][stn].empty() ){ - d_rcons_graph[0][parent][pstn][t][stn] = true; - if( !parent.isNull() ){ - d_rcons_to_process[pstn][parent] = true; + if( ret.isNull() ){ + if( !childChanged ){ + ret = t; + }else{ + Trace("cegqi-si-rcons-debug") << "Make return node " << t.getKind() << " with " << children.size() << " children." << std::endl; + ret = NodeManager::currentNM()->mkNode( t.getKind(), children ); } - d_rcons_graph[1][t][stn][parent][pstn] = true; - d_rcons_to_process[stn][t] = true; } + // now, construct d_rcons_graph + for( unsigned i=0; isecond; } } +void CegConjectureSingleInv::setNeedsReconstruction( Node t, TypeNode stn, Node parent, TypeNode pstn ) { + Trace("cegqi-si-rcons-debug") << "Set reconstruction for " << t << " " << stn << " " << parent << " " << pstn << std::endl; + d_rcons_graph[0][parent][pstn][t][stn] = true; + if( !parent.isNull() ){ + Node parentr = Rewriter::rewrite( parent ); + d_rewrite_to_rcons[parentr] = parent; + d_rcons_to_rewrite[parent] = parentr; + d_rcons_to_process[pstn][parentr] = true; + } + d_rcons_graph[1][t][stn][parent][pstn] = true; + Node tr = Rewriter::rewrite( t ); + d_rewrite_to_rcons[tr] = t; + d_rcons_to_rewrite[t] = tr; + d_rcons_to_process[stn][tr] = true; +} + void CegConjectureSingleInv::setReconstructed( Node t, TypeNode stn ) { + Assert( !t.isNull() ); if( Trace.isOn("cegqi-si-rcons-debug") ){ const Datatype& dt = ((DatatypeType)(stn).toType()).getDatatype(); Trace("cegqi-si-rcons-debug") << "set rcons : " << t << " in syntax " << dt.getName() << std::endl; @@ -1317,19 +1478,20 @@ void CegConjectureSingleInv::setReconstructed( Node t, TypeNode stn ) { for( unsigned r=0; r<2; r++){ unsigned ro = r==0 ? 1 : 0; for( std::map< Node, std::map< TypeNode, bool > >::iterator it = d_rcons_graph[r][t][stn].begin(); it != d_rcons_graph[r][t][stn].end(); ++it ){ + Node curr = it->first; TypeNode stnc; for( std::map< TypeNode, bool >::iterator it2 = it->second.begin(); it2 != it->second.end(); ++it2 ){ stnc = it2->first; - d_rcons_graph[ro][it->first][stnc][t].erase( stn ); - if( d_rcons_graph[ro][it->first][stnc][t].empty() ){ - d_rcons_graph[ro][it->first][stnc].erase( t ); + d_rcons_graph[ro][curr][stnc][t].erase( stn ); + if( d_rcons_graph[ro][curr][stnc][t].empty() ){ + d_rcons_graph[ro][curr][stnc].erase( t ); }else{ - Trace("cegqi-si-rcons-debug") << " " << ( r==0 ? "child" : "parent" ) << " " << it->first << " now has " << d_rcons_graph[ro][it->first][stnc][t].size() << std::endl; + Trace("cegqi-si-rcons-debug") << " " << ( r==0 ? "child" : "parent" ) << " " << curr << " now has " << d_rcons_graph[ro][curr][stnc][t].size() << std::endl; } } - if( d_rcons_graph[ro][it->first][stnc].empty() ){ - if( !it->first.isNull() ){ - to_set.push_back( it->first ); + if( d_rcons_graph[ro][curr][stnc].empty() ){ + if( !curr.isNull() ){ + to_set.push_back( curr ); to_set_tn.push_back( stnc ); } } @@ -1338,20 +1500,26 @@ void CegConjectureSingleInv::setReconstructed( Node t, TypeNode stn ) { for( unsigned r=0; r<2; r++){ d_rcons_graph[r].erase( t ); } - d_rcons_to_process[stn].erase( t ); + d_rcons_to_process[stn].erase( d_rcons_to_rewrite[t] ); for( unsigned i=0; i::iterator it = d_reconstructed[t].find( stn ); if( it!=d_reconstructed[t].end() ){ if( d_reconstructed_op[t][stn] ){ + Assert( datatypes::DatatypesRewriter::isTypeDatatype( stn ) ); + const Datatype& dt = ((DatatypeType)(stn).toType()).getDatatype(); + Assert( dt.isSygus() ); std::vector< Node > children; children.push_back( it->second ); + int c = tds->getKindArg( stn, t.getKind() ); + Assert( c!=-1 ); for( unsigned i=0; igetArgType( dt[c], i ); + Node nc = getReconstructedSolution( tds, stnc, t[i] ); children.push_back( nc ); } return NodeManager::currentNM()->mkNode( APPLY_CONSTRUCTOR, children ); diff --git a/src/theory/quantifiers/ce_guided_single_inv.h b/src/theory/quantifiers/ce_guided_single_inv.h index 1aba90468..44a8ed6e4 100644 --- a/src/theory/quantifiers/ce_guided_single_inv.h +++ b/src/theory/quantifiers/ce_guided_single_inv.h @@ -68,6 +68,8 @@ public: // solution Node d_orig_solution; Node d_solution; + Node d_templ_solution; + Node d_sygus_solution; public: //get the single invocation lemma Node getSingleInvLemma( Node guard ); @@ -93,18 +95,21 @@ private: bool getAssignEquality( QuantifiersEngine * qe, Node eq, std::vector< Node >& vars, std::vector< Node >& new_vars, std::vector< Node >& new_subs, std::vector< Node >& args ); //solution reconstruction private: - std::map< Node, std::map< TypeNode, std::map< Node, std::vector< TypeNode > > > > d_rcons_processed; + std::map< Node, std::map< TypeNode, Node > > d_rcons_processed; + std::map< Node, std::map< TypeNode, int > > d_rcons_processed_status; std::map< Node, std::map< TypeNode, Node > > d_reconstructed; std::map< Node, std::map< TypeNode, bool > > d_reconstructed_op; std::map< Node, std::map< TypeNode, std::map< Node, std::map< TypeNode, bool > > > > d_rcons_graph[2]; std::map< TypeNode, std::map< Node, bool > > d_rcons_to_process; - // term t with sygus type st - Node getSolutionTemplate( TermDbSygus * tds, Node n, TypeNode stn, Node parent, int arg ); - void collectReconstructNodes( TermDbSygus * tds, Node t, TypeNode stn, Node parent, TypeNode pstn, bool ignoreBoolean ); + std::map< Node, Node > d_rewrite_to_rcons; + std::map< Node, Node > d_rcons_to_rewrite; + // term t with sygus type st, returns inducted templated form of t + Node collectReconstructNodes( TermDbSygus * tds, Node t, TypeNode stn, int& status ); // set reconstructed - void setReconstructed( Node t, TypeNode stn ); + void setNeedsReconstruction( Node t, TypeNode stn, Node parent, TypeNode pstn ); + void setReconstructed( Node tr, TypeNode stn ); // get solution - Node getReconstructedSolution( TypeNode stn, Node t ); + Node getReconstructedSolution( TermDbSygus * tds, TypeNode stn, Node t ); }; } diff --git a/src/theory/quantifiers/term_database.cpp b/src/theory/quantifiers/term_database.cpp index c5a3cec4d..a7385f027 100644 --- a/src/theory/quantifiers/term_database.cpp +++ b/src/theory/quantifiers/term_database.cpp @@ -1339,6 +1339,7 @@ TNode TermDbSygus::getVar( TypeNode tn, int i ) { Assert( !vtn.isNull() ); Node v = NodeManager::currentNM()->mkSkolem( ss.str(), vtn, "for sygus normal form testing" ); d_fv_stype[v] = tn; + d_fv_num[v] = i; d_fv[tn].push_back( v ); } return d_fv[tn][i]; @@ -1361,6 +1362,48 @@ TypeNode TermDbSygus::getSygusType( Node v ) { return d_fv_stype[v]; } +bool TermDbSygus::getMatch( Node p, Node n, std::map< int, Node >& s ) { + std::vector< int > new_s; + return getMatch2( p, n, s, new_s ); +} + +bool TermDbSygus::getMatch2( Node p, Node n, std::map< int, Node >& s, std::vector< int >& new_s ) { + std::map< Node, int >::iterator it = d_fv_num.find( p ); + if( it!=d_fv_num.end() ){ + Node prev = s[it->second]; + s[it->second] = n; + if( prev.isNull() ){ + new_s.push_back( it->second ); + } + return prev.isNull() || prev==n; + }else if( n.getNumChildren()==0 ){ + return p==n; + }else if( n.getKind()==p.getKind() && n.getNumChildren()==p.getNumChildren() ){ + //try both ways? + unsigned rmax = isComm( n.getKind() ) && n.getNumChildren()==2 ? 2 : 1; + std::vector< int > new_tmp; + for( unsigned r=0; r& var_count, std::map< int, Node >& pre ) { Assert( c>=0 && c<(int)dt.getNumConstructors() ); Assert( dt.isSygus() ); @@ -1400,6 +1443,24 @@ Node TermDbSygus::mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int } } +Node TermDbSygus::getGenericBase( TypeNode tn, int c ) { + std::map< int, Node >::iterator it = d_generic_base[tn].find( c ); + if( it==d_generic_base[tn].end() ){ + Assert( datatypes::DatatypesRewriter::isTypeDatatype( tn ) ); + const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + Assert( dt.isSygus() ); + registerSygusType( tn ); + std::map< TypeNode, int > var_count; + std::map< int, Node > pre; + Node g = mkGeneric( dt, c, var_count, pre ); + Node gr = Rewriter::rewrite( g ); + d_generic_base[tn][c] = gr; + return gr; + }else{ + return it->second; + } +} + Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) { std::map< Node, Node >::iterator it = d_sygus_to_builtin[tn].find( n ); if( it==d_sygus_to_builtin[tn].end() ){ @@ -1420,6 +1481,27 @@ Node TermDbSygus::sygusToBuiltin( Node n, TypeNode tn ) { } } +Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn ) { + std::map< Node, Node >::iterator it = d_builtin_const_to_sygus[tn].find( c ); + if( it==d_builtin_const_to_sygus[tn].end() ){ + Assert( c.isConst() ); + Assert( datatypes::DatatypesRewriter::isTypeDatatype( tn ) ); + const Datatype& dt = ((DatatypeType)(tn).toType()).getDatatype(); + Assert( dt.isSygus() ); + Node sc; + int carg = getOpArg( tn, c ); + if( carg!=-1 ){ + sc = Node::fromExpr( dt[carg].getSygusOp() ); + }else{ + //TODO + } + d_builtin_const_to_sygus[tn][c] = sc; + return sc; + }else{ + return it->second; + } +} + Node TermDbSygus::getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ) { return n; if( n.getKind()==SKOLEM ){ diff --git a/src/theory/quantifiers/term_database.h b/src/theory/quantifiers/term_database.h index 0f730929c..5a3419cee 100644 --- a/src/theory/quantifiers/term_database.h +++ b/src/theory/quantifiers/term_database.h @@ -368,10 +368,14 @@ class TermDbSygus { private: std::map< TypeNode, std::vector< Node > > d_fv; std::map< Node, TypeNode > d_fv_stype; + std::map< Node, int > d_fv_num; public: TNode getVar( TypeNode tn, int i ); TNode getVarInc( TypeNode tn, std::map< TypeNode, int >& var_count ); bool isVar( Node n ) { return d_fv_stype.find( n )!=d_fv_stype.end(); } + int getVarNum( Node n ) { return d_fv_num[n]; } + bool getMatch( Node p, Node n, std::map< int, Node >& s ); + bool getMatch2( Node p, Node n, std::map< int, Node >& s, std::vector< int >& new_s ); private: //information for sygus types std::map< TypeNode, TypeNode > d_register; //stores sygus type @@ -389,6 +393,8 @@ private: //normalized map std::map< TypeNode, std::map< Node, Node > > d_normalized; std::map< TypeNode, std::map< Node, Node > > d_sygus_to_builtin; + std::map< TypeNode, std::map< Node, Node > > d_builtin_const_to_sygus; + std::map< TypeNode, std::map< int, Node > > d_generic_base; public: TermDbSygus(){} bool isRegistered( TypeNode tn ); @@ -426,7 +432,9 @@ public: Node getTypeMaxValue( TypeNode tn ); TypeNode getSygusType( Node v ); Node mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ); + Node getGenericBase( TypeNode tn, int c ); Node sygusToBuiltin( Node n, TypeNode tn ); + Node builtinToSygusConst( Node c, TypeNode tn ); Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ); Node getNormalized( TypeNode t, Node prog, bool do_pre_norm = false, bool do_post_norm = true ); int getTermSize( Node n ); -- 2.30.2