From: Andrew Reynolds Date: Tue, 14 Nov 2017 23:04:14 +0000 (-0600) Subject: (Refactor) Split sygus term db (#1335) X-Git-Tag: cvc5-1.0.0~5476 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=748e20967ae7698f6b545a5128633865701aeb2d;p=cvc5.git (Refactor) Split sygus term db (#1335) * Split explain utility, invariance tests. * Split extended rewriter. * Remove unused function. * Minor * Move generic term utilities to term_util. * Documentation, minor. * Make arguments private in eval invariance. * Document * More documentation. * Clang format. * Fix, improve. * Format * Address review. * Address missed comment. * Add line breaks * Format --- diff --git a/src/Makefile.am b/src/Makefile.am index c78e75426..7dcf73652 100644 --- a/src/Makefile.am +++ b/src/Makefile.am @@ -379,6 +379,8 @@ libcvc4_la_SOURCES = \ theory/quantifiers/equality_query.h \ theory/quantifiers/equality_infer.cpp \ theory/quantifiers/equality_infer.h \ + theory/quantifiers/extended_rewrite.cpp \ + theory/quantifiers/extended_rewrite.h \ theory/quantifiers/first_order_model.cpp \ theory/quantifiers/first_order_model.h \ theory/quantifiers/full_model_check.cpp \ @@ -433,6 +435,8 @@ libcvc4_la_SOURCES = \ theory/quantifiers/skolemize.h \ theory/quantifiers/sygus_explain.cpp \ theory/quantifiers/sygus_explain.h \ + theory/quantifiers/sygus_invariance.cpp \ + theory/quantifiers/sygus_invariance.h \ theory/quantifiers/sygus_grammar_cons.cpp \ theory/quantifiers/sygus_grammar_cons.h \ theory/quantifiers/sygus_process_conj.cpp \ diff --git a/src/theory/datatypes/datatypes_sygus.cpp b/src/theory/datatypes/datatypes_sygus.cpp index cca475d57..82766bf49 100644 --- a/src/theory/datatypes/datatypes_sygus.cpp +++ b/src/theory/datatypes/datatypes_sygus.cpp @@ -21,6 +21,7 @@ #include "theory/datatypes/datatypes_rewriter.h" #include "theory/datatypes/theory_datatypes.h" #include "theory/quantifiers/ce_guided_conjecture.h" +#include "theory/quantifiers/sygus_explain.h" #include "theory/quantifiers/term_database_sygus.h" #include "theory/quantifiers/term_util.h" #include "theory/theory_model.h" @@ -556,9 +557,9 @@ Node SygusSymBreakNew::getSimpleSymBreakPred( TypeNode tn, int tindex, unsigned Node req_const; if( nk==GT || nk==LT || nk==XOR || nk==MINUS || nk==BITVECTOR_SUB || nk==BITVECTOR_XOR || nk==BITVECTOR_UREM_TOTAL ){ //must have the zero element - req_const = d_tds->getTypeValue( tnb, 0 ); + req_const = quantifiers::TermUtil::mkTypeValue(tnb, 0); }else if( nk==EQUAL || nk==LEQ || nk==GEQ || nk==BITVECTOR_XNOR ){ - req_const = d_tds->getTypeMaxValue( tnb ); + req_const = quantifiers::TermUtil::mkTypeMaxValue(tnb); } // cannot do division since we have to consider when both are zero if( !req_const.isNull() ){ @@ -728,156 +729,6 @@ void SygusSymBreakNew::registerSearchTerm( TypeNode tn, unsigned d, Node n, bool } } -/** EquivSygusInvarianceTest -* -* This class is used to construct a minimal shape of a term that is equivalent -* up to rewriting to a RHS value, -* given as input bvr. -* -* For example, -* -* ite( t>0, 0, 0 ) + s*0 ----> 0 -* -* can be minimized to: -* -* ite( _, 0, 0 ) + _*0 ----> 0 -* -* It also manages the case where the rewriting is invariant wrt a finite set of -* examples occurring in the conjecture. -* -* It is an instance of quantifiers::SygusInvarianceTest which is the standard -* interface for term generalization via -* the TermRecBuild utility, which traverses the AST of a given term, replaces -* each subterm by a fresh variable and -* check whether the invariant, as specified by this class (equivalent up to -* rewriting to a RHS) holds. -* -* For details, see Reynolds et al SYNT 2017. -*/ -class EquivSygusInvarianceTest : public quantifiers::SygusInvarianceTest { -public: - EquivSygusInvarianceTest() : d_conj(nullptr) {} - ~EquivSygusInvarianceTest() {} - /** initialize this invariance test - * tn is the sygus type for e - * aconj/e are used for conjecture-specific symmetry breaking - * bvr is the builtin version of the right hand side of the rewrite that we are - * checking for invariance - */ - void init(quantifiers::TermDbSygus* tds, TypeNode tn, - quantifiers::CegConjecture* aconj, Node e, Node bvr) { - // compute the current examples - d_bvr = bvr; - if (aconj->getPbe()->hasExamples(e)) { - d_conj = aconj; - d_enum = e; - unsigned nex = aconj->getPbe()->getNumExamples(e); - for (unsigned i = 0; i < nex; i++) { - d_exo.push_back(d_conj->getPbe()->evaluateBuiltin(tn, bvr, e, i)); - } - } - } -protected: - /** does nvn still rewrite to d_bvr? */ - bool invariant(quantifiers::TermDbSygus* tds, Node nvn, Node x) { - TypeNode tn = nvn.getType(); - Node nbv = tds->sygusToBuiltin(nvn, tn); - Node nbvr = tds->extendedRewrite(nbv); - Trace("sygus-sb-mexp-debug") << " min-exp check : " << nbv << " -> " << nbvr - << std::endl; - bool exc_arg = false; - // equivalent / singular up to normalization - if (nbvr == d_bvr) { - // gives the same result : then the explanation for the child is irrelevant - exc_arg = true; - Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) - << " is rewritten to " << nbvr; - Trace("sygus-sb-mexp") << " regardless of the content of " - << tds->sygusToBuiltin(x) << std::endl; - } else { - if (nbvr.isVar()) { - TypeNode xtn = x.getType(); - if (xtn == tn) { - Node bx = tds->sygusToBuiltin(x, xtn); - Assert(bx.getType() == nbvr.getType()); - if (nbvr == bx) { - Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) - << " always rewrites to argument " << nbvr - << std::endl; - // rewrites to the variable : then the explanation of this is - // irrelevant as well - exc_arg = true; - d_bvr = nbvr; - } - } - } - } - // equivalent under examples - if (!exc_arg) { - if (!d_enum.isNull()) { - bool ex_equiv = true; - for (unsigned j = 0; j < d_exo.size(); j++) { - Node nbvr_ex = d_conj->getPbe()->evaluateBuiltin(tn, nbvr, d_enum, j); - if (nbvr_ex != d_exo[j]) { - ex_equiv = false; - break; - } - } - if (ex_equiv) { - Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn); - Trace("sygus-sb-mexp") - << " is the same w.r.t. examples regardless of the content of " - << tds->sygusToBuiltin(x) << std::endl; - exc_arg = true; - } - } - } - return exc_arg; - } - - private: - /** the conjecture associated with the enumerator d_enum */ - quantifiers::CegConjecture* d_conj; - /** the enumerator associated with the term we are doing an invariance test - * for */ - Node d_enum; - /** the RHS of the evaluation */ - Node d_bvr; - /** the result of the examples - * This is a finer-grained version of d_bvr, where for example if our input - * examples are: - * (x,y,z) = (3,2,4), (5,2,6), (3,2,1) - * On these examples, we have: - * - * ite( x>y, z, 0) ---> 4,6,1 - * - * which can be minimized to: - * - * ite( x>y, z, _) ---> 4,6,1 - */ - std::vector d_exo; -}; - - -class DivByZeroSygusInvarianceTest : public quantifiers::SygusInvarianceTest { -public: - DivByZeroSygusInvarianceTest(){} - ~DivByZeroSygusInvarianceTest(){} - -protected: - bool invariant( quantifiers::TermDbSygus * tds, Node nvn, Node x ){ - TypeNode tn = nvn.getType(); - Node nbv = tds->sygusToBuiltin( nvn, tn ); - Node nbvr = tds->extendedRewrite( nbv ); - if( tds->involvesDivByZero( nbvr ) ){ - Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin( nvn ) << " involves div-by-zero regardless of " << tds->sygusToBuiltin( x ) << std::endl; - return true; - }else{ - return false; - } - } -}; - bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, std::vector< Node >& lemmas ) { Assert( n.getType()==nv.getType() ); Assert( nv.getKind()==APPLY_CONSTRUCTOR ); @@ -904,16 +755,17 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, Trace("sygus-sb-debug") << " ...register search value " << nv << ", type=" << tn << std::endl; Node bv = d_tds->sygusToBuiltin( nv, tn ); Trace("sygus-sb-debug") << " ......builtin is " << bv << std::endl; - Node bvr = d_tds->extendedRewrite( bv ); + Node bvr = d_tds->getExtRewriter()->extendedRewrite(bv); Trace("sygus-sb-debug") << " ......rewrites to " << bvr << std::endl; unsigned sz = d_tds->getSygusTermSize( nv ); std::vector< Node > exp; bool do_exclude = false; if( d_tds->involvesDivByZero( bvr ) ){ Node x = getFreeVar( tn ); - DivByZeroSygusInvarianceTest dbzet; + quantifiers::DivByZeroSygusInvarianceTest dbzet; Trace("sygus-sb-mexp-debug") << "Minimize explanation for div-by-zero in " << d_tds->sygusToBuiltin( nv ) << std::endl; - d_tds->getExplanationFor( x, nv, exp, dbzet, Node::null(), sz ); + d_tds->getExplain()->getExplanationFor( + x, nv, exp, dbzet, Node::null(), sz); do_exclude = true; }else{ std::map< Node, Node >::iterator itsv = d_cache[a].d_search_val[tn].find( bvr ); @@ -977,10 +829,11 @@ bool SygusSymBreakNew::registerSearchValue( Node a, Node n, Node nv, unsigned d, Node x = getFreeVar( tn ); // do analysis of the evaluation FIXME: does not work (evaluation is non-constant) - EquivSygusInvarianceTest eset; + quantifiers::EquivSygusInvarianceTest eset; eset.init(d_tds, tn, aconj, a, bvr); Trace("sygus-sb-mexp-debug") << "Minimize explanation for eval[" << d_tds->sygusToBuiltin( bad_val ) << "] = " << bvr << std::endl; - d_tds->getExplanationFor( x, bad_val, exp, eset, bad_val_o, sz ); + d_tds->getExplain()->getExplanationFor( + x, bad_val, exp, eset, bad_val_o, sz); do_exclude = true; } } diff --git a/src/theory/quantifiers/ce_guided_conjecture.cpp b/src/theory/quantifiers/ce_guided_conjecture.cpp index c89b6b2b4..e1bc32761 100644 --- a/src/theory/quantifiers/ce_guided_conjecture.cpp +++ b/src/theory/quantifiers/ce_guided_conjecture.cpp @@ -513,7 +513,9 @@ Node CegConjecture::getNextDecisionRequest( unsigned& priority ) { if( !d_cinfo[cprog].d_inst.empty() ){ sol = d_cinfo[cprog].d_inst.back(); // add to explanation of exclusion - d_qe->getTermDatabaseSygus()->getExplanationForConstantEquality( cprog, sol, exp ); + d_qe->getTermDatabaseSygus() + ->getExplain() + ->getExplanationForConstantEquality(cprog, sol, exp); } Trace("cegqi-debug") << " " << cprog << " -> " << sol << std::endl; } diff --git a/src/theory/quantifiers/ce_guided_instantiation.cpp b/src/theory/quantifiers/ce_guided_instantiation.cpp index a2a6e2bbe..12c3c8464 100644 --- a/src/theory/quantifiers/ce_guided_instantiation.cpp +++ b/src/theory/quantifiers/ce_guided_instantiation.cpp @@ -238,6 +238,8 @@ void CegInstantiation::getCRefEvaluationLemmas( CegConjecture * conj, std::vecto if( conj->getNumRefinementLemmas()>0 ){ Assert( vs.size()==ms.size() ); + TermDbSygus* tds = d_quantEngine->getTermDatabaseSygus(); + Node nfalse = d_quantEngine->getTermUtil()->d_false; Node neg_guard = conj->getGuard().negate(); for( unsigned i=0; igetNumRefinementLemmas(); i++ ){ Node lem; @@ -255,7 +257,6 @@ void CegInstantiation::getCRefEvaluationLemmas( CegConjecture * conj, std::vecto lem_conj.push_back( lem ); } EvalSygusInvarianceTest vsit; - vsit.d_result = d_quantEngine->getTermUtil()->d_false; for( unsigned j=0; jgetTermDatabaseSygus()->evaluateWithUnfolding( lemcs, vsit.d_visited ); + Node lemcsu = vsit.doEvaluateWithUnfolding(tds, lemcs); Trace("sygus-cref-eval2") << "...after unfolding is : " << lemcsu << std::endl; if( lemcsu==d_quantEngine->getTermUtil()->d_false ){ std::vector< Node > msu; std::vector< Node > mexp; msu.insert( msu.end(), ms.begin(), ms.end() ); for( unsigned k=0; k& } } -/** NegContainsSygusInvarianceTest -* -* This class is used to construct a minimal shape of a term that cannot -* be contained in at least one output of an I/O pair. -* -* Say our PBE conjecture is: -* -* exists f. -* f( "abc" ) = "abc abc" ^ -* f( "de" ) = "de de" -* -* Then, this class is used when there is a candidate solution t[x1] such that -* either -* contains( "abc abc", t["abc"] ) ---> false or -* contains( "de de", t["de"] ) ---> false -* -* In particular it is used to determine whether certain generalizations of t[x1] -* are still sufficient to falsify one of the above containments. -* -* For example: -* -* str.++( x1, "d" ) can be minimized to str.++( _, "d" ) -* ...since contains( "abc abc", str.++( y, "d" ) ) ---> false, -* for any y. -* str.replace( "de", x1, "b" ) can be minimized to str.replace( "de", x1, _ ) -* ...since contains( "abc abc", str.replace( "de", "abc", y ) ) ---> false, -* for any y. -* -* It is an instance of quantifiers::SygusInvarianceTest, which -* traverses the AST of a given term, replaces each subterm by a -* fresh variable y and checks whether an invariance test (such as -* the one specified by this class) holds. -* -*/ -class NegContainsSygusInvarianceTest : public quantifiers::SygusInvarianceTest { -public: - NegContainsSygusInvarianceTest() : d_cpbe(nullptr){} - ~NegContainsSygusInvarianceTest(){} - - /** initialize this invariance test - * cpbe is the conjecture utility. - * e is the enumerator which we are reasoning about (associated with a synth - * fun). - * exo is the list of outputs of the PBE conjecture. - * ncind is the set of possible indices of the PBE conjecture to check - * invariance of non-containment. - * For example, in the above example, when t[x1] = "ab", then this - * has the index 1 since contains("de de", "ab") ---> false but not - * the index 0 since contains("abc abc","ab") ---> true. - */ - void init(quantifiers::CegConjecturePbe* cpbe, - Node e, - std::vector& exo, - std::vector& ncind) - { - if (cpbe->hasExamples(e)) - { - Assert(cpbe->getNumExamples(e) == exo.size()); - d_enum = e; - d_exo.insert( d_exo.end(), exo.begin(), exo.end() ); - d_neg_con_indices.insert( d_neg_con_indices.end(), ncind.begin(), ncind.end() ); - d_cpbe = cpbe; - } - } - - protected: - /** checks whether contains( out_i, nvn[in_i] ) --> false for some I/O pair i. - */ - bool invariant( quantifiers::TermDbSygus * tds, Node nvn, Node x ){ - if (!d_enum.isNull()) - { - TypeNode tn = nvn.getType(); - Node nbv = tds->sygusToBuiltin( nvn, tn ); - Node nbvr = tds->extendedRewrite( nbv ); - // if for any of the examples, it is not contained, then we can exclude - for( unsigned i=0; ievaluateBuiltin(tn, nbvr, d_enum, ii); - Node out = d_exo[ii]; - Node cont = NodeManager::currentNM()->mkNode( kind::STRING_STRCTN, out, nbvre ); - Node contr = Rewriter::rewrite( cont ); - if( contr==tds->d_false ){ - if( Trace.isOn("sygus-pbe-cterm") ){ - Trace("sygus-pbe-cterm") << "PBE-cterm : enumerator : do not consider "; - Trace("sygus-pbe-cterm") << nbv << " for any " << tds->sygusToBuiltin( x ) << " since " << std::endl; - Trace("sygus-pbe-cterm") << " PBE-cterm : for input example : "; - std::vector< Node > ex; - d_cpbe->getExample(d_enum, ii, ex); - for( unsigned j=0; j d_exo; - /** The set of I/O pair indices i such that - * contains( out_i, nvn[in_i] ) ---> false - */ - std::vector d_neg_con_indices; - /** reference to the PBE utility */ - quantifiers::CegConjecturePbe* d_cpbe; -}; - - bool CegConjecturePbe::getExplanationForEnumeratorExclude( Node c, Node x, Node v, std::vector< Node >& results, EnumInfo& ei, std::vector< Node >& exp ) { if( ei.d_enum_slave.size()==1 ){ // this check whether the example evaluates to something that is larger than the output @@ -1146,8 +1029,8 @@ bool CegConjecturePbe::getExplanationForEnumeratorExclude( Node c, Node x, Node if( !cmp_indices.empty() ){ //set up the inclusion set NegContainsSygusInvarianceTest ncset; - ncset.init(this, x, itxo->second, cmp_indices); - d_tds->getExplanationFor( x, v, exp, ncset ); + ncset.init(d_parent, x, itxo->second, cmp_indices); + d_tds->getExplain()->getExplanationFor(x, v, exp, ncset); Trace("sygus-pbe-cterm") << "PBE-cterm : enumerator exclude " << d_tds->sygusToBuiltin( v ) << " due to negative containment." << std::endl; return true; } diff --git a/src/theory/quantifiers/ce_guided_pbe.h b/src/theory/quantifiers/ce_guided_pbe.h index d69c94944..b357e4d15 100644 --- a/src/theory/quantifiers/ce_guided_pbe.h +++ b/src/theory/quantifiers/ce_guided_pbe.h @@ -77,7 +77,7 @@ class CegConjecture; * are equivalent up to examples on the above conjecture, since they have the * same value on the points x = 0,5,6. Hence, we need only consider one of * them. The interface for querying this is -* CegConjecturePbe::addSearchVal(...). +* CegConjecturePbe::addSearchVal(...). * For details, see Reynolds et al. SYNT 2017. * * (5) When the extension of quantifier-free datatypes procedure for SyGuS diff --git a/src/theory/quantifiers/extended_rewrite.cpp b/src/theory/quantifiers/extended_rewrite.cpp new file mode 100644 index 000000000..7d3f9afab --- /dev/null +++ b/src/theory/quantifiers/extended_rewrite.cpp @@ -0,0 +1,329 @@ +/********************* */ +/*! \file extended_rewrite.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of extended rewriting techniques + **/ + +#include "theory/quantifiers/extended_rewrite.h" + +#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/quantifiers/quant_util.h" // for QuantArith +#include "theory/rewriter.h" +#include "theory/strings/theory_strings_rewriter.h" + +using namespace CVC4::kind; +using namespace std; + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +ExtendedRewriter::ExtendedRewriter() +{ + d_true = NodeManager::currentNM()->mkConst(true); + d_false = NodeManager::currentNM()->mkConst(false); +} + +Node ExtendedRewriter::extendedRewritePullIte(Node n) +{ + // generalize this? + Assert(n.getNumChildren() == 2); + Assert(n.getType().isBoolean()); + Assert(n.getMetaKind() != kind::metakind::PARAMETERIZED); + std::vector children; + for (unsigned i = 0; i < n.getNumChildren(); i++) + { + children.push_back(n[i]); + } + for (unsigned i = 0; i < 2; i++) + { + if (n[i].getKind() == kind::ITE) + { + for (unsigned j = 0; j < 2; j++) + { + children[i] = n[i][j + 1]; + Node eqr = extendedRewrite( + NodeManager::currentNM()->mkNode(n.getKind(), children)); + children[i] = n[i]; + if (eqr.isConst()) + { + std::vector new_children; + Kind new_k; + if (eqr == d_true) + { + new_k = kind::OR; + new_children.push_back(j == 0 ? n[i][0] : n[i][0].negate()); + } + else + { + Assert(eqr == d_false); + new_k = kind::AND; + new_children.push_back(j == 0 ? n[i][0].negate() : n[i][0]); + } + children[i] = n[i][2 - j]; + Node rem_eq = NodeManager::currentNM()->mkNode(n.getKind(), children); + children[i] = n[i]; + new_children.push_back(rem_eq); + Node nc = NodeManager::currentNM()->mkNode(new_k, new_children); + Trace("q-ext-rewrite") << "sygus-extr : " << n << " rewrites to " + << nc << " by simple ITE pulling." + << std::endl; + // recurse + return extendedRewrite(nc); + } + } + } + } + return Node::null(); +} + +Node ExtendedRewriter::extendedRewrite(Node n) +{ + std::unordered_map::iterator it = + d_ext_rewrite_cache.find(n); + if (it == d_ext_rewrite_cache.end()) + { + Node ret = n; + if (n.getNumChildren() > 0) + { + std::vector children; + if (n.getMetaKind() == kind::metakind::PARAMETERIZED) + { + children.push_back(n.getOperator()); + } + bool childChanged = false; + for (unsigned i = 0; i < n.getNumChildren(); i++) + { + Node nc = extendedRewrite(n[i]); + childChanged = nc != n[i] || childChanged; + children.push_back(nc); + } + if (childChanged) + { + ret = NodeManager::currentNM()->mkNode(n.getKind(), children); + } + } + ret = Rewriter::rewrite(ret); + Trace("q-ext-rewrite-debug") << "Do extended rewrite on : " << ret + << " (from " << n << ")" << std::endl; + + Node new_ret; + if (ret.getKind() == kind::EQUAL) + { + // string equalities with disequal prefix or suffix + if (ret[0].getType().isString()) + { + std::vector c[2]; + for (unsigned i = 0; i < 2; i++) + { + strings::TheoryStringsRewriter::getConcat(ret[i], c[i]); + } + if (c[0].empty() == c[1].empty()) + { + if (!c[0].empty()) + { + for (unsigned i = 0; i < 2; i++) + { + unsigned index1 = i == 0 ? 0 : c[0].size() - 1; + unsigned index2 = i == 0 ? 0 : c[1].size() - 1; + if (c[0][index1].isConst() && c[1][index2].isConst()) + { + CVC4::String s = c[0][index1].getConst(); + CVC4::String t = c[1][index2].getConst(); + unsigned len_short = s.size() <= t.size() ? s.size() : t.size(); + bool isSameFix = + i == 1 ? s.rstrncmp(t, len_short) : s.strncmp(t, len_short); + if (!isSameFix) + { + Trace("q-ext-rewrite") << "sygus-extr : " << ret + << " rewrites to false due to " + "disequal string prefix/suffix." + << std::endl; + new_ret = d_false; + break; + } + } + } + } + } + else + { + new_ret = d_false; + } + } + if (new_ret.isNull()) + { + // simple ITE pulling + new_ret = extendedRewritePullIte(ret); + } + // TODO (as part of #1343) + // ( ~contains( x, y ) --> false ) => ( ~x=y --> false ) + } + else if (ret.getKind() == kind::ITE) + { + Assert(ret[1] != ret[2]); + if (ret[0].getKind() == NOT) + { + ret = NodeManager::currentNM()->mkNode( + kind::ITE, ret[0][0], ret[2], ret[1]); + } + if (ret[0].getKind() == kind::EQUAL) + { + // simple invariant ITE + for (unsigned i = 0; i < 2; i++) + { + if (ret[1] == ret[0][i] && ret[2] == ret[0][1 - i]) + { + Trace("q-ext-rewrite") << "sygus-extr : " << ret << " rewrites to " + << ret[2] << " due to simple invariant ITE." + << std::endl; + new_ret = ret[2]; + break; + } + } + // notice this is strictly more general than the above + if (new_ret.isNull()) + { + // simple substitution + for (unsigned i = 0; i < 2; i++) + { + TNode r1 = ret[0][i]; + TNode r2 = ret[0][1 - i]; + if (r1.isVar() && ((r2.isVar() && r1 < r2) || r2.isConst())) + { + Node retn = ret[1].substitute(r1, r2); + if (retn != ret[1]) + { + new_ret = NodeManager::currentNM()->mkNode( + kind::ITE, ret[0], retn, ret[2]); + Trace("q-ext-rewrite") + << "sygus-extr : " << ret << " rewrites to " << new_ret + << " due to simple ITE substitution." << std::endl; + } + } + } + } + } + } + else if (ret.getKind() == DIVISION || ret.getKind() == INTS_DIVISION + || ret.getKind() == INTS_MODULUS) + { + // rewrite as though total + std::vector children; + bool all_const = true; + for (unsigned i = 0; i < ret.getNumChildren(); i++) + { + if (ret[i].isConst()) + { + children.push_back(ret[i]); + } + else + { + all_const = false; + break; + } + } + if (all_const) + { + Kind new_k = + (ret.getKind() == DIVISION + ? DIVISION_TOTAL + : (ret.getKind() == INTS_DIVISION ? INTS_DIVISION_TOTAL + : INTS_MODULUS_TOTAL)); + new_ret = NodeManager::currentNM()->mkNode(new_k, children); + Trace("q-ext-rewrite") << "sygus-extr : " << ret << " rewrites to " + << new_ret << " due to total interpretation." + << std::endl; + } + } + // more expensive rewrites + if (new_ret.isNull()) + { + Trace("q-ext-rewrite-debug2") << "Do expensive rewrites on " << ret + << std::endl; + bool polarity = ret.getKind() != NOT; + Node ret_atom = ret.getKind() == NOT ? ret[0] : ret; + if ((ret_atom.getKind() == EQUAL && ret_atom[0].getType().isReal()) + || ret_atom.getKind() == GEQ) + { + Trace("q-ext-rewrite-debug2") << "Compute monomial sum " << ret_atom + << std::endl; + // compute monomial sum + std::map msum; + if (QuantArith::getMonomialSumLit(ret_atom, msum)) + { + for (std::map::iterator itm = msum.begin(); + itm != msum.end(); + ++itm) + { + Node v = itm->first; + Trace("q-ext-rewrite-debug2") << itm->first << " * " << itm->second + << std::endl; + if (v.getKind() == ITE) + { + Node veq; + int res = QuantArith::isolate(v, msum, veq, ret_atom.getKind()); + if (res != 0) + { + Trace("q-ext-rewrite-debug") + << " have ITE relation, solved form : " << veq + << std::endl; + // try pulling ITE + new_ret = extendedRewritePullIte(veq); + if (!new_ret.isNull()) + { + if (!polarity) + { + new_ret = new_ret.negate(); + } + break; + } + } + else + { + Trace("q-ext-rewrite-debug") << " failed to isolate " << v + << " in " << ret << std::endl; + } + } + } + } + else + { + Trace("q-ext-rewrite-debug") << " failed to get monomial sum of " + << ret << std::endl; + } + } + else if (ret_atom.getKind() == ITE) + { + // TODO : conditional rewriting + } + else if (ret.getKind() == kind::AND || ret.getKind() == kind::OR) + { + // TODO condition merging + } + } + + if (!new_ret.isNull()) + { + ret = Rewriter::rewrite(new_ret); + } + d_ext_rewrite_cache[n] = ret; + return ret; + } + else + { + return it->second; + } +} + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ diff --git a/src/theory/quantifiers/extended_rewrite.h b/src/theory/quantifiers/extended_rewrite.h new file mode 100644 index 000000000..3a9fdb918 --- /dev/null +++ b/src/theory/quantifiers/extended_rewrite.h @@ -0,0 +1,71 @@ +/********************* */ +/*! \file extended_rewrite.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief extended rewriting class + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__QUANTIFIERS__EXTENDED_REWRITE_H +#define __CVC4__THEORY__QUANTIFIERS__EXTENDED_REWRITE_H + +#include + +#include "expr/node.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +/** Extended rewriter + * + * This class is used for all rewriting that is not necessarily + * helpful for quantifier-free solving, but is helpful + * in other use cases. An example of this is SyGuS, where rewriting + * can be used for generalizing refinement lemmas, and hence + * should be highly aggressive. + * + * This class extended the standard techniques for rewriting + * with techniques, including but not limited to: + * - ITE branch merging, + * - ITE conditional variable elimination, + * - ITE condition subsumption, and + * - Aggressive rewriting for string equalities. + */ +class ExtendedRewriter +{ + public: + ExtendedRewriter(); + ~ExtendedRewriter() {} + /** return the extended rewritten form of n */ + Node extendedRewrite(Node n); + + private: + /** true and false nodes */ + Node d_true; + Node d_false; + /** cache for extendedRewrite */ + std::unordered_map d_ext_rewrite_cache; + /** pull ITE + * Do simple ITE pulling, e.g.: + * C2 --->^E false + * implies: + * ite( C, C1, C2 ) --->^E C ^ C1 + * where ---->^E denotes extended rewriting. + */ + Node extendedRewritePullIte(Node n); +}; + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ + +#endif /* __CVC4__THEORY__QUANTIFIERS__EXTENDED_REWRITE_H */ diff --git a/src/theory/quantifiers/sygus_explain.cpp b/src/theory/quantifiers/sygus_explain.cpp index 558e1b36b..4ae4d4391 100644 --- a/src/theory/quantifiers/sygus_explain.cpp +++ b/src/theory/quantifiers/sygus_explain.cpp @@ -14,6 +14,9 @@ #include "theory/quantifiers/sygus_explain.h" +#include "theory/datatypes/datatypes_rewriter.h" +#include "theory/quantifiers/term_database_sygus.h" + using namespace CVC4::kind; using namespace std; @@ -107,6 +110,192 @@ Node TermRecBuild::build(unsigned d) return NodeManager::currentNM()->mkNode(d_kind[d], children); } +void SygusExplain::getExplanationForConstantEquality(Node n, + Node vn, + std::vector& exp) +{ + std::map cexc; + getExplanationForConstantEquality(n, vn, exp, cexc); +} + +void SygusExplain::getExplanationForConstantEquality( + Node n, Node vn, std::vector& exp, std::map& cexc) +{ + Assert(vn.getKind() == kind::APPLY_CONSTRUCTOR); + Assert(n.getType() == vn.getType()); + TypeNode tn = n.getType(); + Assert(tn.isDatatype()); + const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype(); + int i = Datatype::indexOf(vn.getOperator().toExpr()); + Node tst = datatypes::DatatypesRewriter::mkTester(n, i, dt); + exp.push_back(tst); + for (unsigned j = 0; j < vn.getNumChildren(); j++) + { + if (cexc.find(j) == cexc.end()) + { + Node sel = NodeManager::currentNM()->mkNode( + kind::APPLY_SELECTOR_TOTAL, + Node::fromExpr(dt[i].getSelectorInternal(tn.toType(), j)), + n); + getExplanationForConstantEquality(sel, vn[j], exp); + } + } +} + +Node SygusExplain::getExplanationForConstantEquality(Node n, Node vn) +{ + std::map cexc; + return getExplanationForConstantEquality(n, vn, cexc); +} + +Node SygusExplain::getExplanationForConstantEquality( + Node n, Node vn, std::map& cexc) +{ + std::vector exp; + getExplanationForConstantEquality(n, vn, exp, cexc); + Assert(!exp.empty()); + return exp.size() == 1 ? exp[0] + : NodeManager::currentNM()->mkNode(kind::AND, exp); +} + +// we have ( n = vn => eval( n ) = bvr ) ^ vn != vnr , returns exp such that exp +// => ( eval( n ) = bvr ^ vn != vnr ) +void SygusExplain::getExplanationFor(TermRecBuild& trb, + Node n, + Node vn, + std::vector& exp, + std::map& var_count, + SygusInvarianceTest& et, + Node vnr, + Node& vnr_exp, + int& sz) +{ + Assert(vnr.isNull() || vn != vnr); + Assert(vn.getKind() == APPLY_CONSTRUCTOR); + Assert(vnr.isNull() || vnr.getKind() == APPLY_CONSTRUCTOR); + Assert(n.getType() == vn.getType()); + TypeNode ntn = n.getType(); + std::map cexc; + // for each child, + // check whether replacing that child by a fresh variable + // also satisfies the invariance test. + for (unsigned i = 0; i < vn.getNumChildren(); i++) + { + TypeNode xtn = vn[i].getType(); + Node x = d_tdb->getFreeVarInc(xtn, var_count); + trb.replaceChild(i, x); + Node nvn = trb.build(); + Assert(nvn.getKind() == kind::APPLY_CONSTRUCTOR); + if (et.is_invariant(d_tdb, nvn, x)) + { + cexc[i] = true; + // we are tracking term size if positive + if (sz >= 0) + { + int s = d_tdb->getSygusTermSize(vn[i]); + sz = sz - s; + } + } + else + { + trb.replaceChild(i, vn[i]); + } + } + const Datatype& dt = ((DatatypeType)ntn.toType()).getDatatype(); + int cindex = Datatype::indexOf(vn.getOperator().toExpr()); + Assert(cindex >= 0 && cindex < (int)dt.getNumConstructors()); + Node tst = datatypes::DatatypesRewriter::mkTester(n, cindex, dt); + exp.push_back(tst); + // if the operator of vn is different than vnr, then disunification obligation + // is met + if (!vnr.isNull()) + { + if (vnr.getOperator() != vn.getOperator()) + { + vnr = Node::null(); + vnr_exp = NodeManager::currentNM()->mkConst(true); + } + } + for (unsigned i = 0; i < vn.getNumChildren(); i++) + { + Node sel = NodeManager::currentNM()->mkNode( + kind::APPLY_SELECTOR_TOTAL, + Node::fromExpr(dt[cindex].getSelectorInternal(ntn.toType(), i)), + n); + Node vnr_c = vnr.isNull() ? vnr : (vn[i] == vnr[i] ? Node::null() : vnr[i]); + if (cexc.find(i) == cexc.end()) + { + trb.push(i); + Node vnr_exp_c; + getExplanationFor( + trb, sel, vn[i], exp, var_count, et, vnr_c, vnr_exp_c, sz); + trb.pop(); + if (!vnr_c.isNull()) + { + Assert(!vnr_exp_c.isNull()); + if (vnr_exp_c.isConst() || vnr_exp.isNull()) + { + // recursively satisfied the disunification obligation + if (vnr_exp_c.isConst()) + { + // was successful, don't consider further + vnr = Node::null(); + } + vnr_exp = vnr_exp_c; + } + } + } + else + { + // if excluded, we may need to add the explanation for this + if (vnr_exp.isNull() && !vnr_c.isNull()) + { + vnr_exp = getExplanationForConstantEquality(sel, vnr[i]); + } + } + } +} + +void SygusExplain::getExplanationFor(Node n, + Node vn, + std::vector& exp, + SygusInvarianceTest& et, + Node vnr, + unsigned& sz) +{ + // naive : + // return getExplanationForConstantEquality( n, vn, exp ); + + // set up the recursion object + std::map var_count; + TermRecBuild trb; + trb.init(vn); + Node vnr_exp; + int sz_use = sz; + getExplanationFor(trb, n, vn, exp, var_count, et, vnr, vnr_exp, sz_use); + Assert(sz_use >= 0); + sz = sz_use; + Assert(vnr.isNull() || !vnr_exp.isNull()); + if (!vnr_exp.isNull() && !vnr_exp.isConst()) + { + exp.push_back(vnr_exp.negate()); + } +} + +void SygusExplain::getExplanationFor(Node n, + Node vn, + std::vector& exp, + SygusInvarianceTest& et) +{ + int sz = -1; + std::map var_count; + TermRecBuild trb; + trb.init(vn); + Node vnr; + Node vnr_exp; + getExplanationFor(trb, n, vn, exp, var_count, et, vnr, vnr_exp, sz); +} + } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ diff --git a/src/theory/quantifiers/sygus_explain.h b/src/theory/quantifiers/sygus_explain.h index f47be00b3..aa2ca0dd0 100644 --- a/src/theory/quantifiers/sygus_explain.h +++ b/src/theory/quantifiers/sygus_explain.h @@ -14,12 +14,13 @@ #include "cvc4_private.h" -#ifndef __CVC4__SYGUS_EXPLAIN_H -#define __CVC4__SYGUS_EXPLAIN_H +#ifndef __CVC4__THEORY__QUANTIFIERS__SYGUS_EXPLAIN_H +#define __CVC4__THEORY__QUANTIFIERS__SYGUS_EXPLAIN_H #include #include "expr/node.h" +#include "theory/quantifiers/sygus_invariance.h" namespace CVC4 { namespace theory { @@ -46,6 +47,7 @@ class TermRecBuild * the active term is initially n. */ void init(Node n); + /** push the context * * This updates the context so that the @@ -53,6 +55,7 @@ class TermRecBuild * curr is the previously active term. */ void push(unsigned p); + /** pop the context */ void pop(); /** indicates that the i^th child of the active @@ -83,8 +86,137 @@ class TermRecBuild void addTerm(Node n); }; +/*The SygusExplain utility + * + * This class is used to produce explanations for refinement lemmas + * in the counterexample-guided inductive synthesis (CEGIS) loop. + * + * When given an invariance test T traverses the AST of a given term, + * uses TermRecBuild to replace various subterms by fresh variables and + * recheck whether the invariant, as specified by T still holds. + * If it does, then we may exclude the explanation for that subterm. + * + * For example, say we have that the current value of + * (datatype) sygus term n is: + * (if (gt x 0) 0 0) + * where if, gt, x, 0 are datatype constructors. + * The explanation returned by getExplanationForConstantEquality + * below for n and the above term is: + * { ((_ is if) n), ((_ is geq) n.0), + * ((_ is x) n.0.0), ((_ is 0) n.0.1), + * ((_ is 0) n.1), ((_ is 0) n.2) } + * + * This class can also return more precise + * explanations based on a property that holds for + * variants of n. For instance, + * say we find that n's builtin analog rewrites to 0: + * ite( x>0, 0, 0 ) ----> 0 + * and we would like to find the minimal explanation for + * why the builtin analog of n rewrites to 0. + * We use the invariance test EquivSygusInvarianceTest + * (see sygus_invariance.h) for doing this. + * Using the SygusExplain::getExplanationFor method below, + * this will invoke the invariant test to check, e.g. + * ite( x>0, 0, y1 ) ----> 0 ? fail + * ite( x>0, y2, 0 ) ----> 0 ? fail + * ite( y3, 0, 0 ) ----> 0 ? success + * where y1, y2, y3 are fresh variables. + * Hence the explanation for the condition x>0 is irrelevant. + * This gives us the explanation: + * { ((_ is if) n), ((_ is 0) n.1), ((_ is 0) n.2) } + * indicating that all terms of the form: + * (if _ 0 0) have a builtin equivalent that rewrites to 0. + * + * For details, see Reynolds et al SYNT 2017. + * + * Below, we let [[exp]]_n denote the term induced by + * the explanation exp for n. + * For example: + * exp = { ((_ is plus) n), ((_ is y) n.1) } + * is such that: + * [[exp]]_n = (plus w y) + * where w is a fresh variable. + */ +class SygusExplain +{ + public: + SygusExplain(TermDbSygus* tdb) : d_tdb(tdb) {} + ~SygusExplain() {} + /** get explanation for constant equality + * + * This function constructs an explanation, stored in exp, such that: + * - All formulas in exp are of the form ((_ is C) ns), where ns + * is a chain of selectors applied to n, and + * - exp => ( n = vn ) + */ + void getExplanationForConstantEquality(Node n, + Node vn, + std::vector& exp); + /** returns the conjunction of exp computed in the above function */ + Node getExplanationForConstantEquality(Node n, Node vn); + + /** get explanation for constant equality + * This is identical to the above function except that we + * take an additional argument cexc, which says which + * children of vn should be excluded from the explanation. + * + * For example, if vn = plus( plus( x, x ), y ) and cexc is { 0 -> true }, + * then the following is appended to exp : + * { ((_ is plus) n), ((_ is y) n.1) } + * where notice that the 0^th argument of vn is excluded. + */ + void getExplanationForConstantEquality(Node n, + Node vn, + std::vector& exp, + std::map& cexc); + /** returns the conjunction of exp computed in the above function */ + Node getExplanationForConstantEquality(Node n, + Node vn, + std::map& cexc); + + /** get explanation for + * + * This function constructs an explanation, stored in exp, such that: + * - All formulas in exp are of the form ((_ is C) ns), where ns + * is a chain of selectors applied to n, and + * - The test et holds for [[exp]]_n, and + * - (if applicable) exp => ( n != vnr ). + * + * This function updates sz to be the term size of [[exp]]_n. + */ + void getExplanationFor(Node n, + Node vn, + std::vector& exp, + SygusInvarianceTest& et, + Node vnr, + unsigned& sz); + void getExplanationFor(Node n, + Node vn, + std::vector& exp, + SygusInvarianceTest& et); + + private: + /** sygus term database associated with this utility */ + TermDbSygus* d_tdb; + /** Helper function for getExplanationFor + * var_count is the number of free variables we have introduced, + * per type, for the purposes of generalizing subterms of n. + * vnr_exp stores the explanation, if one exists, for + * n != vnr. It is only non-null if vnr is non-null. + */ + void getExplanationFor(TermRecBuild& trb, + Node n, + Node vn, + std::vector& exp, + std::map& var_count, + SygusInvarianceTest& et, + Node vnr, + Node& vnr_exp, + int& sz); +}; + } /* CVC4::theory::quantifiers namespace */ } /* CVC4::theory namespace */ } /* CVC4 namespace */ -#endif /* __CVC4__SYGUS_INVARIANCE_H */ +#endif /* __CVC4__THEORY__QUANTIFIERS__SYGUS_EXPLAIN_H */ diff --git a/src/theory/quantifiers/sygus_invariance.cpp b/src/theory/quantifiers/sygus_invariance.cpp new file mode 100644 index 000000000..6813f4320 --- /dev/null +++ b/src/theory/quantifiers/sygus_invariance.cpp @@ -0,0 +1,226 @@ +/********************* */ +/*! \file sygus_invariance.cpp + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief Implementation of techniques for sygus invariance tests. + **/ + +#include "theory/quantifiers/sygus_invariance.h" + +#include "theory/quantifiers/ce_guided_conjecture.h" +#include "theory/quantifiers/ce_guided_pbe.h" +#include "theory/quantifiers/term_database_sygus.h" + +using namespace CVC4::kind; +using namespace std; + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +void EvalSygusInvarianceTest::init(Node conj, Node var, Node res) +{ + d_conj = conj; + d_var = var; + d_result = res; +} + +Node EvalSygusInvarianceTest::doEvaluateWithUnfolding(TermDbSygus* tds, Node n) +{ + return tds->evaluateWithUnfolding(n, d_visited); +} + +bool EvalSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) +{ + TNode tnvn = nvn; + Node conj_subs = d_conj.substitute(d_var, tnvn); + Node conj_subs_unfold = doEvaluateWithUnfolding(tds, conj_subs); + Trace("sygus-cref-eval2-debug") + << " ...check unfolding : " << conj_subs_unfold << std::endl; + Trace("sygus-cref-eval2-debug") << " ......from : " << conj_subs + << std::endl; + if (conj_subs_unfold == d_result) + { + Trace("sygus-cref-eval2") << "Evaluation min explain : " << conj_subs + << " still evaluates to " << d_result + << " regardless of "; + Trace("sygus-cref-eval2") << x << std::endl; + return true; + } + return false; +} + +void EquivSygusInvarianceTest::init( + TermDbSygus* tds, TypeNode tn, CegConjecture* aconj, Node e, Node bvr) +{ + // compute the current examples + d_bvr = bvr; + if (aconj->getPbe()->hasExamples(e)) + { + d_conj = aconj; + d_enum = e; + unsigned nex = aconj->getPbe()->getNumExamples(e); + for (unsigned i = 0; i < nex; i++) + { + d_exo.push_back(d_conj->getPbe()->evaluateBuiltin(tn, bvr, e, i)); + } + } +} + +bool EquivSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) +{ + TypeNode tn = nvn.getType(); + Node nbv = tds->sygusToBuiltin(nvn, tn); + Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + Trace("sygus-sb-mexp-debug") << " min-exp check : " << nbv << " -> " << nbvr + << std::endl; + bool exc_arg = false; + // equivalent / singular up to normalization + if (nbvr == d_bvr) + { + // gives the same result : then the explanation for the child is irrelevant + exc_arg = true; + Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) + << " is rewritten to " << nbvr; + Trace("sygus-sb-mexp") << " regardless of the content of " + << tds->sygusToBuiltin(x) << std::endl; + } + else + { + if (nbvr.isVar()) + { + TypeNode xtn = x.getType(); + if (xtn == tn) + { + Node bx = tds->sygusToBuiltin(x, xtn); + Assert(bx.getType() == nbvr.getType()); + if (nbvr == bx) + { + Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) + << " always rewrites to argument " << nbvr + << std::endl; + // rewrites to the variable : then the explanation of this is + // irrelevant as well + exc_arg = true; + d_bvr = nbvr; + } + } + } + } + // equivalent under examples + if (!exc_arg) + { + if (!d_enum.isNull()) + { + bool ex_equiv = true; + for (unsigned j = 0; j < d_exo.size(); j++) + { + Node nbvr_ex = d_conj->getPbe()->evaluateBuiltin(tn, nbvr, d_enum, j); + if (nbvr_ex != d_exo[j]) + { + ex_equiv = false; + break; + } + } + if (ex_equiv) + { + Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn); + Trace("sygus-sb-mexp") + << " is the same w.r.t. examples regardless of the content of " + << tds->sygusToBuiltin(x) << std::endl; + exc_arg = true; + } + } + } + return exc_arg; +} + +bool DivByZeroSygusInvarianceTest::invariant(TermDbSygus* tds, Node nvn, Node x) +{ + TypeNode tn = nvn.getType(); + Node nbv = tds->sygusToBuiltin(nvn, tn); + Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + if (tds->involvesDivByZero(nbvr)) + { + Trace("sygus-sb-mexp") << "sb-min-exp : " << tds->sygusToBuiltin(nvn) + << " involves div-by-zero regardless of " + << tds->sygusToBuiltin(x) << std::endl; + return true; + } + return false; +} + +void NegContainsSygusInvarianceTest::init(CegConjecture* conj, + Node e, + std::vector& exo, + std::vector& ncind) +{ + if (conj->getPbe()->hasExamples(e)) + { + Assert(conj->getPbe()->getNumExamples(e) == exo.size()); + d_enum = e; + d_exo.insert(d_exo.end(), exo.begin(), exo.end()); + d_neg_con_indices.insert( + d_neg_con_indices.end(), ncind.begin(), ncind.end()); + d_conj = conj; + } +} + +bool NegContainsSygusInvarianceTest::invariant(TermDbSygus* tds, + Node nvn, + Node x) +{ + if (!d_enum.isNull()) + { + TypeNode tn = nvn.getType(); + Node nbv = tds->sygusToBuiltin(nvn, tn); + Node nbvr = tds->getExtRewriter()->extendedRewrite(nbv); + // if for any of the examples, it is not contained, then we can exclude + for (unsigned i = 0; i < d_neg_con_indices.size(); i++) + { + unsigned ii = d_neg_con_indices[i]; + Assert(ii < d_exo.size()); + Node nbvre = d_conj->getPbe()->evaluateBuiltin(tn, nbvr, d_enum, ii); + Node out = d_exo[ii]; + Node cont = + NodeManager::currentNM()->mkNode(kind::STRING_STRCTN, out, nbvre); + Node contr = Rewriter::rewrite(cont); + if (contr == tds->d_false) + { + if (Trace.isOn("sygus-pbe-cterm")) + { + Trace("sygus-pbe-cterm") + << "PBE-cterm : enumerator : do not consider "; + Trace("sygus-pbe-cterm") << nbv << " for any " + << tds->sygusToBuiltin(x) << " since " + << std::endl; + Trace("sygus-pbe-cterm") << " PBE-cterm : for input example : "; + std::vector ex; + d_conj->getPbe()->getExample(d_enum, ii, ex); + for (unsigned j = 0; j < ex.size(); j++) + { + Trace("sygus-pbe-cterm") << ex[j] << " "; + } + Trace("sygus-pbe-cterm") << std::endl; + Trace("sygus-pbe-cterm") + << " PBE-cterm : this rewrites to : " << nbvre << std::endl; + Trace("sygus-pbe-cterm") + << " PBE-cterm : and is not in output : " << out << std::endl; + } + return true; + } + } + } + return false; +} + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ diff --git a/src/theory/quantifiers/sygus_invariance.h b/src/theory/quantifiers/sygus_invariance.h new file mode 100644 index 000000000..bf3c56572 --- /dev/null +++ b/src/theory/quantifiers/sygus_invariance.h @@ -0,0 +1,274 @@ +/********************* */ +/*! \file sygus_invariance.h + ** \verbatim + ** Top contributors (to current version): + ** Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2017 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief sygus invariance tests + **/ + +#include "cvc4_private.h" + +#ifndef __CVC4__THEORY__QUANTIFIERS__SYGUS_INVARIANCE_H +#define __CVC4__THEORY__QUANTIFIERS__SYGUS_INVARIANCE_H + +#include +#include + +#include "expr/node.h" + +namespace CVC4 { +namespace theory { +namespace quantifiers { + +class TermDbSygus; +class CegConjecture; + +/* SygusInvarianceTest +* +* This class is the standard interface for term generalization +* in SyGuS. Its interface is a single function is_variant, +* which is a virtual condition for SyGuS terms. +* +* The common use case of invariance tests is when constructing +* minimal explanations for refinement lemmas in the +* counterexample-guided inductive synthesis (CEGIS) loop. +* See sygus_explain.h for more details. +*/ +class SygusInvarianceTest +{ + public: + /** Is nvn invariant with respect to this test ? + * + * - nvn is the term to check whether it is invariant. + * - x is a variable such that the previous call to + * is_invariant (if any) was with term nvn_prev, and + * nvn is equal to nvn_prev with some subterm + * position replaced by x. This is typically used + * for debugging only. + */ + bool is_invariant(TermDbSygus* tds, Node nvn, Node x) + { + if (invariant(tds, nvn, x)) + { + d_update_nvn = nvn; + return true; + } + return false; + } + /** get updated term */ + Node getUpdatedTerm() { return d_update_nvn; } + /** set updated term */ + void setUpdatedTerm(Node n) { d_update_nvn = n; } + protected: + /** result of the node that satisfies this invariant */ + Node d_update_nvn; + /** check whether nvn[ x ] is invariant */ + virtual bool invariant(TermDbSygus* tds, Node nvn, Node x) = 0; +}; + +/** EquivSygusInvarianceTest +* +* This class tests whether a term evaluates via evaluation +* operators in the deep embedding (Section 4 of Reynolds +* et al. CAV 2015) to fixed term d_result. +* +* For example, consider a SyGuS evaluation function eval +* for a synthesis conjecture with arguments x and y. +* Notice that the term t = (mult x y) is such that: +* eval( t, 0, 1 ) ----> 0 +* This test is invariant on the content of the second +* argument of t, noting that: +* eval( (mult x _), 0, 1 ) ----> 0 +* as well, via a call to EvalSygusInvarianceTest::invariant. +* +* Another example, t = ite( gt( x, y ), x, y ) is such that: +* eval( t, 2, 3 ) ----> 3 +* This test is invariant on the second child of t, noting: +* eval( ite( gt( x, y ), _, y ), 2, 3 ) ----> 3 +*/ +class EvalSygusInvarianceTest : public SygusInvarianceTest +{ + public: + EvalSygusInvarianceTest() {} + ~EvalSygusInvarianceTest() {} + /** initialize this invariance test + * This sets d_conj/d_var/d_result, where + * we are checking whether: + * d_conj { d_var -> n } ----> d_result. + * for terms n. + */ + void init(Node conj, Node var, Node res); + + /** do evaluate with unfolding, using the cache of this class */ + Node doEvaluateWithUnfolding(TermDbSygus* tds, Node n); + + protected: + /** does d_conj{ d_var -> nvn } still rewrite to d_result? */ + bool invariant(TermDbSygus* tds, Node nvn, Node x); + + private: + /** the formula we are evaluating */ + Node d_conj; + /** the variable */ + TNode d_var; + /** the result of the evaluation */ + Node d_result; + /** cache of n -> the simplified form of eval( n ) */ + std::unordered_map d_visited; +}; + +/** EquivSygusInvarianceTest +* +* This class tests whether a builtin version of a +* sygus term is equivalent up to rewriting to a RHS value bvr. +* +* For example, +* +* ite( t>0, 0, 0 ) + s*0 ----> 0 +* +* This test is invariant on the condition t>0 and s, since: +* +* ite( _, 0, 0 ) + _*0 ----> 0 +* +* for any values of _. +* +* It also manages the case where the rewriting is invariant +* wrt a finite set of examples occurring in the conjecture. +* (EX1) : For example if our input examples are: +* (x,y,z) = (3,2,4), (5,2,6), (3,2,1) +* On these examples, we have: +* +* ite( x>y, z, 0) ---> 4,6,1 +* +* which is invariant on the second argument: +* +* ite( x>y, z, _) ---> 4,6,1 +* +* For details, see Reynolds et al SYNT 2017. +*/ +class EquivSygusInvarianceTest : public SygusInvarianceTest +{ + public: + EquivSygusInvarianceTest() : d_conj(nullptr) {} + ~EquivSygusInvarianceTest() {} + /** initialize this invariance test + * tn is the sygus type for e + * aconj/e are used for conjecture-specific symmetry breaking + * bvr is the builtin version of the right hand side of the rewrite that we + * are checking for invariance + */ + void init( + TermDbSygus* tds, TypeNode tn, CegConjecture* aconj, Node e, Node bvr); + + protected: + /** checks whether the analog of nvn still rewrites to d_bvr */ + bool invariant(TermDbSygus* tds, Node nvn, Node x); + + private: + /** the conjecture associated with the enumerator d_enum */ + CegConjecture* d_conj; + /** the enumerator associated with the term for which this test is for */ + Node d_enum; + /** the RHS of the evaluation */ + Node d_bvr; + /** the result of the examples + * In (EX1), this is (4,6,1) + */ + std::vector d_exo; +}; + +/** DivByZeroSygusInvarianceTest + * + * This class tests whether a sygus term involves + * division by zero. + * + * For example the test for: + * ( x + ( y/0 )*2 ) + * is invariant on the contents of _ below: + * ( _ + ( _/0 )*_ ) + */ +class DivByZeroSygusInvarianceTest : public SygusInvarianceTest +{ + public: + DivByZeroSygusInvarianceTest() {} + ~DivByZeroSygusInvarianceTest() {} + protected: + /** checks whether nvn involves division by zero. */ + bool invariant(TermDbSygus* tds, Node nvn, Node x); +}; + +/** NegContainsSygusInvarianceTest +* +* This class is used to construct a minimal shape of a term that cannot +* be contained in at least one output of an I/O pair. +* +* Say our PBE conjecture is: +* +* exists f. +* f( "abc" ) = "abc abc" ^ +* f( "de" ) = "de de" +* +* Then, this class is used when there is a candidate solution t[x1] +* such that either: +* contains( "abc abc", t["abc"] ) ---> false or +* contains( "de de", t["de"] ) ---> false +* +* It is used to determine whether certain generalizations of t[x1] +* are still sufficient to falsify one of the above containments. +* +* For example: +* +* The test for str.++( x1, "d" ) is invariant on its first argument +* ...since contains( "abc abc", str.++( _, "d" ) ) ---> false +* The test for str.replace( "de", x1, "b" ) is invariant on its third argument +* ...since contains( "abc abc", str.replace( "de", "abc", _ ) ) ---> false +*/ +class NegContainsSygusInvarianceTest : public SygusInvarianceTest +{ + public: + NegContainsSygusInvarianceTest() : d_conj(nullptr) {} + ~NegContainsSygusInvarianceTest() {} + /** initialize this invariance test + * cpbe is the conjecture utility. + * e is the enumerator which we are reasoning about (associated with a synth + * fun). + * exo is the list of outputs of the PBE conjecture. + * ncind is the set of possible indices of the PBE conjecture to check + * invariance of non-containment. + * For example, in the above example, when t[x1] = "ab", then this + * has the index 1 since contains("de de", "ab") ---> false but not + * the index 0 since contains("abc abc","ab") ---> true. + */ + void init(CegConjecture* conj, + Node e, + std::vector& exo, + std::vector& ncind); + + protected: + /** checks if contains( out_i, nvn[in_i] ) --> false for some I/O pair i. */ + bool invariant(TermDbSygus* tds, Node nvn, Node x); + + private: + /** The enumerator whose value we are considering in this invariance test */ + Node d_enum; + /** The output examples for the enumerator */ + std::vector d_exo; + /** The set of I/O pair indices i such that + * contains( out_i, nvn[in_i] ) ---> false + */ + std::vector d_neg_con_indices; + /** reference to the conjecture associated with this test */ + CegConjecture* d_conj; +}; + +} /* CVC4::theory::quantifiers namespace */ +} /* CVC4::theory namespace */ +} /* CVC4 namespace */ + +#endif /* __CVC4__THEORY__QUANTIFIERS__SYGUS_INVARIANCE_H */ diff --git a/src/theory/quantifiers/term_database_sygus.cpp b/src/theory/quantifiers/term_database_sygus.cpp index d8d120eab..45e3d7593 100644 --- a/src/theory/quantifiers/term_database_sygus.cpp +++ b/src/theory/quantifiers/term_database_sygus.cpp @@ -42,22 +42,11 @@ namespace CVC4 { namespace theory { namespace quantifiers { -bool EvalSygusInvarianceTest::invariant( quantifiers::TermDbSygus * tds, Node nvn, Node x ){ - TNode tnvn = nvn; - Node conj_subs = d_conj.substitute( d_var, tnvn ); - Node conj_subs_unfold = tds->evaluateWithUnfolding( conj_subs, d_visited ); - Trace("sygus-cref-eval2-debug") << " ...check unfolding : " << conj_subs_unfold << std::endl; - Trace("sygus-cref-eval2-debug") << " ......from : " << conj_subs << std::endl; - if( conj_subs_unfold==d_result ){ - Trace("sygus-cref-eval2") << "Evaluation min explain : " << conj_subs << " still evaluates to " << d_result << " regardless of "; - Trace("sygus-cref-eval2") << x << std::endl; - return true; - }else{ - return false; - } -} - -TermDbSygus::TermDbSygus( context::Context* c, QuantifiersEngine* qe ) : d_quantEngine( qe ){ +TermDbSygus::TermDbSygus(context::Context* c, QuantifiersEngine* qe) + : d_quantEngine(qe), + d_syexp(new SygusExplain(this)), + d_ext_rw(new ExtendedRewriter) +{ d_true = NodeManager::currentNM()->mkConst( true ); d_false = NodeManager::currentNM()->mkConst( false ); } @@ -424,59 +413,12 @@ Node TermDbSygus::builtinToSygusConst( Node c, TypeNode tn, int rcons_depth ) { } } -Node TermDbSygus::getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ) { - return n; - /* TODO? - if( n.getKind()==SKOLEM ){ - std::map< Node, Node >::iterator its = subs.find( n ); - if( its!=subs.end() ){ - return its->second; - }else{ - std::map< Node, TypeNode >::iterator it = d_fv_stype.find( n ); - if( it!=d_fv_stype.end() ){ - Node v = getVarInc( it->second, var_count ); - subs[n] = v; - return v; - }else{ - return n; - } - } - }else{ - if( n.getNumChildren()>0 ){ - std::vector< Node > children; - if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){ - children.push_back( n.getOperator() ); - } - bool childChanged = false; - for( unsigned i=0; imkNode( n.getKind(), children ); - } - } - return n; - } - */ -} - -Node TermDbSygus::getNormalized( TypeNode t, Node prog, bool do_pre_norm, bool do_post_norm ) { - if( do_pre_norm ){ - std::map< TypeNode, int > var_count; - std::map< Node, Node > subs; - prog = getSygusNormalized( prog, var_count, subs ); - } +Node TermDbSygus::getNormalized(TypeNode t, Node prog) +{ std::map< Node, Node >::iterator itn = d_normalized[t].find( prog ); if( itn==d_normalized[t].end() ){ Node progr = Node::fromExpr( smt::currentSmtEngine()->expandDefinitions( prog.toExpr() ) ); progr = Rewriter::rewrite( progr ); - if( do_post_norm ){ - std::map< TypeNode, int > var_count; - std::map< Node, Node > subs; - progr = getSygusNormalized( progr, var_count, subs ); - } Trace("sygus-sym-break2") << "...rewrites to " << progr << std::endl; d_normalized[t][prog] = progr; return progr; @@ -513,136 +455,6 @@ unsigned TermDbSygus::getSygusConstructors( Node n, std::vector< Node >& cons ) return 1+sum; } } - -bool TermDbSygus::isAntisymmetric( Kind k, Kind& dk ) { - if( k==GT ){ - dk = LT; - return true; - }else if( k==GEQ ){ - dk = LEQ; - return true; - }else if( k==BITVECTOR_UGT ){ - dk = BITVECTOR_ULT; - return true; - }else if( k==BITVECTOR_UGE ){ - dk = BITVECTOR_ULE; - return true; - }else if( k==BITVECTOR_SGT ){ - dk = BITVECTOR_SLT; - return true; - }else if( k==BITVECTOR_SGE ){ - dk = BITVECTOR_SLE; - return true; - }else{ - return false; - } -} - -bool TermDbSygus::isIdempotentArg( Node n, Kind ik, int arg ) { - // these should all be binary operators - //Assert( ik!=DIVISION && ik!=INTS_DIVISION && ik!=INTS_MODULUS && ik!=BITVECTOR_UDIV ); - TypeNode tn = n.getType(); - if( n==getTypeValue( tn, 0 ) ){ - if( ik==PLUS || ik==OR || ik==XOR || ik==BITVECTOR_PLUS || ik==BITVECTOR_OR || ik==BITVECTOR_XOR || ik==STRING_CONCAT ){ - return true; - }else if( ik==MINUS || ik==BITVECTOR_SHL || ik==BITVECTOR_LSHR || ik==BITVECTOR_ASHR || ik==BITVECTOR_SUB || - ik==BITVECTOR_UREM || ik==BITVECTOR_UREM_TOTAL ){ - return arg==1; - } - }else if( n==getTypeValue( tn, 1 ) ){ - if( ik==MULT || ik==BITVECTOR_MULT ){ - return true; - }else if( ik==DIVISION || ik==DIVISION_TOTAL || ik==INTS_DIVISION || ik==INTS_DIVISION_TOTAL || - ik==INTS_MODULUS || ik==INTS_MODULUS_TOTAL || - ik==BITVECTOR_UDIV_TOTAL || ik==BITVECTOR_UDIV || ik==BITVECTOR_SDIV ){ - return arg==1; - } - }else if( n==getTypeMaxValue( tn ) ){ - if( ik==EQUAL || ik==BITVECTOR_AND || ik==BITVECTOR_XNOR ){ - return true; - } - } - return false; -} - - -Node TermDbSygus::isSingularArg( Node n, Kind ik, int arg ) { - TypeNode tn = n.getType(); - if( n==getTypeValue( tn, 0 ) ){ - if( ik==AND || ik==MULT || ik==BITVECTOR_AND || ik==BITVECTOR_MULT ){ - return n; - }else if( ik==BITVECTOR_SHL || ik==BITVECTOR_LSHR || ik==BITVECTOR_ASHR || - ik==BITVECTOR_UREM || ik==BITVECTOR_UREM_TOTAL ){ - if( arg==0 ){ - return n; - } - }else if( ik==BITVECTOR_UDIV_TOTAL || ik==BITVECTOR_UDIV || ik==BITVECTOR_SDIV ){ - if( arg==0 ){ - return n; - }else if( arg==1 ){ - return getTypeMaxValue( tn ); - } - }else if( ik==DIVISION || ik==DIVISION_TOTAL || ik==INTS_DIVISION || ik==INTS_DIVISION_TOTAL || - ik==INTS_MODULUS || ik==INTS_MODULUS_TOTAL ){ - if( arg==0 ){ - return n; - }else{ - //TODO? - } - }else if( ik==STRING_SUBSTR ){ - if( arg==0 ){ - return n; - }else if( arg==2 ){ - return getTypeValue( NodeManager::currentNM()->stringType(), 0 ); - } - }else if( ik==STRING_STRIDOF ){ - if( arg==0 || arg==1 ){ - return getTypeValue( NodeManager::currentNM()->integerType(), -1 ); - } - } - }else if( n==getTypeValue( tn, 1 ) ){ - if( ik==BITVECTOR_UREM_TOTAL ){ - return getTypeValue( tn, 0 ); - } - }else if( n==getTypeMaxValue( tn ) ){ - if( ik==OR || ik==BITVECTOR_OR ){ - return n; - } - }else{ - if( n.getType().isReal() && n.getConst().sgn()<0 ){ - // negative arguments - if( ik==STRING_SUBSTR || ik==STRING_CHARAT ){ - return getTypeValue( NodeManager::currentNM()->stringType(), 0 ); - }else if( ik==STRING_STRIDOF ){ - Assert( arg==2 ); - return getTypeValue( NodeManager::currentNM()->integerType(), -1 ); - } - } - } - return Node::null(); -} - -bool TermDbSygus::hasOffsetArg( Kind ik, int arg, int& offset, Kind& ok ) { - if( ik==LT ){ - Assert( arg==0 || arg==1 ); - offset = arg==0 ? 1 : -1; - ok = LEQ; - return true; - }else if( ik==BITVECTOR_ULT ){ - Assert( arg==0 || arg==1 ); - offset = arg==0 ? 1 : -1; - ok = BITVECTOR_ULE; - return true; - }else if( ik==BITVECTOR_SLT ){ - Assert( arg==0 || arg==1 ); - offset = arg==0 ? 1 : -1; - ok = BITVECTOR_SLE; - return true; - } - return false; -} - - class ReqTrie { public: @@ -897,7 +709,8 @@ bool TermDbSygus::considerConst( TypeNode tn, TypeNode tnp, Node c, Kind pk, int if( pdt[pc].getNumArgs()==2 ){ Kind ok; int offset; - if( hasOffsetArg( pk, arg, offset, ok ) ){ + if (d_quantEngine->getTermUtil()->hasOffsetArg(pk, arg, offset, ok)) + { Trace("sygus-sb-simple-debug") << pk << " has offset arg " << ok << " " << offset << std::endl; int ok_arg = getKindConsNum( tnp, ok ); if( ok_arg!=-1 ){ @@ -905,7 +718,8 @@ bool TermDbSygus::considerConst( TypeNode tn, TypeNode tnp, Node c, Kind pk, int //other operator be the same type if( isTypeMatch( pdt[ok_arg], pdt[arg] ) ){ int status; - Node co = getTypeValueOffset( c.getType(), c, offset, status ); + Node co = d_quantEngine->getTermUtil()->getTypeValueOffset( + c.getType(), c, offset, status); Trace("sygus-sb-simple-debug") << c << " with offset " << offset << " is " << co << ", status=" << status << std::endl; if( status==0 && !co.isNull() ){ if( hasConst( tn, co ) ){ @@ -926,7 +740,8 @@ bool TermDbSygus::considerConst( const Datatype& pdt, TypeNode tnp, Node c, Kind int pc = getKindConsNum( tnp, pk ); bool ret = true; Trace("sygus-sb-debug") << "Consider sygus const " << c << ", parent = " << pk << ", arg = " << arg << "?" << std::endl; - if( isIdempotentArg( c, pk, arg ) ){ + if (d_quantEngine->getTermUtil()->isIdempotentArg(c, pk, arg)) + { if( pdt[pc].getNumArgs()==2 ){ int oarg = arg==0 ? 1 : 0; TypeNode otn = TypeNode::fromType( ((SelectorType)pdt[pc][oarg].getType()).getRangeType() ); @@ -935,8 +750,8 @@ bool TermDbSygus::considerConst( const Datatype& pdt, TypeNode tnp, Node c, Kind ret = false; } } - }else{ - Node sc = isSingularArg( c, pk, arg ); + }else{ + Node sc = d_quantEngine->getTermUtil()->isSingularArg(c, pk, arg); if( !sc.isNull() ){ if( hasConst( tnp, sc ) ){ Trace("sygus-sb-simple") << " sb-simple : " << c << " is singular arg " << arg << " of " << pk << ", evaluating to " << sc << "..." << std::endl; @@ -947,9 +762,9 @@ bool TermDbSygus::considerConst( const Datatype& pdt, TypeNode tnp, Node c, Kind if( ret ){ ReqTrie rt; Assert( rt.empty() ); - Node max_c = getTypeMaxValue( c.getType() ); - Node zero_c = getTypeValue( c.getType(), 0 ); - Node one_c = getTypeValue( c.getType(), 1 ); + Node max_c = d_quantEngine->getTermUtil()->getTypeMaxValue(c.getType()); + Node zero_c = d_quantEngine->getTermUtil()->getTypeValue(c.getType(), 0); + Node one_c = d_quantEngine->getTermUtil()->getTypeValue(c.getType(), 1); if( pk==XOR || pk==BITVECTOR_XOR ){ if( c==max_c ){ rt.d_req_kind = pk==XOR ? NOT : BITVECTOR_NOT; @@ -1000,7 +815,8 @@ int TermDbSygus::solveForArgument( TypeNode tn, unsigned cindex, unsigned arg ) if( nk==MINUS || nk==BITVECTOR_SUB ){ if( dt[cindex].getNumArgs()==2 && arg==0 ){ TypeNode tnco = getArgType( dt[cindex], 1 ); - Node builtin = getTypeValue( sygusToBuiltinType( tnc ), 0 ); + Node builtin = d_quantEngine->getTermUtil()->getTypeValue( + sygusToBuiltinType(tnc), 0); solve_ret = getConstConsNum( tn, builtin ); if( solve_ret!=-1 ){ // t - s -----> ( 0 - s ) + t @@ -1029,73 +845,6 @@ int TermDbSygus::solveForArgument( TypeNode tn, unsigned cindex, unsigned arg ) return -1; } -Node TermDbSygus::getTypeValue( TypeNode tn, int val ) { - std::map< int, Node >::iterator it = d_type_value[tn].find( val ); - if( it==d_type_value[tn].end() ){ - Node n; - if( tn.isInteger() || tn.isReal() ){ - Rational c(val); - n = NodeManager::currentNM()->mkConst( c ); - }else if( tn.isBitVector() ){ - unsigned int uv = val; - BitVector bval(tn.getConst(), uv); - n = NodeManager::currentNM()->mkConst(bval); - }else if( tn.isBoolean() ){ - if( val==0 ){ - n = d_false; - } - }else if( tn.isString() ){ - if( val==0 ){ - n = NodeManager::currentNM()->mkConst( ::CVC4::String("") ); - } - } - d_type_value[tn][val] = n; - return n; - }else{ - return it->second; - } -} - -Node TermDbSygus::getTypeMaxValue( TypeNode tn ) { - std::map< TypeNode, Node >::iterator it = d_type_max_value.find( tn ); - if( it==d_type_max_value.end() ){ - Node n; - if( tn.isBitVector() ){ - n = bv::utils::mkOnes(tn.getConst()); - }else if( tn.isBoolean() ){ - n = d_true; - } - d_type_max_value[tn] = n; - return n; - }else{ - return it->second; - } -} - -Node TermDbSygus::getTypeValueOffset( TypeNode tn, Node val, int offset, int& status ) { - std::map< int, Node >::iterator it = d_type_value_offset[tn][val].find( offset ); - if( it==d_type_value_offset[tn][val].end() ){ - Node val_o; - Node offset_val = getTypeValue( tn, offset ); - status = -1; - if( !offset_val.isNull() ){ - if( tn.isInteger() || tn.isReal() ){ - val_o = Rewriter::rewrite( NodeManager::currentNM()->mkNode( PLUS, val, offset_val ) ); - status = 0; - }else if( tn.isBitVector() ){ - val_o = Rewriter::rewrite( NodeManager::currentNM()->mkNode( BITVECTOR_PLUS, val, offset_val ) ); - // TODO : enable? watch for overflows - } - } - d_type_value_offset[tn][val][offset] = val_o; - d_type_value_offset_status[tn][val][offset] = status; - return val_o; - }else{ - status = d_type_value_offset_status[tn][val][offset]; - return it->second; - } -} - struct sortConstants { TermDbSygus * d_tds; Kind d_comp_kind; @@ -1175,7 +924,8 @@ void TermDbSygus::registerSygusType( TypeNode tn ) { } //for constant reconstruction Kind ck = getComparisonKind( TypeNode::fromType( dt.getSygusType() ) ); - Node z = getTypeValue( TypeNode::fromType( dt.getSygusType() ), 0 ); + Node z = d_quantEngine->getTermUtil()->getTypeValue( + TypeNode::fromType(dt.getSygusType()), 0); d_const_list_pos[tn] = 0; //iterate over constructors for( unsigned i=0; i reserved; for( unsigned i=0; i<=2; i++ ){ - Node rsv = i==2 ? getTypeMaxValue( btn ) : getTypeValue( btn, i ); + Node rsv = + i == 2 ? d_quantEngine->getTermUtil()->getTypeMaxValue(btn) + : d_quantEngine->getTermUtil()->getTypeValue(btn, i); if( !rsv.isNull() ){ reserved[ rsv ] = true; } @@ -1751,7 +1504,9 @@ bool TermDbSygus::involvesDivByZero( Node n, std::map< Node, bool >& visited ){ if( k==DIVISION || k==DIVISION_TOTAL || k==INTS_DIVISION || k==INTS_DIVISION_TOTAL || k==INTS_MODULUS || k==INTS_MODULUS_TOTAL ){ if( n[1].isConst() ){ - if( n[1]==getTypeValue( n[1].getType(), 0 ) ){ + if (n[1] + == d_quantEngine->getTermUtil()->getTypeValue(n[1].getType(), 0)) + { return true; } }else{ @@ -1967,7 +1722,7 @@ void TermDbSygus::registerModelValue( Node a, Node v, std::vector< Node >& terms unsigned start = d_node_mv_args_proc[n][vn]; // get explanation in terms of testers std::vector< Node > antec_exp; - getExplanationForConstantEquality( n, vn, antec_exp ); + d_syexp->getExplanationForConstantEquality(n, vn, antec_exp); Node antec = antec_exp.size()==1 ? antec_exp[0] : NodeManager::currentNM()->mkNode( kind::AND, antec_exp ); //Node antec = n.eqNode( vn ); TypeNode tn = n.getType(); @@ -2012,18 +1767,18 @@ void TermDbSygus::registerModelValue( Node a, Node v, std::vector< Node >& terms EvalSygusInvarianceTest esit; eval_children.insert( eval_children.end(), it->second[i].begin(), it->second[i].end() ); - esit.d_conj = NodeManager::currentNM()->mkNode( kind::APPLY_UF, eval_children ); - esit.d_var = n; + Node conj = + NodeManager::currentNM()->mkNode(kind::APPLY_UF, eval_children); eval_children[1] = vn; Node eval_fun = NodeManager::currentNM()->mkNode( kind::APPLY_UF, eval_children ); - esit.d_result = evaluateWithUnfolding( eval_fun ); - res = esit.d_result; + res = evaluateWithUnfolding(eval_fun); + esit.init(conj, n, res); eval_children.resize( 2 ); eval_children[1] = n; //evaluate with minimal explanation std::vector< Node > mexp; - getExplanationFor( n, vn, mexp, esit ); + d_syexp->getExplanationFor(n, vn, mexp, esit); Assert( !mexp.empty() ); expn = mexp.size()==1 ? mexp[0] : NodeManager::currentNM()->mkNode( kind::AND, mexp ); @@ -2062,136 +1817,6 @@ void TermDbSygus::registerModelValue( Node a, Node v, std::vector< Node >& terms } } -void TermDbSygus::getExplanationForConstantEquality( Node n, Node vn, std::vector< Node >& exp ) { - std::map< unsigned, bool > cexc; - getExplanationForConstantEquality( n, vn, exp, cexc ); -} - -void TermDbSygus::getExplanationForConstantEquality( Node n, Node vn, std::vector< Node >& exp, std::map< unsigned, bool >& cexc ) { - Assert( vn.getKind()==kind::APPLY_CONSTRUCTOR ); - Assert( n.getType()==vn.getType() ); - TypeNode tn = n.getType(); - Assert( tn.isDatatype() ); - const Datatype& dt = ((DatatypeType)tn.toType()).getDatatype(); - int i = Datatype::indexOf( vn.getOperator().toExpr() ); - Node tst = datatypes::DatatypesRewriter::mkTester( n, i, dt ); - exp.push_back( tst ); - for( unsigned j=0; jmkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[i].getSelectorInternal( tn.toType(), j ) ), n ); - getExplanationForConstantEquality( sel, vn[j], exp ); - } - } -} - -Node TermDbSygus::getExplanationForConstantEquality( Node n, Node vn ) { - std::map< unsigned, bool > cexc; - return getExplanationForConstantEquality( n, vn, cexc ); -} - -Node TermDbSygus::getExplanationForConstantEquality( Node n, Node vn, std::map< unsigned, bool >& cexc ) { - std::vector< Node > exp; - getExplanationForConstantEquality( n, vn, exp, cexc ); - Assert( !exp.empty() ); - return exp.size()==1 ? exp[0] : NodeManager::currentNM()->mkNode( kind::AND, exp ); -} - -// we have ( n = vn => eval( n ) = bvr ) ^ vn != vnr , returns exp such that exp => ( eval( n ) = bvr ^ vn != vnr ) -void TermDbSygus::getExplanationFor( TermRecBuild& trb, Node n, Node vn, std::vector< Node >& exp, std::map< TypeNode, int >& var_count, - SygusInvarianceTest& et, Node vnr, Node& vnr_exp, int& sz ) { - Assert( vnr.isNull() || vn!=vnr ); - Assert( vn.getKind()==APPLY_CONSTRUCTOR ); - Assert( vnr.isNull() || vnr.getKind()==APPLY_CONSTRUCTOR ); - Assert( n.getType()==vn.getType() ); - TypeNode ntn = n.getType(); - std::map< unsigned, bool > cexc; - // for each child, check whether replacing by a fresh variable and rewriting again - for( unsigned i=0; i=0 ){ - int s = getSygusTermSize( vn[i] ); - sz = sz - s; - } - }else{ - trb.replaceChild( i, vn[i] ); - } - } - const Datatype& dt = ((DatatypeType)ntn.toType()).getDatatype(); - int cindex = Datatype::indexOf( vn.getOperator().toExpr() ); - Assert( cindex>=0 && cindex<(int)dt.getNumConstructors() ); - Node tst = datatypes::DatatypesRewriter::mkTester( n, cindex, dt ); - exp.push_back( tst ); - // if the operator of vn is different than vnr, then disunification obligation is met - if( !vnr.isNull() ){ - if( vnr.getOperator()!=vn.getOperator() ){ - vnr = Node::null(); - vnr_exp = d_true; - } - } - for( unsigned i=0; imkNode( kind::APPLY_SELECTOR_TOTAL, Node::fromExpr( dt[cindex].getSelectorInternal( ntn.toType(), i ) ), n ); - Node vnr_c = vnr.isNull() ? vnr : ( vn[i]==vnr[i] ? Node::null() : vnr[i] ); - if( cexc.find( i )==cexc.end() ){ - trb.push( i ); - Node vnr_exp_c; - getExplanationFor( trb, sel, vn[i], exp, var_count, et, vnr_c, vnr_exp_c, sz ); - trb.pop(); - if( !vnr_c.isNull() ){ - Assert( !vnr_exp_c.isNull() ); - if( vnr_exp_c.isConst() || vnr_exp.isNull() ){ - // recursively satisfied the disunification obligation - if( vnr_exp_c.isConst() ){ - // was successful, don't consider further - vnr = Node::null(); - } - vnr_exp = vnr_exp_c; - } - } - }else{ - // if excluded, we may need to add the explanation for this - if( vnr_exp.isNull() && !vnr_c.isNull() ){ - vnr_exp = getExplanationForConstantEquality( sel, vnr[i] ); - } - } - } -} - -void TermDbSygus::getExplanationFor( Node n, Node vn, std::vector< Node >& exp, SygusInvarianceTest& et, Node vnr, unsigned& sz ) { - // naive : - //return getExplanationForConstantEquality( n, vn, exp ); - - // set up the recursion object - std::map< TypeNode, int > var_count; - TermRecBuild trb; - trb.init( vn ); - Node vnr_exp; - int sz_use = sz; - getExplanationFor( trb, n, vn, exp, var_count, et, vnr, vnr_exp, sz_use ); - Assert( sz_use>=0 ); - sz = sz_use; - Assert( vnr.isNull() || !vnr_exp.isNull() ); - if( !vnr_exp.isNull() && !vnr_exp.isConst() ){ - exp.push_back( vnr_exp.negate() ); - } -} - -void TermDbSygus::getExplanationFor( Node n, Node vn, std::vector< Node >& exp, SygusInvarianceTest& et ) { - int sz = -1; - std::map< TypeNode, int > var_count; - TermRecBuild trb; - trb.init( vn ); - Node vnr; - Node vnr_exp; - getExplanationFor( trb, n, vn, exp, var_count, et, vnr, vnr_exp, sz ); -} - Node TermDbSygus::unfold( Node en, std::map< Node, Node >& vtm, std::vector< Node >& exp, bool track_exp ) { if( en.getKind()==kind::APPLY_UF ){ Trace("sygus-db-debug") << "Unfold : " << en << std::endl; @@ -2336,8 +1961,11 @@ Node TermDbSygus::evaluateBuiltin( TypeNode tn, Node bn, std::vector< Node >& ar } } -Node TermDbSygus::evaluateWithUnfolding( Node n, std::map< Node, Node >& visited ) { - std::map< Node, Node >::iterator it = visited.find( n ); +Node TermDbSygus::evaluateWithUnfolding( + Node n, std::unordered_map& visited) +{ + std::unordered_map::iterator it = + visited.find(n); if( it==visited.end() ){ Node ret = n; while( ret.getKind()==APPLY_UF && ret[0].getKind()==APPLY_CONSTRUCTOR ){ @@ -2357,8 +1985,7 @@ Node TermDbSygus::evaluateWithUnfolding( Node n, std::map< Node, Node >& visited if( childChanged ){ ret = NodeManager::currentNM()->mkNode( ret.getKind(), children ); } - // TODO : extended rewrite? - ret = extendedRewrite( ret ); + ret = getExtRewriter()->extendedRewrite(ret); } visited[n] = ret; return ret; @@ -2368,7 +1995,7 @@ Node TermDbSygus::evaluateWithUnfolding( Node n, std::map< Node, Node >& visited } Node TermDbSygus::evaluateWithUnfolding( Node n ) { - std::map< Node, Node > visited; + std::unordered_map visited; return evaluateWithUnfolding( n, visited ); } @@ -2377,7 +2004,7 @@ bool TermDbSygus::computeGenericRedundant( TypeNode tn, Node g ) { std::map< Node, bool >::iterator it = d_gen_redundant[tn].find( g ); if( it==d_gen_redundant[tn].end() ){ Trace("sygus-gnf") << "Register generic for " << tn << " : " << g << std::endl; - Node gr = getNormalized( tn, g, false ); + Node gr = getNormalized(tn, g); Trace("sygus-gnf-debug") << "Generic " << g << " rewrites to " << gr << std::endl; std::map< Node, Node >::iterator itg = d_gen_terms[tn].find( gr ); bool red = true; @@ -2406,205 +2033,6 @@ bool TermDbSygus::isGenericRedundant( TypeNode tn, unsigned i ) { } } -Node TermDbSygus::extendedRewritePullIte( Node n ) { - // generalize this? - Assert( n.getNumChildren()==2 ); - Assert( n.getType().isBoolean() ); - Assert( n.getMetaKind() != kind::metakind::PARAMETERIZED ); - std::vector< Node > children; - for( unsigned i=0; imkNode( n.getKind(), children ) ); - children[i] = n[i]; - if( eqr.isConst() ){ - std::vector< Node > new_children; - Kind new_k; - if( eqr==d_true ){ - new_k = kind::OR; - new_children.push_back( j==0 ? n[i][0] : n[i][0].negate() ); - }else{ - Assert( eqr==d_false ); - new_k = kind::AND; - new_children.push_back( j==0 ? n[i][0].negate() : n[i][0] ); - } - children[i] = n[i][2-j]; - Node rem_eq = NodeManager::currentNM()->mkNode( n.getKind(), children ); - children[i] = n[i]; - new_children.push_back( rem_eq ); - Node nc = NodeManager::currentNM()->mkNode( new_k, new_children ); - Trace("sygus-ext-rewrite") << "sygus-extr : " << n << " rewrites to " << nc << " by simple ITE pulling." << std::endl; - //recurse - return extendedRewrite( nc ); - } - } - } - } - return Node::null(); -} - -Node TermDbSygus::extendedRewrite( Node n ) { - std::map< Node, Node >::iterator it = d_ext_rewrite_cache.find( n ); - if( it == d_ext_rewrite_cache.end() ){ - Node ret = n; - if( n.getNumChildren()>0 ){ - std::vector< Node > children; - if( n.getMetaKind() == kind::metakind::PARAMETERIZED ){ - children.push_back( n.getOperator() ); - } - bool childChanged = false; - for( unsigned i=0; imkNode( n.getKind(), children ); - } - } - ret = Rewriter::rewrite( n ); - Trace("sygus-ext-rewrite-debug") << "Do extended rewrite on : " << ret << " (from " << n << ")" << std::endl; - - Node new_ret; - if( ret.getKind()==kind::EQUAL ){ - // string equalities with disequal prefix or suffix - if( ret[0].getType().isString() ){ - std::vector< Node > c[2]; - for( unsigned i=0; i<2; i++ ){ - strings::TheoryStringsRewriter::getConcat( ret[i], c[i] ); - } - if( c[0].empty()==c[1].empty() ){ - if( !c[0].empty() ){ - for( unsigned i=0; i<2; i++ ){ - unsigned index1 = i==0 ? 0 : c[0].size()-1; - unsigned index2 = i==0 ? 0 : c[1].size()-1; - if( c[0][index1].isConst() && c[1][index2].isConst() ){ - CVC4::String s = c[0][index1].getConst(); - CVC4::String t = c[1][index2].getConst(); - unsigned len_short = s.size() <= t.size() ? s.size() : t.size(); - bool isSameFix = i==1 ? s.rstrncmp(t, len_short): s.strncmp(t, len_short); - if( !isSameFix ){ - Trace("sygus-ext-rewrite") << "sygus-extr : " << ret << " rewrites to false due to disequal string prefix/suffix." << std::endl; - new_ret = d_false; - break; - } - } - } - } - }else{ - new_ret = d_false; - } - } - if( new_ret.isNull() ){ - // simple ITE pulling - new_ret = extendedRewritePullIte( ret ); - } - // TODO : ( ~contains( x, y ) --> false ) => ( ~x=y --> false ) - }else if( ret.getKind()==kind::ITE ){ - Assert( ret[1]!=ret[2] ); - if( ret[0].getKind()==NOT ){ - ret = NodeManager::currentNM()->mkNode( kind::ITE, ret[0][0], ret[2], ret[1] ); - } - if( ret[0].getKind()==kind::EQUAL ){ - // simple invariant ITE - for( unsigned i=0; i<2; i++ ){ - if( ret[1]==ret[0][i] && ret[2]==ret[0][1-i] ){ - Trace("sygus-ext-rewrite") << "sygus-extr : " << ret << " rewrites to " << ret[2] << " due to simple invariant ITE." << std::endl; - new_ret = ret[2]; - break; - } - } - // notice this is strictly more general that the above - if( new_ret.isNull() ){ - // simple substitution - for( unsigned i=0; i<2; i++ ){ - if( ret[0][i].isVar() && ( ( ret[0][1-i].isVar() && ret[0][i]mkNode( kind::ITE, ret[0], retn, ret[2] ); - Trace("sygus-ext-rewrite") << "sygus-extr : " << ret << " rewrites to " << new_ret << " due to simple ITE substitution." << std::endl; - } - } - } - } - } - }else if( ret.getKind()==DIVISION || ret.getKind()==INTS_DIVISION || ret.getKind()==INTS_MODULUS ){ - // rewrite as though total - std::vector< Node > children; - bool all_const = true; - for( unsigned i=0; imkNode( new_k, children ); - Trace("sygus-ext-rewrite") << "sygus-extr : " << ret << " rewrites to " << new_ret << " due to total interpretation." << std::endl; - } - } - // more expensive rewrites - if( new_ret.isNull() ){ - Trace("sygus-ext-rewrite-debug2") << "Do expensive rewrites on " << ret << std::endl; - bool polarity = ret.getKind()!=NOT; - Node ret_atom = ret.getKind()==NOT ? ret[0] : ret; - if( ( ret_atom.getKind()==EQUAL && ret_atom[0].getType().isReal() ) || ret_atom.getKind()==GEQ ){ - Trace("sygus-ext-rewrite-debug2") << "Compute monomial sum " << ret_atom << std::endl; - //compute monomial sum - std::map< Node, Node > msum; - if( QuantArith::getMonomialSumLit( ret_atom, msum ) ){ - for( std::map< Node, Node >::iterator itm = msum.begin(); itm != msum.end(); ++itm ){ - Node v = itm->first; - Trace("sygus-ext-rewrite-debug2") << itm->first << " * " << itm->second << std::endl; - if( v.getKind()==ITE ){ - Node veq; - int res = QuantArith::isolate( v, msum, veq, ret_atom.getKind() ); - if( res!=0 ){ - Trace("sygus-ext-rewrite-debug") << " have ITE relation, solved form : " << veq << std::endl; - // try pulling ITE - new_ret = extendedRewritePullIte( veq ); - if( !new_ret.isNull() ){ - if( !polarity ){ - new_ret = new_ret.negate(); - } - break; - } - }else{ - Trace("sygus-ext-rewrite-debug") << " failed to isolate " << v << " in " << ret << std::endl; - } - } - } - }else{ - Trace("sygus-ext-rewrite-debug") << " failed to get monomial sum of " << ret << std::endl; - } - }else if( ret_atom.getKind()==ITE ){ - // TODO : conditional rewriting - }else if( ret.getKind()==kind::AND || ret.getKind()==kind::OR ){ - // TODO condition merging - } - } - - if( !new_ret.isNull() ){ - ret = Rewriter::rewrite( new_ret ); - } - d_ext_rewrite_cache[n] = ret; - return ret; - }else{ - return it->second; - } -} - - }/* CVC4::theory::quantifiers namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/quantifiers/term_database_sygus.h b/src/theory/quantifiers/term_database_sygus.h index 4786f053b..5ff1612c9 100644 --- a/src/theory/quantifiers/term_database_sygus.h +++ b/src/theory/quantifiers/term_database_sygus.h @@ -19,6 +19,7 @@ #include +#include "theory/quantifiers/extended_rewrite.h" #include "theory/quantifiers/sygus_explain.h" #include "theory/quantifiers/term_database.h" @@ -28,46 +29,76 @@ namespace quantifiers { class CegConjecture; -// TODO (as part of #1235) move to sygus_invariance.h -class SygusInvarianceTest { -protected: - // check whether nvn[ x ] should be excluded - virtual bool invariant( TermDbSygus * tds, Node nvn, Node x ) = 0; -public: - bool is_invariant( TermDbSygus * tds, Node nvn, Node x ){ - if( invariant( tds, nvn, x ) ){ - d_update_nvn = nvn; - return true; - }else{ - return false; - } - } - // result of the node after invariant replacements - Node d_update_nvn; -}; - -class EvalSygusInvarianceTest : public SygusInvarianceTest { -public: - Node d_conj; - TNode d_var; - std::map< Node, Node > d_visited; - Node d_result; -protected: - bool invariant( quantifiers::TermDbSygus * tds, Node nvn, Node x ); -}; - // TODO :issue #1235 split and document this class class TermDbSygus { -private: + public: + TermDbSygus(context::Context* c, QuantifiersEngine* qe); + ~TermDbSygus() {} + /** Reset this utility */ + bool reset(Theory::Effort e); + /** Identify this utility */ + std::string identify() const { return "TermDbSygus"; } + /** register the sygus type */ + void registerSygusType(TypeNode tn); + /** register a variable e that we will do enumerative search on + * conj is the conjecture that the enumeration of e is for. + * f is the synth-fun that the enumeration of e is for. + * mkActiveGuard is whether we want to make an active guard for e + * (see d_enum_to_active_guard). + * + * Notice that enumerator e may not be equivalent + * to f in synthesis-through-unification approaches + * (e.g. decision tree construction for PBE synthesis). + */ + void registerEnumerator(Node e, + Node f, + CegConjecture* conj, + bool mkActiveGuard = false); + /** is e an enumerator? */ + bool isEnumerator(Node e) const; + /** return the conjecture e is associated with */ + CegConjecture* getConjectureForEnumerator(Node e); + /** return the function-to-synthesize e is associated with */ + Node getSynthFunForEnumerator(Node e); + /** get active guard for e */ + Node getActiveGuardForEnumerator(Node e); + /** get all registered enumerators */ + void getEnumerators(std::vector& mts); + /** get the explanation utility */ + SygusExplain* getExplain() { return d_syexp.get(); } + /** get the extended rewrite utility */ + ExtendedRewriter* getExtRewriter() { return d_ext_rw.get(); } + private: /** reference to the quantifiers engine */ QuantifiersEngine* d_quantEngine; + /** sygus explanation */ + std::unique_ptr d_syexp; + /** sygus explanation */ + std::unique_ptr d_ext_rw; + /** mapping from enumerator terms to the conjecture they are associated with + */ + std::map d_enum_to_conjecture; + /** mapping from enumerator terms to the function-to-synthesize they are + * associated with + */ + std::map d_enum_to_synth_fun; + /** mapping from enumerator terms to the guard they are associated with + * The guard G for an enumerator e has the semantics + * if G is true, then there are more values of e to enumerate". + */ + std::map d_enum_to_active_guard; + + // TODO :issue #1235 : below here needs refactor + + public: + Node d_true; + Node d_false; + + private: std::map< TypeNode, std::vector< Node > > d_fv[2]; std::map< Node, TypeNode > d_fv_stype; std::map< Node, int > d_fv_num; bool hasFreeVar( Node n, std::map< Node, bool >& visited ); -public: - Node d_true; - Node d_false; public: TNode getFreeVar( TypeNode tn, int i, bool useSygusType = false ); TNode getFreeVarInc( TypeNode tn, std::map< TypeNode, int >& var_count, bool useSygusType = false ); @@ -85,16 +116,6 @@ private: void computeMinTypeDepthInternal( TypeNode root_tn, TypeNode tn, unsigned type_depth ); bool involvesDivByZero( Node n, std::map< Node, bool >& visited ); private: - /** mapping from enumerator terms to the conjecture they are associated with */ - std::map d_enum_to_conjecture; - /** mapping from enumerator terms to the function-to-synthesize they are - * associated with */ - std::map d_enum_to_synth_fun; - /** mapping from enumerator terms to the guard they are associated with - * The guard G for an enumerator e has the semantics - * "if G is true, then there are more values of e to enumerate". - */ - std::map d_enum_to_active_guard; // information for sygus types std::map d_register; // stores sygus -> builtin type std::map > d_var_list; @@ -109,12 +130,6 @@ private: d_const_list; // sorted list of constants for type std::map d_const_list_pos; std::map > d_semantic_skolem; - // information for builtin types - std::map > d_type_value; - std::map d_type_max_value; - std::map > > d_type_value_offset; - std::map > > - d_type_value_offset_status; // normalized map std::map > d_normalized; std::map > d_sygus_to_builtin; @@ -127,39 +142,6 @@ private: // type -> cons -> _ std::map d_min_term_size; std::map > d_min_cons_term_size; -public: - TermDbSygus( context::Context* c, QuantifiersEngine* qe ); - ~TermDbSygus(){} - bool reset( Theory::Effort e ); - std::string identify() const { return "TermDbSygus"; } -public: - /** register the sygus type */ - void registerSygusType( TypeNode tn ); - /** register a variable e that we will do enumerative search on - * conj is the conjecture that the enumeration of e is for. - * f is the synth-fun that the enumeration of e is for. - * mkActiveGuard is whether we want to make a active guard for e (see - * d_enum_to_active_guard) - * - * Notice that enumerator e may not be equivalent - * to f in synthesis-through-unification approaches - * (e.g. decision tree construction for PBE synthesis). - */ - void registerEnumerator(Node e, - Node f, - CegConjecture* conj, - bool mkActiveGuard = false); - /** is e an enumerator? */ - bool isEnumerator(Node e) const; - /** return the conjecture e is associated with */ - CegConjecture* getConjectureForEnumerator(Node e); - /** return the function-to-synthesize e is associated with */ - Node getSynthFunForEnumerator(Node e); - /** get active guard for e */ - Node getActiveGuardForEnumerator(Node e); - /** get all registered enumerators */ - void getEnumerators(std::vector& mts); - public: // general sygus utilities bool isRegistered( TypeNode tn ); // get the minimum depth of type in its parent grammar @@ -188,20 +170,7 @@ public: int getFirstArgOccurrence( const DatatypeConstructor& c, TypeNode tn ); /** is type match */ bool isTypeMatch( const DatatypeConstructor& c1, const DatatypeConstructor& c2 ); - /** isAntisymmetric */ - bool isAntisymmetric( Kind k, Kind& dk ); - /** is idempotent arg */ - bool isIdempotentArg( Node n, Kind ik, int arg ); - /** is singular arg */ - Node isSingularArg( Node n, Kind ik, int arg ); - /** get offset arg */ - bool hasOffsetArg( Kind ik, int arg, int& offset, Kind& ok ); - /** get value */ - Node getTypeValue( TypeNode tn, int val ); - /** get value */ - Node getTypeValueOffset( TypeNode tn, Node val, int offset, int& status ); - /** get value */ - Node getTypeMaxValue( TypeNode tn ); + TypeNode getSygusTypeForVar( Node v ); Node getGenericBase( TypeNode tn, const Datatype& dt, int c ); Node mkGeneric( const Datatype& dt, int c, std::map< TypeNode, int >& var_count, std::map< int, Node >& pre ); @@ -210,7 +179,7 @@ public: Node sygusSubstituted( TypeNode tn, Node n, std::vector< Node >& args ); Node builtinToSygusConst( Node c, TypeNode tn, int rcons_depth = 0 ); Node getSygusNormalized( Node n, std::map< TypeNode, int >& var_count, std::map< Node, Node >& subs ); - Node getNormalized( TypeNode t, Node prog, bool do_pre_norm = false, bool do_post_norm = true ); + Node getNormalized(TypeNode t, Node prog); unsigned getSygusTermSize( Node n ); // returns size unsigned getSygusConstructors( Node n, std::vector< Node >& cons ); @@ -253,8 +222,6 @@ public: // for symmetry breaking std::map< Node, std::vector< bool > > d_eval_args_const; std::map< Node, std::map< Node, unsigned > > d_node_mv_args_proc; - void getExplanationFor( TermRecBuild& trb, Node n, Node vn, std::vector< Node >& exp, std::map< TypeNode, int >& var_count, - SygusInvarianceTest& et, Node vnr, Node& vnr_exp, int& sz ); public: void registerEvalTerm( Node n ); void registerModelValue( Node n, Node v, std::vector< Node >& exps, std::vector< Node >& terms, std::vector< Node >& vals ); @@ -265,19 +232,12 @@ public: return unfold( en, vtm, exp, false ); } Node getEagerUnfold( Node n, std::map< Node, Node >& visited ); - // returns straightforward exp => n = vn - void getExplanationForConstantEquality( Node n, Node vn, std::vector< Node >& exp ); - void getExplanationForConstantEquality( Node n, Node vn, std::vector< Node >& exp, std::map< unsigned, bool >& cexc ); - Node getExplanationForConstantEquality( Node n, Node vn ); - Node getExplanationForConstantEquality( Node n, Node vn, std::map< unsigned, bool >& cexc ); - // we have n = vn => eval( n ) = bvr, returns exp => eval( n ) = bvr - // ensures the explanation still allows for vnr - void getExplanationFor( Node n, Node vn, std::vector< Node >& exp, SygusInvarianceTest& et, Node vnr, unsigned& sz ); - void getExplanationFor( Node n, Node vn, std::vector< Node >& exp, SygusInvarianceTest& et ); + // builtin evaluation, returns rewrite( bn [ args / vars(tn) ] ) Node evaluateBuiltin( TypeNode tn, Node bn, std::vector< Node >& args ); // evaluate with unfolding - Node evaluateWithUnfolding( Node n, std::map< Node, Node >& visited ); + Node evaluateWithUnfolding( + Node n, std::unordered_map& visited); Node evaluateWithUnfolding( Node n ); //for calculating redundant operators private: @@ -291,13 +251,6 @@ private: bool computeGenericRedundant( TypeNode tn, Node g ); public: bool isGenericRedundant( TypeNode tn, unsigned i ); - -// extended rewriting -private: - std::map< Node, Node > d_ext_rewrite_cache; - Node extendedRewritePullIte( Node n ); -public: - Node extendedRewrite( Node n ); }; }/* CVC4::theory::quantifiers namespace */ diff --git a/src/theory/quantifiers/term_util.cpp b/src/theory/quantifiers/term_util.cpp index 471670515..2183db5f1 100644 --- a/src/theory/quantifiers/term_util.cpp +++ b/src/theory/quantifiers/term_util.cpp @@ -917,6 +917,319 @@ bool TermUtil::isBoolConnectiveTerm( TNode n ) { ( n.getKind()!=ITE || n.getType().isBoolean() ); } +Node TermUtil::getTypeValue(TypeNode tn, int val) +{ + std::unordered_map::iterator it = d_type_value[tn].find(val); + if (it == d_type_value[tn].end()) + { + Node n = mkTypeValue(tn, val); + d_type_value[tn][val] = n; + return n; + } + return it->second; +} + +Node TermUtil::mkTypeValue(TypeNode tn, int val) +{ + Node n; + if (tn.isInteger() || tn.isReal()) + { + Rational c(val); + n = NodeManager::currentNM()->mkConst(c); + } + else if (tn.isBitVector()) + { + unsigned int uv = val; + BitVector bval(tn.getConst(), uv); + n = NodeManager::currentNM()->mkConst(bval); + } + else if (tn.isBoolean()) + { + if (val == 0) + { + n = NodeManager::currentNM()->mkConst(false); + } + } + else if (tn.isString()) + { + if (val == 0) + { + n = NodeManager::currentNM()->mkConst(::CVC4::String("")); + } + } + return n; +} + +Node TermUtil::getTypeMaxValue(TypeNode tn) +{ + std::unordered_map::iterator it = + d_type_max_value.find(tn); + if (it == d_type_max_value.end()) + { + Node n = mkTypeMaxValue(tn); + d_type_max_value[tn] = n; + return n; + } + return it->second; +} + +Node TermUtil::mkTypeMaxValue(TypeNode tn) +{ + Node n; + if (tn.isBitVector()) + { + n = bv::utils::mkOnes(tn.getConst()); + } + else if (tn.isBoolean()) + { + n = NodeManager::currentNM()->mkConst(true); + } + return n; +} + +Node TermUtil::getTypeValueOffset(TypeNode tn, + Node val, + int offset, + int& status) +{ + std::unordered_map::iterator it = + d_type_value_offset[tn][val].find(offset); + if (it == d_type_value_offset[tn][val].end()) + { + Node val_o; + Node offset_val = getTypeValue(tn, offset); + status = -1; + if (!offset_val.isNull()) + { + if (tn.isInteger() || tn.isReal()) + { + val_o = Rewriter::rewrite( + NodeManager::currentNM()->mkNode(PLUS, val, offset_val)); + status = 0; + } + else if (tn.isBitVector()) + { + val_o = Rewriter::rewrite( + NodeManager::currentNM()->mkNode(BITVECTOR_PLUS, val, offset_val)); + // TODO : enable? watch for overflows + } + } + d_type_value_offset[tn][val][offset] = val_o; + d_type_value_offset_status[tn][val][offset] = status; + return val_o; + } + status = d_type_value_offset_status[tn][val][offset]; + return it->second; +} + +bool TermUtil::isAntisymmetric(Kind k, Kind& dk) +{ + if (k == GT) + { + dk = LT; + return true; + } + else if (k == GEQ) + { + dk = LEQ; + return true; + } + else if (k == BITVECTOR_UGT) + { + dk = BITVECTOR_ULT; + return true; + } + else if (k == BITVECTOR_UGE) + { + dk = BITVECTOR_ULE; + return true; + } + else if (k == BITVECTOR_SGT) + { + dk = BITVECTOR_SLT; + return true; + } + else if (k == BITVECTOR_SGE) + { + dk = BITVECTOR_SLE; + return true; + } + return false; +} + +bool TermUtil::isIdempotentArg(Node n, Kind ik, int arg) +{ + // these should all be binary operators + // Assert( ik!=DIVISION && ik!=INTS_DIVISION && ik!=INTS_MODULUS && + // ik!=BITVECTOR_UDIV ); + TypeNode tn = n.getType(); + if (n == getTypeValue(tn, 0)) + { + if (ik == PLUS || ik == OR || ik == XOR || ik == BITVECTOR_PLUS + || ik == BITVECTOR_OR + || ik == BITVECTOR_XOR + || ik == STRING_CONCAT) + { + return true; + } + else if (ik == MINUS || ik == BITVECTOR_SHL || ik == BITVECTOR_LSHR + || ik == BITVECTOR_ASHR + || ik == BITVECTOR_SUB + || ik == BITVECTOR_UREM + || ik == BITVECTOR_UREM_TOTAL) + { + return arg == 1; + } + } + else if (n == getTypeValue(tn, 1)) + { + if (ik == MULT || ik == BITVECTOR_MULT) + { + return true; + } + else if (ik == DIVISION || ik == DIVISION_TOTAL || ik == INTS_DIVISION + || ik == INTS_DIVISION_TOTAL + || ik == INTS_MODULUS + || ik == INTS_MODULUS_TOTAL + || ik == BITVECTOR_UDIV_TOTAL + || ik == BITVECTOR_UDIV + || ik == BITVECTOR_SDIV) + { + return arg == 1; + } + } + else if (n == getTypeMaxValue(tn)) + { + if (ik == EQUAL || ik == BITVECTOR_AND || ik == BITVECTOR_XNOR) + { + return true; + } + } + return false; +} + +Node TermUtil::isSingularArg(Node n, Kind ik, int arg) +{ + TypeNode tn = n.getType(); + if (n == getTypeValue(tn, 0)) + { + if (ik == AND || ik == MULT || ik == BITVECTOR_AND || ik == BITVECTOR_MULT) + { + return n; + } + else if (ik == BITVECTOR_SHL || ik == BITVECTOR_LSHR || ik == BITVECTOR_ASHR + || ik == BITVECTOR_UREM + || ik == BITVECTOR_UREM_TOTAL) + { + if (arg == 0) + { + return n; + } + } + else if (ik == BITVECTOR_UDIV_TOTAL || ik == BITVECTOR_UDIV + || ik == BITVECTOR_SDIV) + { + if (arg == 0) + { + return n; + } + else if (arg == 1) + { + return getTypeMaxValue(tn); + } + } + else if (ik == DIVISION || ik == DIVISION_TOTAL || ik == INTS_DIVISION + || ik == INTS_DIVISION_TOTAL + || ik == INTS_MODULUS + || ik == INTS_MODULUS_TOTAL) + { + if (arg == 0) + { + return n; + } + else + { + // TODO? + } + } + else if (ik == STRING_SUBSTR) + { + if (arg == 0) + { + return n; + } + else if (arg == 2) + { + return getTypeValue(NodeManager::currentNM()->stringType(), 0); + } + } + else if (ik == STRING_STRIDOF) + { + if (arg == 0 || arg == 1) + { + return getTypeValue(NodeManager::currentNM()->integerType(), -1); + } + } + } + else if (n == getTypeValue(tn, 1)) + { + if (ik == BITVECTOR_UREM_TOTAL) + { + return getTypeValue(tn, 0); + } + } + else if (n == getTypeMaxValue(tn)) + { + if (ik == OR || ik == BITVECTOR_OR) + { + return n; + } + } + else + { + if (n.getType().isReal() && n.getConst().sgn() < 0) + { + // negative arguments + if (ik == STRING_SUBSTR || ik == STRING_CHARAT) + { + return getTypeValue(NodeManager::currentNM()->stringType(), 0); + } + else if (ik == STRING_STRIDOF) + { + Assert(arg == 2); + return getTypeValue(NodeManager::currentNM()->integerType(), -1); + } + } + } + return Node::null(); +} + +bool TermUtil::hasOffsetArg(Kind ik, int arg, int& offset, Kind& ok) +{ + if (ik == LT) + { + Assert(arg == 0 || arg == 1); + offset = arg == 0 ? 1 : -1; + ok = LEQ; + return true; + } + else if (ik == BITVECTOR_ULT) + { + Assert(arg == 0 || arg == 1); + offset = arg == 0 ? 1 : -1; + ok = BITVECTOR_ULE; + return true; + } + else if (ik == BITVECTOR_SLT) + { + Assert(arg == 0 || arg == 1); + offset = arg == 0 ? 1 : -1; + ok = BITVECTOR_SLE; + return true; + } + return false; +} + Node TermUtil::getHoTypeMatchPredicate( TypeNode tn ) { std::map< TypeNode, Node >::iterator ithp = d_ho_type_match_pred.find( tn ); if( ithp==d_ho_type_match_pred.end() ){ diff --git a/src/theory/quantifiers/term_util.h b/src/theory/quantifiers/term_util.h index bcdf8a2ff..d2a8a14f0 100644 --- a/src/theory/quantifiers/term_util.h +++ b/src/theory/quantifiers/term_util.h @@ -265,11 +265,34 @@ public: static void getRelevancyCondition( Node n, std::vector< Node >& cond ); //general utilities -private: + // TODO #1216 : promote these? + private: //helper for contains term static bool containsTerm2( Node n, Node t, std::map< Node, bool >& visited ); static bool containsTerms2( Node n, std::vector< Node >& t, std::map< Node, bool >& visited ); -public: + /** cache for getTypeValue */ + std::unordered_map, + TypeNodeHashFunction> + d_type_value; + /** cache for getTypeMaxValue */ + std::unordered_map d_type_max_value; + /** cache for getTypeValueOffset */ + std::unordered_map, + NodeHashFunction>, + TypeNodeHashFunction> + d_type_value_offset; + /** cache for status of getTypeValueOffset*/ + std::unordered_map, + NodeHashFunction>, + TypeNodeHashFunction> + d_type_value_offset_status; + + public: /** simple check for whether n contains t as subterm */ static bool containsTerm( Node n, Node t ); /** simple check for contains term, true if contains at least one term in t */ @@ -282,17 +305,82 @@ public: static Node simpleNegate( Node n ); /** is assoc */ static bool isAssoc( Kind k ); - /** is comm */ + /** is k commutative? */ static bool isComm( Kind k ); - /** ( x k ... ) k x = ( x k ... ) */ + + /** is k non-additive? + * Returns true if + * ( ( T1, x, T2 ), x ) = + * ( T1, x, T2 ) + * always holds, where T1 and T2 are vectors. + */ static bool isNonAdditive( Kind k ); - /** is bool connective */ + /** is k a bool connective? */ static bool isBoolConnective( Kind k ); - /** is bool connective term */ + /** is n a bool connective term? */ static bool isBoolConnectiveTerm( TNode n ); -//for higher-order -private: + /** is the kind k antisymmetric? + * If so, return true and store its inverse kind in dk. + */ + static bool isAntisymmetric(Kind k, Kind& dk); + + /** has offset arg + * Returns true if there is a Kind ok and offset + * such that + * ( ... t_{arg-1}, n, t_{arg+1}... ) = + * ( ... t_{arg-1}, n+offset, t_{arg+1}...) + * always holds. + * If so, this function returns true and stores + * offset and ok in the respective fields. + */ + static bool hasOffsetArg(Kind ik, int arg, int& offset, Kind& ok); + + /** is idempotent arg + * Returns true if + * ( ... t_{arg-1}, n, t_{arg+1}...) = + * ( ... t_{arg-1}, t_{arg+1}...) + * always holds. + */ + bool isIdempotentArg(Node n, Kind ik, int arg); + + /** is singular arg + * Returns true if + * ( ... t_{arg-1}, n, t_{arg+1}...) = n + * always holds. + */ + Node isSingularArg(Node n, Kind ik, int arg); + + /** get type value + * This gets the Node that represents value val for Type tn + * This is used to get simple values, e.g. -1,0,1, + * in a uniform way per type. + */ + Node getTypeValue(TypeNode tn, int val); + + /** get type value offset + * Returns the value of ( val + getTypeValue( tn, offset ) ), + * where + is the additive operator for the type. + * Stores the status (0: success, -1: failure) in status. + */ + Node getTypeValueOffset(TypeNode tn, Node val, int offset, int& status); + + /** get the "max" value for type tn + * For example, + * the max value for Bool is true, + * the max value for BitVector is 1..1. + */ + Node getTypeMaxValue(TypeNode tn); + + /** make value, static version of get value */ + static Node mkTypeValue(TypeNode tn, int val); + /** make value offset, static version of get value offset */ + static Node mkTypeValueOffset(TypeNode tn, Node val, int offset, int& status); + /** make max value, static version of get max value */ + static Node mkTypeMaxValue(TypeNode tn); + + // for higher-order + private: /** dummy predicate that states terms should be considered first-class members of equality engine */ std::map< TypeNode, Node > d_ho_type_match_pred; public: