From 2c78f3bf696e7eb4b04f687e15b9569b9e1b8f23 Mon Sep 17 00:00:00 2001 From: Abdalrhman Mohamed <32971963+abdoo8080@users.noreply.github.com> Date: Thu, 21 May 2020 13:01:14 -0500 Subject: [PATCH] Make Grammar reusable. (#4506) This PR modifies the Grammar implementation to make it reusable (i.e., can be copied or passed multiple times to synthFun/synthInv) with the catch that it becomes read-only after the first use. --- examples/api/CMakeLists.txt | 1 + examples/api/sygus-fun.cpp | 2 +- examples/api/sygus-grammar.cpp | 121 +++++++++++++++++++++++++++++++++ src/api/cvc4cpp.cpp | 120 ++++++++++++++++++++------------ src/api/cvc4cpp.h | 43 ++++++------ test/unit/api/grammar_black.h | 28 ++++++-- 6 files changed, 245 insertions(+), 70 deletions(-) create mode 100644 examples/api/sygus-grammar.cpp diff --git a/examples/api/CMakeLists.txt b/examples/api/CMakeLists.txt index e4ef4ee78..3ced5681c 100644 --- a/examples/api/CMakeLists.txt +++ b/examples/api/CMakeLists.txt @@ -18,6 +18,7 @@ set(CVC4_EXAMPLES_API strings strings-new sygus-fun + sygus-grammar sygus-inv ) diff --git a/examples/api/sygus-fun.cpp b/examples/api/sygus-fun.cpp index d6437afa3..6c47ec715 100644 --- a/examples/api/sygus-fun.cpp +++ b/examples/api/sygus-fun.cpp @@ -78,7 +78,7 @@ int main() Term one = slv.mkReal(1); Term plus = slv.mkTerm(PLUS, start, start); - Term minus = slv.mkTerm(PLUS, start, start); + Term minus = slv.mkTerm(MINUS, start, start); Term ite = slv.mkTerm(ITE, start_bool, start, start); Term And = slv.mkTerm(AND, start_bool, start_bool); diff --git a/examples/api/sygus-grammar.cpp b/examples/api/sygus-grammar.cpp new file mode 100644 index 000000000..c2e624c1f --- /dev/null +++ b/examples/api/sygus-grammar.cpp @@ -0,0 +1,121 @@ +/********************* */ +/*! \file sygus-grammar.cpp + ** \verbatim + ** Top contributors (to current version): + ** Abdalrhman Mohamed, Andrew Reynolds + ** This file is part of the CVC4 project. + ** Copyright (c) 2009-2019 by the authors listed in the file AUTHORS + ** in the top-level source directory) and their institutional affiliations. + ** All rights reserved. See the file COPYING in the top-level source + ** directory for licensing information.\endverbatim + ** + ** \brief A simple demonstration of the Sygus API. + ** + ** A simple demonstration of how to use the Sygus API to synthesize max and min + ** functions. Here is the same problem written in Sygus V2 format: + ** + ** (set-logic LIA) + ** + ** (synth-fun id1 ((x Int)) Int + ** ((Start Int)) ((Start Int ((- x) (+ x Start))))) + ** + ** (synth-fun id2 ((x Int)) Int + ** ((Start Int)) ((Start Int ((Variable Int) (- x) (+ x Start))))) + ** + ** (synth-fun id3 ((x Int)) Int + ** ((Start Int)) ((Start Int (0 (- x) (+ x Start))))) + ** + ** (synth-fun id4 ((x Int)) Int + ** ((Start Int)) ((Start Int ((- x) (+ x Start))))) + ** + ** (declare-var x Int) + ** + ** (constraint (= (id1 x) (id2 x) (id3 x) (id4 x) x)) + ** + ** (check-synth) + ** + ** The printed output to this example should be equivalent to: + ** (define-fun max ((x Int) (y Int)) Int (ite (<= x y) y x)) + ** (define-fun min ((x Int) (y Int)) Int (ite (<= x y) x y)) + **/ + +#include + +#include + +using namespace CVC4::api; + +int main() +{ + Solver slv; + + // required options + slv.setOption("lang", "sygus2"); + slv.setOption("incremental", "false"); + + // set the logic + slv.setLogic("LIA"); + + Sort integer = slv.getIntegerSort(); + Sort boolean = slv.getBooleanSort(); + + // declare input variables for the function-to-synthesize + Term x = slv.mkVar(integer, "x"); + + // declare the grammar non-terminals + Term start = slv.mkVar(integer, "Start"); + + // define the rules + Term zero = slv.mkReal(0); + Term neg_x = slv.mkTerm(UMINUS, x); + Term plus = slv.mkTerm(PLUS, x, start); + + // create the grammar object + Grammar g1 = slv.mkSygusGrammar({x}, {start}); + + // bind each non-terminal to its rules + g1.addRules(start, {neg_x, plus}); + + // copy the first grammar with all of its non-termainals and their rules + Grammar g2 = g1; + Grammar g3 = g1; + + // add parameters as rules to the start symbol. Similar to "(Variable Int)" + g2.addAnyVariable(start); + + // declare the function-to-synthesizse + Term id1 = slv.synthFun("id1", {x}, integer, g1); + Term id2 = slv.synthFun("id2", {x}, integer, g2); + + g3.addRule(start, zero); + + Term id3 = slv.synthFun("id3", {x}, integer, g3); + + // g1 is reusable as long as it remains unmodified after first use + Term id4 = slv.synthFun("id4", {x}, integer, g1); + + // declare universal variables. + Term varX = slv.mkSygusVar(integer, "x"); + + Term id1_x = slv.mkTerm(APPLY_UF, id1, varX); + Term id2_x = slv.mkTerm(APPLY_UF, id2, varX); + Term id3_x = slv.mkTerm(APPLY_UF, id3, varX); + Term id4_x = slv.mkTerm(APPLY_UF, id4, varX); + + // add logical constraints + // (constraint (= (id1 x) (id2 x) x)) + slv.addSygusConstraint(slv.mkTerm(EQUAL, {id1_x, id2_x, id3_x, id4_x, varX})); + + // print solutions if available + if (slv.checkSynth().isUnsat()) + { + // Output should be equivalent to: + // (define-fun id1 ((x Int)) Int (+ x (+ x (- x)))) + // (define-fun id2 ((x Int)) Int x) + // (define-fun id3 ((x Int)) Int (+ x 0)) + // (define-fun id4 ((x Int)) Int (+ x (+ x (- x)))) + slv.printSynthSolution(std::cout); + } + + return 0; +} diff --git a/src/api/cvc4cpp.cpp b/src/api/cvc4cpp.cpp index 5c0c4a750..d990b3e22 100644 --- a/src/api/cvc4cpp.cpp +++ b/src/api/cvc4cpp.cpp @@ -2241,39 +2241,40 @@ Grammar::Grammar(const Solver* s, : d_s(s), d_sygusVars(sygusVars), d_ntSyms(ntSymbols), - d_ntsToUnres(), - d_dtDecls(), - d_allowConst() + d_ntsToTerms(ntSymbols.size()), + d_allowConst(), + d_allowVars(), + d_isResolved(false) { for (Term ntsymbol : d_ntSyms) { - // make the datatype, which encodes terms generated by this non-terminal - d_dtDecls.emplace(ntsymbol, DatatypeDecl(d_s, ntsymbol.toString())); - // make its unresolved type, used for referencing the final version of - // the datatype - d_ntsToUnres[ntsymbol] = d_s->getExprManager()->mkSort(ntsymbol.toString()); + d_ntsToTerms.emplace(ntsymbol, std::vector()); } } void Grammar::addRule(Term ntSymbol, Term rule) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); CVC4_API_ARG_CHECK_NOT_NULL(rule); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; CVC4_API_CHECK(ntSymbol.d_expr->getType() == rule.d_expr->getType()) << "Expected ntSymbol and rule to have the same sort"; - addSygusConstructorTerm(d_dtDecls[ntSymbol], rule); + d_ntsToTerms[ntSymbol].push_back(rule); } void Grammar::addRules(Term ntSymbol, std::vector rules) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; @@ -2285,16 +2286,19 @@ void Grammar::addRules(Term ntSymbol, std::vector rules) CVC4_API_CHECK(ntSymbol.d_expr->getType() == rules[i].d_expr->getType()) << "Expected ntSymbol and rule at index " << i << " to have the same sort"; - - addSygusConstructorTerm(d_dtDecls[ntSymbol], rules[i]); } + + d_ntsToTerms[ntSymbol].insert( + d_ntsToTerms[ntSymbol].cend(), rules.cbegin(), rules.cend()); } void Grammar::addAnyConstant(Term ntSymbol) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; @@ -2303,17 +2307,21 @@ void Grammar::addAnyConstant(Term ntSymbol) void Grammar::addAnyVariable(Term ntSymbol) { + CVC4_API_CHECK(!d_isResolved) << "Grammar cannot be modified after passing " + "it as an argument to synthFun/synthInv"; CVC4_API_ARG_CHECK_NOT_NULL(ntSymbol); - CVC4_API_ARG_CHECK_EXPECTED(d_dtDecls.find(ntSymbol) != d_dtDecls.end(), - ntSymbol) + CVC4_API_ARG_CHECK_EXPECTED( + d_ntsToTerms.find(ntSymbol) != d_ntsToTerms.cend(), ntSymbol) << "ntSymbol to be one of the non-terminal symbols given in the " "predeclaration"; - addSygusConstructorVariables(d_dtDecls[ntSymbol], ntSymbol.d_expr->getType()); + d_allowVars.insert(ntSymbol); } Sort Grammar::resolve() { + d_isResolved = true; + Term bvl; if (!d_sygusVars.empty()) @@ -2322,29 +2330,48 @@ Sort Grammar::resolve() termVectorToExprs(d_sygusVars)); } - for (const Term& i : d_ntSyms) + std::unordered_map ntsToUnres(d_ntSyms.size()); + + for (Term ntsymbol : d_ntSyms) { - bool aci = d_allowConst.find(i) != d_allowConst.end(); - Type btt = i.d_expr->getType(); - d_dtDecls[i].d_dtype->setSygus(btt, *bvl.d_expr, aci, false); - // We can be in a case where the only rule specified was (Variable T) - // and there are no variables of type T, in which case this is a bogus - // grammar. This results in the error below. - CVC4_API_CHECK(d_dtDecls[i].d_dtype->getNumConstructors() != 0) - << "Grouped rule listing for " << d_dtDecls[i] - << " produced an empty rule list"; + // make the unresolved type, used for referencing the final version of + // the ntsymbol's datatype + ntsToUnres[ntsymbol] = d_s->getExprManager()->mkSort(ntsymbol.toString()); } - // now, make the sygus datatype std::vector datatypes; std::set unresTypes; datatypes.reserve(d_ntSyms.size()); - for (const Term& i : d_ntSyms) + for (const Term& ntSym : d_ntSyms) { - datatypes.push_back(*d_dtDecls[i].d_dtype); - unresTypes.insert(*d_ntsToUnres[i].d_type); + // make the datatype, which encodes terms generated by this non-terminal + DatatypeDecl dtDecl(d_s, ntSym.toString()); + + for (const Term& consTerm : d_ntsToTerms[ntSym]) + { + addSygusConstructorTerm(dtDecl, consTerm, ntsToUnres); + } + + if (d_allowVars.find(ntSym) != d_allowConst.cend()) + { + addSygusConstructorVariables(dtDecl, ntSym.d_expr->getType()); + } + + bool aci = d_allowConst.find(ntSym) != d_allowConst.end(); + Type btt = ntSym.d_expr->getType(); + dtDecl.d_dtype->setSygus(btt, *bvl.d_expr, aci, false); + + // We can be in a case where the only rule specified was (Variable T) + // and there are no variables of type T, in which case this is a bogus + // grammar. This results in the error below. + CVC4_API_CHECK(dtDecl.d_dtype->getNumConstructors() != 0) + << "Grouped rule listing for " << *dtDecl.d_dtype + << " produced an empty rule list"; + + datatypes.push_back(*dtDecl.d_dtype); + unresTypes.insert(*ntsToUnres[ntSym].d_type); } std::vector datatypeTypes = @@ -2355,7 +2382,10 @@ Sort Grammar::resolve() return datatypeTypes[0]; } -void Grammar::addSygusConstructorTerm(DatatypeDecl& dt, Term term) const +void Grammar::addSygusConstructorTerm( + DatatypeDecl& dt, + Term term, + const std::unordered_map& ntsToUnres) const { // At this point, we should know that dt is well founded, and that its // builtin sygus operators are well-typed. @@ -2367,7 +2397,7 @@ void Grammar::addSygusConstructorTerm(DatatypeDecl& dt, Term term) const // this does not lead to exponential behavior with respect to input size. std::vector args; std::vector cargs; - Term op = purifySygusGTerm(term, args, cargs); + Term op = purifySygusGTerm(term, args, cargs, ntsToUnres); std::stringstream ssCName; ssCName << op.getKind(); std::shared_ptr spc; @@ -2386,13 +2416,15 @@ void Grammar::addSygusConstructorTerm(DatatypeDecl& dt, Term term) const *op.d_expr, ssCName.str(), sortVectorToTypes(cargs), spc); } -Term Grammar::purifySygusGTerm(Term term, - std::vector& args, - std::vector& cargs) const +Term Grammar::purifySygusGTerm( + Term term, + std::vector& args, + std::vector& cargs, + const std::unordered_map& ntsToUnres) const { std::unordered_map::const_iterator itn = - d_ntsToUnres.find(term); - if (itn != d_ntsToUnres.cend()) + ntsToUnres.find(term); + if (itn != ntsToUnres.cend()) { Term ret = d_s->getExprManager()->mkBoundVar(term.d_expr->getType()); args.push_back(ret); @@ -2403,7 +2435,7 @@ Term Grammar::purifySygusGTerm(Term term, bool childChanged = false; for (unsigned i = 0, nchild = term.d_expr->getNumChildren(); i < nchild; i++) { - Term ptermc = purifySygusGTerm((*term.d_expr)[i], args, cargs); + Term ptermc = purifySygusGTerm((*term.d_expr)[i], args, cargs, ntsToUnres); pchildren.push_back(ptermc); childChanged = childChanged || *ptermc.d_expr != (*term.d_expr)[i]; } @@ -4495,7 +4527,7 @@ Term Solver::synthFun(const std::string& symbol, Term Solver::synthFun(const std::string& symbol, const std::vector& boundVars, Sort sort, - Grammar g) const + Grammar& g) const { return synthFunHelper(symbol, boundVars, sort, false, &g); } @@ -4508,7 +4540,7 @@ Term Solver::synthInv(const std::string& symbol, Term Solver::synthInv(const std::string& symbol, const std::vector& boundVars, - Grammar g) const + Grammar& g) const { return synthFunHelper(symbol, boundVars, d_exprMgr->booleanType(), true, &g); } diff --git a/src/api/cvc4cpp.h b/src/api/cvc4cpp.h index 5cfc61bbb..279453747 100644 --- a/src/api/cvc4cpp.h +++ b/src/api/cvc4cpp.h @@ -1810,25 +1810,29 @@ class CVC4_PUBLIC Grammar /** * Adds a constructor to sygus datatype
whose sygus operator is . * - * contains a mapping from non-terminal symbols to the + * contains a mapping from non-terminal symbols to the * unresolved sorts they correspond to. This map indicates how the argument * should be interpreted (instances of symbols from the domain of - * correspond to constructor arguments). + * correspond to constructor arguments). * * The sygus operator that is actually added to
corresponds to replacing - * each occurrence of non-terminal symbols from the domain of + * each occurrence of non-terminal symbols from the domain of * with bound variables via purifySygusGTerm, and binding these variables * via a lambda. * * @param dt the non-terminal's datatype to which a constructor is added * @param term the sygus operator of the constructor + * @param ntsToUnres mapping from non-terminals to their unresolved sorts */ - void addSygusConstructorTerm(DatatypeDecl& dt, Term term) const; + void addSygusConstructorTerm( + DatatypeDecl& dt, + Term term, + const std::unordered_map& ntsToUnres) const; /** Purify sygus grammar term * * This returns a term where all occurrences of non-terminal symbols (those - * in the domain of ) are replaced by fresh variables. For + * in the domain of ) are replaced by fresh variables. For * each variable replaced in this way, we add the fresh variable it is * replaced with to , and the unresolved sorts corresponding to the * non-terminal symbol to (constructor args). In other words, @@ -1839,11 +1843,14 @@ class CVC4_PUBLIC Grammar * @param term the term to purify * @param args the free variables in the term returned by this method * @param cargs the sorts of the arguments of the sygus constructor + * @param ntsToUnres mapping from non-terminals to their unresolved sorts * @return the purfied term */ - Term purifySygusGTerm(Term term, - std::vector& args, - std::vector& cargs) const; + Term purifySygusGTerm( + Term term, + std::vector& args, + std::vector& cargs, + const std::unordered_map& ntsToUnres) const; /** * This adds constructors to
for sygus variables in whose @@ -1861,18 +1868,14 @@ class CVC4_PUBLIC Grammar std::vector d_sygusVars; /** The non-terminal symbols of this grammar. */ std::vector d_ntSyms; - /** - * The mapping from non-terminal symbols to the unresolved sorts they - * correspond to. - */ - std::unordered_map d_ntsToUnres; - /** - * The mapping from non-terminal symbols to the datatype declarations they - * correspond to. - */ - std::unordered_map d_dtDecls; + /** The mapping from non-terminal symbols to their production terms. */ + std::unordered_map, TermHashFunction> d_ntsToTerms; /** The set of non-terminals that can be arbitrary constants. */ std::unordered_set d_allowConst; + /** The set of non-terminals that can be sygus variables. */ + std::unordered_set d_allowVars; + /** Did we call resolve() before? */ + bool d_isResolved; }; /* -------------------------------------------------------------------------- */ @@ -2972,7 +2975,7 @@ class CVC4_PUBLIC Solver Term synthFun(const std::string& symbol, const std::vector& boundVars, Sort sort, - Grammar g) const; + Grammar& g) const; /** * Synthesize invariant. @@ -2996,7 +2999,7 @@ class CVC4_PUBLIC Solver */ Term synthInv(const std::string& symbol, const std::vector& boundVars, - Grammar g) const; + Grammar& g) const; /** * Add a forumla to the set of Sygus constraints. diff --git a/test/unit/api/grammar_black.h b/test/unit/api/grammar_black.h index abf37f210..03525f12f 100644 --- a/test/unit/api/grammar_black.h +++ b/test/unit/api/grammar_black.h @@ -58,6 +58,11 @@ void GrammarBlack::testAddRule() TS_ASSERT_THROWS(g.addRule(nts, d_solver->mkBoolean(false)), CVC4ApiException&); TS_ASSERT_THROWS(g.addRule(start, d_solver->mkReal(0)), CVC4ApiException&); + + d_solver->synthFun("f", {}, boolean, g); + + TS_ASSERT_THROWS(g.addRule(start, d_solver->mkBoolean(false)), + CVC4ApiException&); } void GrammarBlack::testAddRules() @@ -75,10 +80,15 @@ void GrammarBlack::testAddRules() TS_ASSERT_THROWS(g.addRules(nullTerm, {d_solver->mkBoolean(false)}), CVC4ApiException&); - TS_ASSERT_THROWS(g.addRule(start, {nullTerm}), CVC4ApiException&); - TS_ASSERT_THROWS(g.addRule(nts, {d_solver->mkBoolean(false)}), + TS_ASSERT_THROWS(g.addRules(start, {nullTerm}), CVC4ApiException&); + TS_ASSERT_THROWS(g.addRules(nts, {d_solver->mkBoolean(false)}), + CVC4ApiException&); + TS_ASSERT_THROWS(g.addRules(start, {d_solver->mkReal(0)}), CVC4ApiException&); + + d_solver->synthFun("f", {}, boolean, g); + + TS_ASSERT_THROWS(g.addRules(start, {d_solver->mkBoolean(false)}), CVC4ApiException&); - TS_ASSERT_THROWS(g.addRule(start, {d_solver->mkReal(0)}), CVC4ApiException&); } void GrammarBlack::testAddAnyConstant() @@ -96,6 +106,10 @@ void GrammarBlack::testAddAnyConstant() TS_ASSERT_THROWS(g.addAnyConstant(nullTerm), CVC4ApiException&); TS_ASSERT_THROWS(g.addAnyConstant(nts), CVC4ApiException&); + + d_solver->synthFun("f", {}, boolean, g); + + TS_ASSERT_THROWS(g.addAnyConstant(start), CVC4ApiException&); } void GrammarBlack::testAddAnyVariable() @@ -114,6 +128,10 @@ void GrammarBlack::testAddAnyVariable() TS_ASSERT_THROWS_NOTHING(g1.addAnyVariable(start)); TS_ASSERT_THROWS_NOTHING(g2.addAnyVariable(start)); - TS_ASSERT_THROWS(g1.addAnyConstant(nullTerm), CVC4ApiException&); - TS_ASSERT_THROWS(g1.addAnyConstant(nts), CVC4ApiException&); + TS_ASSERT_THROWS(g1.addAnyVariable(nullTerm), CVC4ApiException&); + TS_ASSERT_THROWS(g1.addAnyVariable(nts), CVC4ApiException&); + + d_solver->synthFun("f", {}, boolean, g1); + + TS_ASSERT_THROWS(g1.addAnyVariable(start), CVC4ApiException&); } -- 2.30.2