From ed8f4388c859595178e303f65393105e99d4eb59 Mon Sep 17 00:00:00 2001 From: yoni206 Date: Fri, 19 Jul 2019 14:19:17 -0700 Subject: [PATCH] SyGuS grammar refactor (#3100) --- src/theory/quantifiers/sygus/cegis_unif.cpp | 10 ++- .../quantifiers/sygus/sygus_grammar_cons.cpp | 87 ++++++++++++++----- .../quantifiers/sygus/sygus_grammar_cons.h | 53 +++++++---- 3 files changed, 108 insertions(+), 42 deletions(-) diff --git a/src/theory/quantifiers/sygus/cegis_unif.cpp b/src/theory/quantifiers/sygus/cegis_unif.cpp index e34669425..b72e50484 100644 --- a/src/theory/quantifiers/sygus/cegis_unif.cpp +++ b/src/theory/quantifiers/sygus/cegis_unif.cpp @@ -455,16 +455,20 @@ Node CegisUnifEnumDecisionStrategy::mkLiteral(unsigned n) std::stringstream ss; ss << "_virtual_enum_grammar"; std::string virtualEnumName(ss.str()); - std::map> extra_cons; - std::map> exclude_cons; + std::map> extra_cons; + std::map> + exclude_cons; + std::map> + include_cons; // do not include "-", which is included by default for integers - exclude_cons[intTn].push_back(nm->operatorOf(MINUS)); + exclude_cons[intTn].insert(nm->operatorOf(MINUS)); std::unordered_set term_irrelevant; TypeNode vtn = CegGrammarConstructor::mkSygusDefaultType(intTn, bvl, virtualEnumName, extra_cons, exclude_cons, + include_cons, term_irrelevant); d_virtual_enum = nm->mkSkolem("_ve", vtn); d_tds->registerEnumerator( diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp index 263c88d15..bcd0d709e 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.cpp @@ -57,7 +57,10 @@ bool CegGrammarConstructor::hasSyntaxRestrictions(Node q) return false; } -void CegGrammarConstructor::collectTerms( Node n, std::map< TypeNode, std::vector< Node > >& consts ){ +void CegGrammarConstructor::collectTerms( + Node n, + std::map>& consts) +{ std::unordered_map visited; std::unordered_map::iterator it; std::stack visit; @@ -78,7 +81,7 @@ void CegGrammarConstructor::collectTerms( Node n, std::map< TypeNode, std::vecto } if( std::find( consts[tn].begin(), consts[tn].end(), c )==consts[tn].end() ){ Trace("cegqi-debug") << "...consider const : " << c << std::endl; - consts[tn].push_back( c ); + consts[tn].insert(c); } } // recurse @@ -97,12 +100,13 @@ Node CegGrammarConstructor::process(Node q, // now, construct the grammar Trace("cegqi") << "SynthConjecture : convert to deep embedding..." << std::endl; - std::map< TypeNode, std::vector< Node > > extra_cons; + std::map> extra_cons; if( options::sygusAddConstGrammar() ){ Trace("cegqi") << "SynthConjecture : collect constants..." << std::endl; collectTerms( q[1], extra_cons ); } - std::map> exc_cons; + std::map> exc_cons; + std::map> inc_cons; NodeManager* nm = NodeManager::currentNM(); @@ -150,8 +154,13 @@ Node CegGrammarConstructor::process(Node q, } // make the default grammar - tn = mkSygusDefaultType( - preGrammarType, sfvl, ss.str(), extra_cons, exc_cons, term_irlv); + tn = mkSygusDefaultType(preGrammarType, + sfvl, + ss.str(), + extra_cons, + exc_cons, + inc_cons, + term_irlv); } // sfvl may be null for constant synthesis functions Trace("cegqi-debug") << "...sygus var list associated with " << sf << " is " @@ -419,8 +428,10 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( TypeNode range, Node bvl, const std::string& fun, - std::map>& extra_cons, - std::map>& exc_cons, + std::map>& extra_cons, + std::map>& exc_cons, + const std::map>& + inc_cons, std::unordered_set& term_irrelevant, std::vector& datatypes, std::set& unres) @@ -516,13 +527,17 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( //add constants std::vector< Node > consts; mkSygusConstantsForType( types[i], consts ); - std::map< TypeNode, std::vector< Node > >::iterator itec = extra_cons.find( types[i] ); + std::map>::iterator + itec = extra_cons.find(types[i]); if( itec!=extra_cons.end() ){ - //consts.insert( consts.end(), itec->second.begin(), itec->second.end() ); - for (unsigned j = 0, size_j = itec->second.size(); j < size_j; ++j) + for (std::unordered_set::iterator set_it = + itec->second.begin(); + set_it != itec->second.end(); + set_it++) { - if( std::find( consts.begin(), consts.end(), itec->second[j] )==consts.end() ){ - consts.push_back( itec->second[j] ); + if (std::find(consts.begin(), consts.end(), *set_it) == consts.end()) + { + consts.push_back(*set_it); } } } @@ -551,9 +566,10 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( if (types[i].isReal()) { - for (unsigned j = 0; j < 2; j++) + // Add PLUS, MINUS + Kind kinds[2] = {PLUS, MINUS}; + for (const Kind k : kinds) { - Kind k = j == 0 ? PLUS : MINUS; Trace("sygus-grammar-def") << "...add for " << k << std::endl; ops[i].push_back(nm->operatorOf(k).toExpr()); cnames[i].push_back(kindToString(k)); @@ -760,18 +776,35 @@ void CegGrammarConstructor::mkSygusDefaultGrammar( { Trace("sygus-grammar-def") << "...make datatype " << datatypes[i] << std::endl; datatypes[i].setSygus( types[i].toType(), bvl.toExpr(), true, true ); - std::map>::iterator itexc = - exc_cons.find(types[i]); + std::map>::iterator + itexc = exc_cons.find(types[i]); + std::map>::const_iterator itinc = + inc_cons.find(types[i]); for (unsigned j = 0, size = ops[i].size(); j < size; ++j) { - // add the constructor if it is not excluded + // add the constructor if it is not excluded, + // and it is in inc_cons, in case it is not empty Node opn = Node::fromExpr(ops[i][j]); + Trace("sygus-grammar-def") + << "...considering " << opn.toString() << " of kind " << opn.getKind() + << " and of type " << opn.getType() << " and of kind of type " + << opn.getType().getKind() << " of metakind " << opn.getMetaKind() + << std::endl; if (itexc == exc_cons.end() || std::find(itexc->second.begin(), itexc->second.end(), opn) == itexc->second.end()) { - datatypes[i].addSygusConstructor( - ops[i][j], cnames[i][j], cargs[i][j], pcs[i][j], weights[i][j]); + Trace("sygus-grammar-def") << "......not excluded " << std::endl; + if ((opn.isVar()) || (opn.getType().getKind() != kind::TYPE_CONSTANT) + || (itinc == inc_cons.end()) + || (std::find(itinc->second.begin(), itinc->second.end(), opn) + != itinc->second.end())) + { + Trace("sygus-grammar-def") << "......included " << std::endl; + datatypes[i].addSygusConstructor( + ops[i][j], cnames[i][j], cargs[i][j], pcs[i][j], weights[i][j]); + } } } Trace("sygus-grammar-def") << "...built datatype " << datatypes[i] << " "; @@ -936,12 +969,19 @@ TypeNode CegGrammarConstructor::mkSygusDefaultType( TypeNode range, Node bvl, const std::string& fun, - std::map>& extra_cons, - std::map>& exclude_cons, + std::map>& extra_cons, + std::map>& + exclude_cons, + std::map>& + include_cons, std::unordered_set& term_irrelevant) { Trace("sygus-grammar-def") << "*** Make sygus default type " << range << ", make datatypes..." << std::endl; - for( std::map< TypeNode, std::vector< Node > >::iterator it = extra_cons.begin(); it != extra_cons.end(); ++it ){ + for (std::map>::iterator + it = extra_cons.begin(); + it != extra_cons.end(); + ++it) + { Trace("sygus-grammar-def") << " ...using " << it->second.size() << " extra constants for " << it->first << std::endl; } std::set unres; @@ -951,6 +991,7 @@ TypeNode CegGrammarConstructor::mkSygusDefaultType( fun, extra_cons, exclude_cons, + include_cons, term_irrelevant, datatypes, unres); diff --git a/src/theory/quantifiers/sygus/sygus_grammar_cons.h b/src/theory/quantifiers/sygus/sygus_grammar_cons.h index 7dfa9b478..c01731d1b 100644 --- a/src/theory/quantifiers/sygus/sygus_grammar_cons.h +++ b/src/theory/quantifiers/sygus/sygus_grammar_cons.h @@ -63,30 +63,44 @@ public: /** is the syntax restricted? */ bool isSyntaxRestricted() { return d_is_syntax_restricted; } /** make the default sygus datatype type corresponding to builtin type range - * bvl is the set of free variables to include in the grammar - * fun is for naming - * extra_cons is a set of extra constant symbols to include in the grammar, - * exclude_cons is used to exclude operators from the grammar, - * term_irrelevant is a set of terms that should not be included in the - * grammar. - */ + * arguments: + * - bvl: the set of free variables to include in the grammar + * - fun: used for naming + * - extra_cons: a set of extra constant symbols to include in the grammar, + * regardless of their inclusion in the default grammar. + * - exclude_cons: used to exclude operators from the grammar, + * - term_irrelevant: a set of terms that should not be included in the + * grammar. + * - include_cons: a set of operators such that if this set is not empty, + * its elements that are in the default grammar (and only them) + * will be included. + */ static TypeNode mkSygusDefaultType( TypeNode range, Node bvl, const std::string& fun, - std::map >& extra_cons, - std::map >& exclude_cons, + std::map>& extra_cons, + std::map>& + exclude_cons, + std::map>& + include_cons, std::unordered_set& term_irrelevant); /** make the default sygus datatype type corresponding to builtin type range */ static TypeNode mkSygusDefaultType(TypeNode range, Node bvl, const std::string& fun) { - std::map > extra_cons; - std::map > exclude_cons; + std::map> extra_cons; + std::map> exclude_cons; + std::map> include_cons; std::unordered_set term_irrelevant; - return mkSygusDefaultType( - range, bvl, fun, extra_cons, exclude_cons, term_irrelevant); + return mkSygusDefaultType(range, + bvl, + fun, + extra_cons, + exclude_cons, + include_cons, + term_irrelevant); } /** make the sygus datatype type that encodes the solution space (lambda * templ_arg. templ[templ_arg]) where templ_arg @@ -139,7 +153,9 @@ public: /** is the syntax restricted? */ bool d_is_syntax_restricted; /** collect terms */ - void collectTerms( Node n, std::map< TypeNode, std::vector< Node > >& consts ); + void collectTerms( + Node n, + std::map>& consts); //---------------- grammar construction // helper for mkSygusDefaultGrammar (makes unresolved type for mutually recursive datatype construction) static TypeNode mkUnresolvedType(const std::string& name, std::set& unres); @@ -155,11 +171,16 @@ public: TypeNode range, Node bvl, const std::string& fun, - std::map >& extra_cons, - std::map >& exclude_cons, + std::map>& + extra_cons, + std::map>& + exclude_cons, + const std::map>& + include_cons, std::unordered_set& term_irrelevant, std::vector& datatypes, std::set& unres); + // helper function for mkSygusTemplateType static TypeNode mkSygusTemplateTypeRec( Node templ, Node templ_arg, TypeNode templ_arg_sygus_type, Node bvl, const std::string& fun, unsigned& tcount ); -- 2.30.2