From: Andrew Reynolds Date: Thu, 26 May 2022 20:16:15 +0000 (-0500) Subject: Use function array constants in HO solver (#8818) X-Git-Tag: cvc5-1.0.1~92 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=a318303646773b10d0f7ef1387d78e0aa29b6ade;p=cvc5.git Use function array constants in HO solver (#8818) This makes lambdas rewrite to function array constants when possible. This extends our HO solver and utilities to be robust to check whether a node represents a lambda (uf::FunctionConst::toLambda). This furthermore removes the isConst rule for LAMBDA; lambdas are never constant. The PR also improves our check-model so that warnings are not thrown if rewriting can show that the model value of a term is equivalent modulo rewriting to its representative in the model equality engine. This eliminates the last remaining static calls to rewrite. This is work towards eliminating SmtEngineScope. --- diff --git a/src/expr/array_store_all.h b/src/expr/array_store_all.h index ddcdaa170..24e1257b1 100644 --- a/src/expr/array_store_all.h +++ b/src/expr/array_store_all.h @@ -16,8 +16,8 @@ #include "cvc5_public.h" -#ifndef CVC5__ARRAY_STORE_ALL_H -#define CVC5__ARRAY_STORE_ALL_H +#ifndef CVC5__EXPR__ARRAY_STORE_ALL_H +#define CVC5__EXPR__ARRAY_STORE_ALL_H #include #include diff --git a/src/preprocessing/passes/ho_elim.cpp b/src/preprocessing/passes/ho_elim.cpp index 232e0a2c7..e83315d81 100644 --- a/src/preprocessing/passes/ho_elim.cpp +++ b/src/preprocessing/passes/ho_elim.cpp @@ -24,6 +24,7 @@ #include "options/quantifiers_options.h" #include "preprocessing/assertion_pipeline.h" #include "theory/rewriter.h" +#include "theory/uf/function_const.h" #include "theory/uf/theory_uf_rewriter.h" using namespace cvc5::internal::kind; @@ -51,17 +52,18 @@ Node HoElim::eliminateLambdaComplete(Node n, std::map& newLambda) if (it == d_visited.end()) { - if (cur.getKind() == LAMBDA) + Node lam = theory::uf::FunctionConst::toLambda(cur); + if (!lam.isNull()) { - Trace("ho-elim-ll") << "Lambda lift: " << cur << std::endl; + Trace("ho-elim-ll") << "Lambda lift: " << lam << std::endl; // must also get free variables in lambda std::vector lvars; std::vector ftypes; std::unordered_set fvs; - expr::getFreeVariables(cur, fvs); + expr::getFreeVariables(lam, fvs); std::vector nvars; std::vector vars; - Node sbd = cur[1]; + Node sbd = lam[1]; if (!fvs.empty()) { Trace("ho-elim-ll") @@ -78,20 +80,20 @@ Node HoElim::eliminateLambdaComplete(Node n, std::map& newLambda) sbd = sbd.substitute( vars.begin(), vars.end(), nvars.begin(), nvars.end()); } - for (const Node& bv : cur[0]) + for (const Node& bv : lam[0]) { TypeNode bvt = bv.getType(); ftypes.push_back(bvt); lvars.push_back(bv); } - Node nlambda = cur; + Node nlambda = lam; if (!fvs.empty()) { nlambda = nm->mkNode(LAMBDA, nm->mkNode(BOUND_VAR_LIST, lvars), sbd); Trace("ho-elim-ll") << "...new lambda definition: " << nlambda << std::endl; } - TypeNode rangeType = cur.getType().getRangeType(); + TypeNode rangeType = lam.getType().getRangeType(); TypeNode nft = nm->mkFunctionType(ftypes, rangeType); Node nf = sm->mkDummySkolem("ll", nft); Trace("ho-elim-ll") diff --git a/src/printer/smt2/smt2_printer.cpp b/src/printer/smt2/smt2_printer.cpp index 279046300..ff13b0600 100644 --- a/src/printer/smt2/smt2_printer.cpp +++ b/src/printer/smt2/smt2_printer.cpp @@ -29,6 +29,7 @@ #include "expr/dtype_cons.h" #include "expr/emptybag.h" #include "expr/emptyset.h" +#include "expr/function_array_const.h" #include "expr/node_manager_attributes.h" #include "expr/node_visitor.h" #include "expr/sequence.h" @@ -40,12 +41,13 @@ #include "printer/let_binding.h" #include "proof/unsat_core.h" #include "smt/command.h" -#include "theory/bags/table_project_op.h" #include "theory/arrays/theory_arrays_rewriter.h" +#include "theory/bags/table_project_op.h" #include "theory/datatypes/sygus_datatype_utils.h" #include "theory/datatypes/tuple_project_op.h" #include "theory/quantifiers/quantifiers_attributes.h" #include "theory/theory_model.h" +#include "theory/uf/function_const.h" #include "util/bitvector.h" #include "util/divisible.h" #include "util/floatingpoint.h" @@ -309,6 +311,13 @@ void Smt2Printer::toStream(std::ostream& out, out << ")"; break; } + case kind::FUNCTION_ARRAY_CONST: + { + // prints as the equivalent lambda + Node lam = theory::uf::FunctionConst::toLambda(n); + toStream(out, lam, toDepth); + break; + } case kind::UNINTERPRETED_SORT_VALUE: { diff --git a/src/proof/lfsc/lfsc_node_converter.cpp b/src/proof/lfsc/lfsc_node_converter.cpp index ee644e2d4..738dac6b8 100644 --- a/src/proof/lfsc/lfsc_node_converter.cpp +++ b/src/proof/lfsc/lfsc_node_converter.cpp @@ -31,6 +31,7 @@ #include "theory/bv/theory_bv_utils.h" #include "theory/datatypes/datatypes_rewriter.h" #include "theory/strings/word.h" +#include "theory/uf/function_const.h" #include "theory/uf/theory_uf_rewriter.h" #include "util/bitvector.h" #include "util/floatingpoint.h" @@ -369,6 +370,13 @@ Node LfscNodeConverter::postConvert(Node n) // notice that intentionally we drop annotations here return ret; } + else if (k == FUNCTION_ARRAY_CONST) + { + // must convert to lambda and then run the conversion + Node lam = theory::uf::FunctionConst::toLambda(n); + Assert(!lam.isNull()); + return convert(lam); + } else if (k == REGEXP_LOOP) { // ((_ re.loop n1 n2) t) is ((re.loop n1 n2) t) diff --git a/src/theory/evaluator.cpp b/src/theory/evaluator.cpp index 7caca8427..efe985d78 100644 --- a/src/theory/evaluator.cpp +++ b/src/theory/evaluator.cpp @@ -19,6 +19,7 @@ #include "theory/rewriter.h" #include "theory/strings/theory_strings_utils.h" #include "theory/theory.h" +#include "theory/uf/function_const.h" #include "util/integer.h" using namespace cvc5::internal::kind; @@ -330,9 +331,18 @@ EvalResult Evaluator::evalInternal( { Trace("evaluator") << "Evaluate " << currNode << std::endl; TNode op = currNode.getOperator(); - Assert(evalAsNode.find(op) != evalAsNode.end()); - // no function can be a valid EvalResult - op = evalAsNode[op]; + if (op.getKind() == kind::FUNCTION_ARRAY_CONST) + { + // If we have a function constant as the operator, it was not + // processed. We require converting to a lambda now. + op = uf::FunctionConst::toLambda(op); + } + else + { + Assert(evalAsNode.find(op) != evalAsNode.end()); + // no function can be a valid EvalResult + op = evalAsNode[op]; + } Trace("evaluator") << "Operator evaluated to " << op << std::endl; if (op.getKind() != kind::LAMBDA) { @@ -362,9 +372,10 @@ EvalResult Evaluator::evalInternal( // Lambdas are evaluated in a recursive fashion because each // evaluation requires different substitutions. We use a fresh cache - // since the evaluation of op[1] is under a new substitution and thus - // should not be cached. We could alternatively copy evalAsNode to - // evalAsNodeC but favor avoiding this copy for performance reasons. + // since the evaluation of op[1] is under a new substitution and + // thus should not be cached. We could alternatively copy evalAsNode + // to evalAsNodeC but favor avoiding this copy for performance + // reasons. std::unordered_map evalAsNodeC; std::unordered_map resultsC; results[currNode] = evalInternal( diff --git a/src/theory/quantifiers/oracle_checker.cpp b/src/theory/quantifiers/oracle_checker.cpp index 13a2e7630..b1fb3aec5 100644 --- a/src/theory/quantifiers/oracle_checker.cpp +++ b/src/theory/quantifiers/oracle_checker.cpp @@ -109,7 +109,7 @@ Node OracleChecker::postConvert(Node n) } } // otherwise, always rewrite - return Rewriter::rewrite(n); + return rewrite(n); } bool OracleChecker::hasOracles() const { return !d_callers.empty(); } bool OracleChecker::hasOracleCalls(Node f) const diff --git a/src/theory/strings/term_registry.cpp b/src/theory/strings/term_registry.cpp index 728f4b047..1e7da4c70 100644 --- a/src/theory/strings/term_registry.cpp +++ b/src/theory/strings/term_registry.cpp @@ -15,7 +15,6 @@ #include "theory/strings/term_registry.h" -#include "expr/attribute.h" #include "options/smt_options.h" #include "options/strings_options.h" #include "printer/smt2/smt2_printer.h" diff --git a/src/theory/theory_model.cpp b/src/theory/theory_model.cpp index 52c78151c..d33f81fe7 100644 --- a/src/theory/theory_model.cpp +++ b/src/theory/theory_model.cpp @@ -24,6 +24,7 @@ #include "smt/env.h" #include "smt/solver_engine.h" #include "theory/trust_substitutions.h" +#include "theory/uf/function_const.h" #include "util/rational.h" using namespace std; @@ -138,7 +139,12 @@ Node TheoryModel::getValue(TNode n) const { return nn; } - else if (nn.getKind() == kind::LAMBDA) + if (nn.getKind() == kind::FUNCTION_ARRAY_CONST) + { + // return the lambda instead + nn = uf::FunctionConst::toLambda(nn); + } + if (nn.getKind() == kind::LAMBDA) { if (options().theory.condenseFunctionValues) { diff --git a/src/theory/theory_model_builder.cpp b/src/theory/theory_model_builder.cpp index c29d89fde..c83ce2b63 100644 --- a/src/theory/theory_model_builder.cpp +++ b/src/theory/theory_model_builder.cpp @@ -23,6 +23,7 @@ #include "options/uf_options.h" #include "smt/env.h" #include "theory/rewriter.h" +#include "theory/uf/function_const.h" #include "theory/uf/theory_uf_model.h" #include "util/uninterpreted_sort_value.h" @@ -1171,23 +1172,24 @@ void TheoryEngineModelBuilder::debugCheckModel(TheoryModel* tm) << "Representative " << rep << " of " << n << " violates type constraints (" << rep.getType() << " and " << n.getType() << ")"; - Node val = tm->getValue(*eqc_i); + Node val = tm->getValue(n); if (val != rep) { std::stringstream err; err << "Failed representative check:" << std::endl << "( " << repCheckInstance << ") " - << "n: " << n << endl - << "getValue(n): " << tm->getValue(n) << std::endl + << "n: " << n << std::endl + << "getValue(n): " << val << std::endl << "rep: " << rep << std::endl; if (val.isConst() && rep.isConst()) { AlwaysAssert(val == rep) << err.str(); } - else + else if (rewrite(val) != rewrite(rep)) { // if it does not evaluate, it is just a warning, which may be the - // case for non-constant values, e.g. lambdas. + // case for non-constant values, e.g. lambdas. Furthermore we only + // throw this warning if rewriting cannot show they are equal. warning() << err.str(); } } @@ -1359,7 +1361,10 @@ void TheoryEngineModelBuilder::assignHoFunction(TheoryModel* m, Node f) Assert(hnv.isConst()); if (!apply_args.empty()) { - Assert(hnv.getKind() == kind::LAMBDA + // Convert to lambda, which is necessary if hnv is a function array + // constant. + hnv = uf::FunctionConst::toLambda(hnv); + Assert(!hnv.isNull() && hnv.getKind() == kind::LAMBDA && hnv[0].getNumChildren() + 1 == args.size()); std::vector largs; for (unsigned j = 0; j < hnv[0].getNumChildren(); j++) diff --git a/src/theory/uf/function_const.cpp b/src/theory/uf/function_const.cpp index 27b39018a..b2bde1306 100644 --- a/src/theory/uf/function_const.cpp +++ b/src/theory/uf/function_const.cpp @@ -16,13 +16,73 @@ #include "theory/uf/function_const.h" #include "expr/array_store_all.h" +#include "expr/attribute.h" +#include "expr/bound_var_manager.h" +#include "expr/function_array_const.h" #include "theory/arrays/theory_arrays_rewriter.h" #include "theory/rewriter.h" +#include "util/rational.h" namespace cvc5::internal { namespace theory { namespace uf { +/** + * Attribute for constructing a unique bound variable list for the lambda + * corresponding to an array constant. + */ +struct FunctionBoundVarListTag +{ +}; +using FunctionBoundVarListAttribute = + expr::Attribute; +/** + * An attribute to cache the conversion between array constants and lambdas. + */ +struct ArrayToLambdaTag +{ +}; +using ArrayToLambdaAttribute = expr::Attribute; + +Node FunctionConst::toLambda(TNode n) +{ + Kind nk = n.getKind(); + if (nk == kind::LAMBDA) + { + return n; + } + else if (nk == kind::FUNCTION_ARRAY_CONST) + { + ArrayToLambdaAttribute atla; + if (n.hasAttribute(atla)) + { + return n.getAttribute(atla); + } + const FunctionArrayConst& fc = n.getConst(); + Node avalue = fc.getArrayValue(); + TypeNode tn = fc.getType(); + Assert(tn.isFunction()); + std::vector argTypes = tn.getArgTypes(); + std::vector bvs; + NodeManager* nm = NodeManager::currentNM(); + BoundVarManager* bvm = nm->getBoundVarManager(); + // associate a unique bound variable list with the value + for (size_t i = 0, nargs = argTypes.size(); i < nargs; i++) + { + Node cacheVal = + BoundVarManager::getCacheValue(n, nm->mkConstInt(Rational(i))); + Node v = + bvm->mkBoundVar(cacheVal, argTypes[i]); + bvs.push_back(v); + } + Node bvl = nm->mkNode(kind::BOUND_VAR_LIST, bvs); + Node lam = getLambdaForArrayRepresentation(avalue, bvl); + n.setAttribute(atla, lam); + return lam; + } + return Node::null(); +} + TypeNode FunctionConst::getFunctionTypeForArrayType(TypeNode atn, Node bvl) { std::vector children; @@ -112,7 +172,6 @@ Node FunctionConst::getLambdaForArrayRepresentation(TNode a, TNode bvl) Node body = getLambdaForArrayRepresentationRec(a, bvl, 0, visited); if (!body.isNull()) { - body = Rewriter::rewrite(body); Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl; return NodeManager::currentNM()->mkNode(kind::LAMBDA, bvl, body); @@ -269,13 +328,6 @@ Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n, return Node::null(); } } - else if (Rewriter::rewrite(index_eq) != index_eq) - { - // equality must be oriented correctly based on rewriter - Trace("builtin-rewrite-debug2") - << " ...equality not oriented properly." << std::endl; - return Node::null(); - } // [3] We ensure that "index_eq" is an equality that is equivalent to // "first_arg" = "curr_index", where curr_index is a constant, and @@ -395,13 +447,22 @@ Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n, return Node::null(); } -Node FunctionConst::getArrayRepresentationForLambda(TNode n) +Node FunctionConst::toArrayConst(TNode n) { - Assert(n.getKind() == kind::LAMBDA); - // must carry the overall return type to deal with cases like (lambda ((x Int) - // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else - // case above should be (arraystoreall (Array Int Real) 0.0) - return getArrayRepresentationForLambdaRec(n, n[1].getType()); + Kind nk = n.getKind(); + if (nk == kind::FUNCTION_ARRAY_CONST) + { + const FunctionArrayConst& fc = n.getConst(); + return fc.getArrayValue(); + } + else if (nk == kind::LAMBDA) + { + // must carry the overall return type to deal with cases like (lambda ((x + // Int) (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for + // the else case above should be (arraystoreall (Array Int Real) 0.0) + return getArrayRepresentationForLambdaRec(n, n[1].getType()); + } + return Node::null(); } } // namespace uf diff --git a/src/theory/uf/function_const.h b/src/theory/uf/function_const.h index b74503caf..02324194e 100644 --- a/src/theory/uf/function_const.h +++ b/src/theory/uf/function_const.h @@ -53,6 +53,35 @@ class FunctionConst * getArrayRepresentationForLambda( t ), where t.getType()=ftn. */ static TypeNode getArrayTypeForFunctionType(TypeNode ftn); + /** + * Returns a node of kind LAMBDA that is equivalent to n, or null otherwise. + * + * This is the identity function for lambda terms and runs the conversion + * for constant array functions, and null for all other nodes. For details, + * see the method getLambdaForArrayRepresentation. + */ + static Node toLambda(TNode n); + /** + * Extracts the array constant from the payload of a a function array constant + * + * + * Given a lambda expression n, returns an array term that corresponds to n. + * This does the opposite direction of the examples described above the + * method getLambdaForArrayRepresentation. + * + * We limit the return values of this method to be almost constant functions, + * that is, arrays of the form: + * (store ... (store (storeall _ b) i1 e1) ... in en) + * where b, i1, e1, ..., in, en are constants. + * Notice however that the return value of this form need not be an + * array such that isConst is true. + * + * If it is not possible to construct an array of this form that corresponds + * to n, this method returns null. + */ + static Node toArrayConst(TNode n); + + private: /** * Given an array constant a, returns a lambda expression that it corresponds * to, with bound variable list bvl. @@ -76,30 +105,13 @@ class FunctionConst * (lambda x. (ite (= x 1) true (= x 2))) */ static Node getLambdaForArrayRepresentation(TNode a, TNode bvl); - /** - * Given a lambda expression n, returns an array term that corresponds to n. - * This does the opposite direction of the examples described above. - * - * We limit the return values of this method to be almost constant functions, - * that is, arrays of the form: - * (store ... (store (storeall _ b) i1 e1) ... in en) - * where b, i1, e1, ..., in, en are constants. - * Notice however that the return value of this form need not be a (canonical) - * array constant. - * - * If it is not possible to construct an array of this form that corresponds - * to n, this method returns null. - */ - static Node getArrayRepresentationForLambda(TNode n); - - private: /** recursive helper for getLambdaForArrayRepresentation */ static Node getLambdaForArrayRepresentationRec( TNode a, TNode bvl, unsigned bvlIndex, std::unordered_map& visited); - /** recursive helper for getArrayRepresentationForLambda */ + /** recursive helper for toArrayConst */ static Node getArrayRepresentationForLambdaRec(TNode n, TypeNode retType); }; diff --git a/src/theory/uf/ho_extension.cpp b/src/theory/uf/ho_extension.cpp index 521ab4f9c..b0359fa48 100644 --- a/src/theory/uf/ho_extension.cpp +++ b/src/theory/uf/ho_extension.cpp @@ -19,6 +19,7 @@ #include "expr/skolem_manager.h" #include "options/uf_options.h" #include "theory/theory_model.h" +#include "theory/uf/function_const.h" #include "theory/uf/lambda_lift.h" #include "theory/uf/theory_uf_rewriter.h" @@ -100,7 +101,7 @@ TrustNode HoExtension::ppRewrite(Node node, std::vector& lems) } } } - else if (k == kind::LAMBDA) + else if (k == kind::LAMBDA || k == kind::FUNCTION_ARRAY_CONST) { Trace("uf-lazy-ll") << "Preprocess lambda: " << node << std::endl; TrustNode skTrn = d_ll.ppRewrite(node, lems); diff --git a/src/theory/uf/kinds b/src/theory/uf/kinds index 304679df2..0837f2902 100644 --- a/src/theory/uf/kinds +++ b/src/theory/uf/kinds @@ -33,9 +33,6 @@ typerule LAMBDA ::cvc5::internal::theory::uf::LambdaTypeRule variable BOOLEAN_TERM_VARIABLE "Boolean term variable" -# lambda expressions that are isomorphic to array constants can be considered constants -construle LAMBDA ::cvc5::internal::theory::uf::LambdaTypeRule - operator HO_APPLY 2 "higher-order (partial) function application" typerule HO_APPLY ::cvc5::internal::theory::uf::HoApplyTypeRule diff --git a/src/theory/uf/lambda_lift.cpp b/src/theory/uf/lambda_lift.cpp index e9313278c..7e1823dfc 100644 --- a/src/theory/uf/lambda_lift.cpp +++ b/src/theory/uf/lambda_lift.cpp @@ -19,6 +19,7 @@ #include "expr/skolem_manager.h" #include "options/uf_options.h" #include "smt/env.h" +#include "theory/uf/function_const.h" using namespace cvc5::internal::kind; @@ -59,15 +60,16 @@ TrustNode LambdaLift::lift(Node node) TrustNode LambdaLift::ppRewrite(Node node, std::vector& lems) { - TNode skolem = getSkolemFor(node); + Node lam = FunctionConst::toLambda(node); + TNode skolem = getSkolemFor(lam); if (skolem.isNull()) { return TrustNode::null(); } - d_lambdaMap[skolem] = node; + d_lambdaMap[skolem] = lam; if (!options().uf.ufHoLazyLambdaLift) { - TrustNode trn = lift(node); + TrustNode trn = lift(lam); lems.push_back(SkolemLemma(trn, skolem)); } // if no proofs, return lemma with no generator @@ -102,21 +104,21 @@ Node LambdaLift::getAssertionFor(TNode node) { return Node::null(); } - Kind k = node.getKind(); Node assertion; - if (k == LAMBDA) + Node lambda = FunctionConst::toLambda(node); + if (!lambda.isNull()) { NodeManager* nm = NodeManager::currentNM(); // The new assertion std::vector children; // bound variable list - children.push_back(node[0]); + children.push_back(lambda[0]); // body std::vector skolem_app_c; skolem_app_c.push_back(skolem); - skolem_app_c.insert(skolem_app_c.end(), node[0].begin(), node[0].end()); + skolem_app_c.insert(skolem_app_c.end(), lambda[0].begin(), lambda[0].end()); Node skolem_app = nm->mkNode(APPLY_UF, skolem_app_c); - skolem_app_c[0] = node; + skolem_app_c[0] = lambda; Node rhs = nm->mkNode(APPLY_UF, skolem_app_c); // For the sake of proofs, we use // (= (k t1 ... tn) ((lambda (x1 ... xn) s) t1 ... tn)) here. This is instead of diff --git a/src/theory/uf/theory_uf_rewriter.cpp b/src/theory/uf/theory_uf_rewriter.cpp index 0f326bdd0..73c36ade7 100644 --- a/src/theory/uf/theory_uf_rewriter.cpp +++ b/src/theory/uf/theory_uf_rewriter.cpp @@ -15,6 +15,7 @@ #include "theory/uf/theory_uf_rewriter.h" +#include "expr/function_array_const.h" #include "expr/node_algorithm.h" #include "theory/rewriter.h" #include "theory/substitutions.h" @@ -53,11 +54,11 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node) } if (node.getKind() == kind::APPLY_UF) { - if (node.getOperator().getKind() == kind::LAMBDA) + Node lambda = FunctionConst::toLambda(node.getOperator()); + if (!lambda.isNull()) { - Trace("uf-ho-beta") << "uf-ho-beta : beta-reducing all args of : " << node - << "\n"; - TNode lambda = node.getOperator(); + Trace("uf-ho-beta") << "uf-ho-beta : beta-reducing all args of : " + << lambda << " for " << node << "\n"; Node ret; // build capture-avoiding substitution since in HOL shadowing may have // been introduced @@ -102,17 +103,18 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node) } else if (node.getKind() == kind::HO_APPLY) { - if (node[0].getKind() == kind::LAMBDA) + Node lambda = FunctionConst::toLambda(node[0]); + if (!lambda.isNull()) { // resolve one argument of the lambda Trace("uf-ho-beta") << "uf-ho-beta : beta-reducing one argument of : " - << node[0] << " with " << node[1] << "\n"; + << lambda << " with " << node[1] << "\n"; // reconstruct the lambda first to avoid variable shadowing - Node new_body = node[0][1]; - if (node[0][0].getNumChildren() > 1) + Node new_body = lambda[1]; + if (lambda[0].getNumChildren() > 1) { - std::vector new_vars(node[0][0].begin() + 1, node[0][0].end()); + std::vector new_vars(lambda[0].begin() + 1, lambda[0].end()); std::vector largs; largs.push_back( NodeManager::currentNM()->mkNode(kind::BOUND_VAR_LIST, new_vars)); @@ -127,13 +129,13 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node) if (d_isHigherOrder) { Node arg = node[1]; - Node var = node[0][0][0]; + Node var = lambda[0][0]; new_body = expr::substituteCaptureAvoiding(new_body, var, arg); } else { TNode arg = node[1]; - TNode var = node[0][0][0]; + TNode var = lambda[0][0]; new_body = new_body.substitute(var, arg); } Trace("uf-ho-beta") << "uf-ho-beta : ..new body : " << new_body << "\n"; @@ -221,7 +223,7 @@ Node TheoryUfRewriter::rewriteLambda(Node node) // normalization on array constants, and then converting the array constant // back to a lambda. Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl; - Node anode = FunctionConst::getArrayRepresentationForLambda(node); + Node anode = FunctionConst::toArrayConst(node); // Only rewrite constant array nodes, since these are the only cases // where we require canonicalization of lambdas. Moreover, applying the // below code is not correct if the arguments to the lambda occur @@ -231,26 +233,12 @@ Node TheoryUfRewriter::rewriteLambda(Node node) if (!anode.isNull() && anode.isConst()) { Assert(anode.getType().isArray()); - // must get the standard bound variable list - Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( - node.getType()); - Node retNode = - FunctionConst::getLambdaForArrayRepresentation(anode, varList); - if (!retNode.isNull() && retNode != node) - { - Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl; - Trace("builtin-rewrite") << " input : " << node << std::endl; - Trace("builtin-rewrite") - << " output : " << retNode << ", constant = " << retNode.isConst() - << std::endl; - Trace("builtin-rewrite") - << " array rep : " << anode << ", constant = " << anode.isConst() - << std::endl; - Assert(anode.isConst() == retNode.isConst()); - Assert(retNode.getType() == node.getType()); - Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode)); - return retNode; - } + Node retNode = NodeManager::currentNM()->mkConst( + FunctionArrayConst(node.getType(), anode)); + Assert(anode.isConst() == retNode.isConst()); + Assert(retNode.getType() == node.getType()); + Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode)); + return retNode; } else { diff --git a/src/theory/uf/theory_uf_type_rules.cpp b/src/theory/uf/theory_uf_type_rules.cpp index 180504da2..1f9bc7b14 100644 --- a/src/theory/uf/theory_uf_type_rules.cpp +++ b/src/theory/uf/theory_uf_type_rules.cpp @@ -176,53 +176,6 @@ TypeNode LambdaTypeRule::computeType(NodeManager* nodeManager, return nodeManager->mkFunctionType(argTypes, rangeType); } -bool LambdaTypeRule::computeIsConst(NodeManager* nodeManager, TNode n) -{ - Assert(n.getKind() == kind::LAMBDA); - // get array representation of this function, if possible - Node na = FunctionConst::getArrayRepresentationForLambda(n); - if (!na.isNull()) - { - Assert(na.getType().isArray()); - Trace("lambda-const") << "Array representation for " << n << " is " << na - << " " << na.getType() << std::endl; - // must have the standard bound variable list - Node bvl = - NodeManager::currentNM()->getBoundVarListForFunctionType(n.getType()); - if (bvl == n[0]) - { - // array must be constant - if (na.isConst()) - { - Trace("lambda-const") << "*** Constant lambda : " << n; - Trace("lambda-const") << " since its array representation : " << na - << " is constant." << std::endl; - return true; - } - else - { - Trace("lambda-const") << "Non-constant lambda : " << n - << " since array is not constant." << std::endl; - } - } - else - { - Trace("lambda-const") - << "Non-constant lambda : " << n - << " since its varlist is not standard." << std::endl; - Trace("lambda-const") << " standard : " << bvl << std::endl; - Trace("lambda-const") << " current : " << n[0] << std::endl; - } - } - else - { - Trace("lambda-const") << "Non-constant lambda : " << n - << " since it has no array representation." - << std::endl; - } - return false; -} - TypeNode FunctionArrayConstTypeRule::computeType(NodeManager* nodeManager, TNode n, bool check) diff --git a/src/theory/uf/theory_uf_type_rules.h b/src/theory/uf/theory_uf_type_rules.h index 12fc2d679..c75d8c169 100644 --- a/src/theory/uf/theory_uf_type_rules.h +++ b/src/theory/uf/theory_uf_type_rules.h @@ -98,9 +98,6 @@ class LambdaTypeRule { public: static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); - // computes whether a lambda is a constant value, via conversion to array - // representation - static bool computeIsConst(NodeManager* nodeManager, TNode n); }; /* class LambdaTypeRule */ /** @@ -111,7 +108,7 @@ class FunctionArrayConstTypeRule { public: static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); -}; /* class LambdaTypeRule */ +}; class FunctionProperties { diff --git a/src/theory/uf/type_enumerator.cpp b/src/theory/uf/type_enumerator.cpp index 84fafa6b8..a0151e777 100644 --- a/src/theory/uf/type_enumerator.cpp +++ b/src/theory/uf/type_enumerator.cpp @@ -15,6 +15,7 @@ #include "theory/uf/type_enumerator.h" +#include "expr/function_array_const.h" #include "theory/uf/function_const.h" namespace cvc5::internal { @@ -27,7 +28,6 @@ FunctionEnumerator::FunctionEnumerator(TypeNode type, d_arrayEnum(FunctionConst::getArrayTypeForFunctionType(type), tep) { Assert(type.getKind() == kind::FUNCTION_TYPE); - d_bvl = NodeManager::currentNM()->getBoundVarListForFunctionType(type); } Node FunctionEnumerator::operator*() @@ -37,7 +37,7 @@ Node FunctionEnumerator::operator*() throw NoMoreValuesException(getType()); } Node a = *d_arrayEnum; - return FunctionConst::getLambdaForArrayRepresentation(a, d_bvl); + return NodeManager::currentNM()->mkConst(FunctionArrayConst(getType(), a)); } FunctionEnumerator& FunctionEnumerator::operator++() diff --git a/src/theory/uf/type_enumerator.h b/src/theory/uf/type_enumerator.h index 75ea631de..66f4ba0b8 100644 --- a/src/theory/uf/type_enumerator.h +++ b/src/theory/uf/type_enumerator.h @@ -45,11 +45,6 @@ class FunctionEnumerator : public TypeEnumeratorBase private: /** Enumerates arrays, which we convert to functions. */ TypeEnumerator d_arrayEnum; - /** The bound variable list for the function type we are enumerating. - * All terms output by this enumerator are of the form (LAMBDA d_bvl t) for - * some term t. - */ - Node d_bvl; }; /* class FunctionEnumerator */ } // namespace uf