From 6fe7877d82721e453d5d928a8fe9dbad2099dac1 Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Mon, 18 Nov 2019 13:52:18 -0600 Subject: [PATCH] Use standard sygus interface for abduction and rewrite rule synthesis (#3471) --- src/preprocessing/passes/synth_rew_rules.cpp | 77 ++++++++++--------- src/theory/quantifiers/sygus/sygus_abduct.cpp | 29 ++++--- 2 files changed, 58 insertions(+), 48 deletions(-) diff --git a/src/preprocessing/passes/synth_rew_rules.cpp b/src/preprocessing/passes/synth_rew_rules.cpp index 6d6e8fb27..47e64b2e4 100644 --- a/src/preprocessing/passes/synth_rew_rules.cpp +++ b/src/preprocessing/passes/synth_rew_rules.cpp @@ -15,6 +15,7 @@ #include "preprocessing/passes/synth_rew_rules.h" +#include "expr/sygus_datatype.h" #include "expr/term_canonize.h" #include "options/base_options.h" #include "options/quantifiers_options.h" @@ -236,14 +237,13 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( Trace("srs-input") << "...finished." << std::endl; // the sygus variable list Node sygusVarList = nm->mkNode(BOUND_VAR_LIST, allVars); - Expr sygusVarListE = sygusVarList.toExpr(); Trace("srs-input") << "Have " << cterms.size() << " canonical subterms." << std::endl; Trace("srs-input") << "Construct unresolved types..." << std::endl; // each canonical subterm corresponds to a grammar type std::set unres; - std::vector datatypes; + std::vector sdts; // make unresolved types for each canonical term std::map cterm_to_utype; for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) @@ -255,11 +255,11 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( TypeNode tnu = nm->mkSort(tname, ExprManager::SORT_FLAG_PLACEHOLDER); cterm_to_utype[ct] = tnu; unres.insert(tnu.toType()); - datatypes.push_back(Datatype(tname)); + sdts.push_back(SygusDatatype(tname)); } Trace("srs-input") << "...finished." << std::endl; - Trace("srs-input") << "Construct datatypes..." << std::endl; + Trace("srs-input") << "Construct sygus datatypes..." << std::endl; for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) { Node ct = cterms[i]; @@ -268,7 +268,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( // add the variables for the type TypeNode ctt = ct.getType(); Assert(tvars.find(ctt) != tvars.end()); - std::vector argList; + std::vector argList; // we add variable constructors if we are not Boolean, we are interested // in purely propositional rewrites (via the option), or this term is // a Boolean variable. @@ -279,7 +279,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( { std::stringstream ssc; ssc << "C_" << i << "_" << v; - datatypes[i].addSygusConstructor(v.toExpr(), ssc.str(), argList); + sdts[i].addConstructor(v, ssc.str(), argList); } } // add the constructor for the operator if it is not a variable @@ -295,7 +295,7 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( Node ctc = term_to_cterm[tc]; Assert(cterm_to_utype.find(ctc) != cterm_to_utype.end()); // get the type - argList.push_back(cterm_to_utype[ctc].toType()); + argList.push_back(cterm_to_utype[ctc]); } // check if we should chain bool do_chain = false; @@ -305,12 +305,12 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( do_chain = theory::quantifiers::TermUtil::isAssoc(k) && theory::quantifiers::TermUtil::isComm(k); // eliminate duplicate child types - std::vector argListTmp = argList; + std::vector argListTmp = argList; argList.clear(); - std::map hasArgType; + std::map hasArgType; for (unsigned j = 0, size = argListTmp.size(); j < size; j++) { - Type t = argListTmp[j]; + TypeNode t = argListTmp[j]; if (hasArgType.find(t) == hasArgType.end()) { hasArgType[t] = true; @@ -323,9 +323,9 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( // we make one type per child // the operator of each constructor is a no-op Node tbv = nm->mkBoundVar(ctt); - Expr lambdaOp = - nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr(); - std::vector argListc; + Node lambdaOp = + nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv); + std::vector argListc; // the following construction admits any number of repeated factors, // so for instance, t1+t2+t3, we generate the grammar: // T_{t1+t2+t3} -> @@ -341,44 +341,49 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( std::stringstream sscs; sscs << "C_factor_" << i << "_" << j; // ID function is not printed and does not count towards weight - datatypes[i].addSygusConstructor( - lambdaOp, - sscs.str(), - argListc, - printer::SygusEmptyPrintCallback::getEmptyPC(), - 0); + sdts[i].addConstructor(lambdaOp, + sscs.str(), + argListc, + printer::SygusEmptyPrintCallback::getEmptyPC(), + 0); } // recursive apply - Type recType = cterm_to_utype[ct].toType(); + TypeNode recType = cterm_to_utype[ct]; argListc.clear(); argListc.push_back(recType); argListc.push_back(recType); std::stringstream ssc; ssc << "C_" << i << "_rec_" << op; - datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argListc); + sdts[i].addConstructor(op, ssc.str(), argListc); } else { std::stringstream ssc; ssc << "C_" << i << "_" << op; - datatypes[i].addSygusConstructor(op.toExpr(), ssc.str(), argList); + sdts[i].addConstructor(op, ssc.str(), argList); } } - Assert(datatypes[i].getNumConstructors() > 0); - datatypes[i].setSygus(ctt.toType(), sygusVarListE, false, false); + Assert(sdts[i].getNumConstructors() > 0); + sdts[i].initializeDatatype(ctt, sygusVarList, false, false); } Trace("srs-input") << "...finished." << std::endl; Trace("srs-input") << "Make mutual datatype types for subterms..." << std::endl; + // extract the datatypes + std::vector datatypes; + for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++) + { + datatypes.push_back(sdts[i].getDatatype()); + } std::vector types = nm->toExprManager()->mkMutualDatatypeTypes( datatypes, unres, ExprManager::DATATYPE_FLAG_PLACEHOLDER); Trace("srs-input") << "...finished." << std::endl; Assert(types.size() == unres.size()); - std::map subtermTypes; + std::map subtermTypes; for (unsigned i = 0, ncterms = cterms.size(); i < ncterms; i++) { - subtermTypes[cterms[i]] = types[i]; + subtermTypes[cterms[i]] = TypeNode::fromType(types[i]); } Trace("srs-input") << "Construct the top-level types..." << std::endl; @@ -389,34 +394,34 @@ PreprocessingPassResult SynthRewRulesPass::applyInternal( TypeNode t = tcp.first; std::stringstream ss; ss << "T_" << t; - Datatype dttl(ss.str()); + SygusDatatype sdttl(ss.str()); Node tbv = nm->mkBoundVar(t); // the operator of each constructor is a no-op - Expr lambdaOp = - nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv).toExpr(); + Node lambdaOp = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, tbv), tbv); Trace("srs-input") << " We have " << tcp.second.size() << " subterms of type " << t << std::endl; for (unsigned i = 0, size = tcp.second.size(); i < size; i++) { Node n = tcp.second[i]; // add constructor that encodes abstractions of this subterm - std::vector argList; + std::vector argList; Assert(subtermTypes.find(n) != subtermTypes.end()); argList.push_back(subtermTypes[n]); std::stringstream ssc; ssc << "Ctl_" << i; // the no-op should not be printed, hence we pass an empty callback - dttl.addSygusConstructor(lambdaOp, - ssc.str(), - argList, - printer::SygusEmptyPrintCallback::getEmptyPC(), - 0); + sdttl.addConstructor(lambdaOp, + ssc.str(), + argList, + printer::SygusEmptyPrintCallback::getEmptyPC(), + 0); Trace("srs-input-debug") << "Grammar for subterm " << n << " is: " << std::endl; Trace("srs-input-debug") << subtermTypes[n].getDatatype() << std::endl; } // set that this is a sygus datatype - dttl.setSygus(t.toType(), sygusVarListE, false, false); + sdttl.initializeDatatype(t, sygusVarList, false, false); + Datatype dttl = sdttl.getDatatype(); DatatypeType tlt = nm->toExprManager()->mkDatatypeType( dttl, ExprManager::DATATYPE_FLAG_PLACEHOLDER); tlGrammarTypes[t] = TypeNode::fromType(tlt); diff --git a/src/theory/quantifiers/sygus/sygus_abduct.cpp b/src/theory/quantifiers/sygus/sygus_abduct.cpp index 529ef037f..0396aba86 100644 --- a/src/theory/quantifiers/sygus/sygus_abduct.cpp +++ b/src/theory/quantifiers/sygus/sygus_abduct.cpp @@ -17,6 +17,7 @@ #include "expr/datatype.h" #include "expr/node_algorithm.h" +#include "expr/sygus_datatype.h" #include "printer/sygus_print_callback.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/quantifiers/quantifiers_rewriter.h" @@ -86,7 +87,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, { Assert(abdGType.isDatatype() && abdGType.getDatatype().isSygus()); // must convert all constructors to version with bound variables in "vars" - std::vector datatypes; + std::vector sdts; std::set unres; Trace("sygus-abduct-debug") << "Process abduction type:" << std::endl; @@ -129,9 +130,9 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, const Datatype& dtc = curr.getDatatype(); std::stringstream ssdtn; ssdtn << dtc.getName() << "_s"; - datatypes.push_back(Datatype(ssdtn.str())); + sdts.push_back(SygusDatatype(ssdtn.str())); Trace("sygus-abduct-debug") - << "Process datatype " << datatypes.back().getName() << "..." + << "Process datatype " << sdts.back().getName() << "..." << std::endl; for (unsigned j = 0, ncons = dtc.getNumConstructors(); j < ncons; j++) { @@ -141,7 +142,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, syms.begin(), syms.end(), varlist.begin(), varlist.end()); Trace("sygus-abduct-debug") << " Process constructor " << op << " / " << ops << "..." << std::endl; - std::vector cargs; + std::vector cargs; for (unsigned k = 0, nargs = dtc[j].getNumArgs(); k < nargs; k++) { TypeNode argt = TypeNode::fromType(dtc[j].getArgType(k)); @@ -167,7 +168,7 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, } Trace("sygus-abduct-debug") << " Arg #" << k << ": " << argtNew << std::endl; - cargs.push_back(argtNew.toType()); + cargs.push_back(argtNew); } // callback prints as the expression std::shared_ptr spc; @@ -191,22 +192,26 @@ Node SygusAbduct::mkAbductionConjecture(const std::string& name, ss << ops.getKind(); Trace("sygus-abduct-debug") << "Add constructor : " << ops << std::endl; - datatypes.back().addSygusConstructor( - ops.toExpr(), ss.str(), cargs, spc); + sdts.back().addConstructor(ops, ss.str(), cargs, spc); } Trace("sygus-abduct-debug") << "Set sygus : " << dtc.getSygusType() << " " << abvl << std::endl; - datatypes.back().setSygus(dtc.getSygusType(), - abvl.toExpr(), - dtc.getSygusAllowConst(), - dtc.getSygusAllowAll()); + TypeNode stn = TypeNode::fromType(dtc.getSygusType()); + sdts.back().initializeDatatype( + stn, abvl, dtc.getSygusAllowConst(), dtc.getSygusAllowAll()); } dtToProcess.clear(); dtToProcess.insert( dtToProcess.end(), dtNextToProcess.begin(), dtNextToProcess.end()); } Trace("sygus-abduct-debug") - << "Make " << datatypes.size() << " datatype types..." << std::endl; + << "Make " << sdts.size() << " datatype types..." << std::endl; + // extract the datatypes + std::vector datatypes; + for (unsigned i = 0, ndts = sdts.size(); i < ndts; i++) + { + datatypes.push_back(sdts[i].getDatatype()); + } // make the datatype types std::vector datatypeTypes = nm->toExprManager()->mkMutualDatatypeTypes( -- 2.30.2