From: Andrew Reynolds Date: Fri, 5 Nov 2021 20:12:22 +0000 (-0500) Subject: Move functions and lambdas from builtin to uf (#7570) X-Git-Tag: cvc5-1.0.0~878 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=bc6f79ab1ca703b3507fc43e438f17b4422360b8;p=cvc5.git Move functions and lambdas from builtin to uf (#7570) This is in preparation for adding better native support for handling lambdas in the higher-order extension of the UF theory. We require that LAMBDA and function types belong to theory UF so that the theory solver is properly notified. This also splits the utility methods for computing whether a function is "constant" to its own file. This PR is code move only. --- diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index 1adf40695..c526bd13b 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -1117,6 +1117,8 @@ libcvc5_add_sources( theory/uf/equality_engine_types.h theory/uf/eq_proof.cpp theory/uf/eq_proof.h + theory/uf/function_const.cpp + theory/uf/function_const.h theory/uf/proof_checker.cpp theory/uf/proof_checker.h theory/uf/proof_equality_engine.cpp @@ -1133,6 +1135,8 @@ libcvc5_add_sources( theory/uf/theory_uf_rewriter.h theory/uf/theory_uf_type_rules.cpp theory/uf/theory_uf_type_rules.h + theory/uf/type_enumerator.cpp + theory/uf/type_enumerator.h theory/valuation.cpp theory/valuation.h ) diff --git a/src/theory/builtin/kinds b/src/theory/builtin/kinds index d4a8782b5..381573a12 100644 --- a/src/theory/builtin/kinds +++ b/src/theory/builtin/kinds @@ -262,7 +262,8 @@ parameterized SORT_TYPE SORT_TAG 0: "specifies types of user-declared 'uninterpr cardinality SORT_TYPE "Cardinality(Cardinality::INTEGERS)" well-founded SORT_TYPE \ "::cvc5::theory::builtin::SortProperties::isWellFounded(%TYPE%)" \ - "::cvc5::theory::builtin::SortProperties::mkGroundTerm(%TYPE%)" + "::cvc5::theory::builtin::SortProperties::mkGroundTerm(%TYPE%)" \ + "theory/builtin/theory_builtin_type_rules.h" constant UNINTERPRETED_CONSTANT \ class \ @@ -301,8 +302,6 @@ variable BOUND_VARIABLE "a bound variable (permitted in bindings and the associa variable SKOLEM "a Skolem variable (internal only)" operator SEXPR 0: "a symbolic expression (any arity)" -operator LAMBDA 2 "a lambda expression; first parameter is a BOUND_VAR_LIST, second is lambda body" - operator WITNESS 2 "a witness expression; first parameter is a BOUND_VAR_LIST, second is the witness body" constant TYPE_CONSTANT \ @@ -311,17 +310,6 @@ constant TYPE_CONSTANT \ ::cvc5::TypeConstantHashFunction \ "expr/kind.h" \ "a representation for basic types" -operator FUNCTION_TYPE 2: "a function type" -cardinality FUNCTION_TYPE \ - "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \ - "theory/builtin/theory_builtin_type_rules.h" -well-founded FUNCTION_TYPE \ - "::cvc5::theory::builtin::FunctionProperties::isWellFounded(%TYPE%)" \ - "::cvc5::theory::builtin::FunctionProperties::mkGroundTerm(%TYPE%)" \ - "theory/builtin/theory_builtin_type_rules.h" -enumerator FUNCTION_TYPE \ - ::cvc5::theory::builtin::FunctionEnumerator \ - "theory/builtin/type_enumerator.h" sort SEXPR_TYPE \ Cardinality::INTEGERS \ not-well-founded \ @@ -330,10 +318,6 @@ sort SEXPR_TYPE \ typerule EQUAL ::cvc5::theory::builtin::EqualityTypeRule typerule DISTINCT ::cvc5::theory::builtin::DistinctTypeRule typerule SEXPR ::cvc5::theory::builtin::SExprTypeRule -typerule LAMBDA ::cvc5::theory::builtin::LambdaTypeRule typerule WITNESS ::cvc5::theory::builtin::WitnessTypeRule -# lambda expressions that are isomorphic to array constants can be considered constants -construle LAMBDA ::cvc5::theory::builtin::LambdaTypeRule - endtheory diff --git a/src/theory/builtin/theory_builtin_rewriter.cpp b/src/theory/builtin/theory_builtin_rewriter.cpp index 0ee72fc5f..b57f2bf42 100644 --- a/src/theory/builtin/theory_builtin_rewriter.cpp +++ b/src/theory/builtin/theory_builtin_rewriter.cpp @@ -18,7 +18,6 @@ #include "theory/builtin/theory_builtin_rewriter.h" -#include "expr/array_store_all.h" #include "expr/attribute.h" #include "expr/node_algorithm.h" #include "theory/rewriter.h" @@ -55,45 +54,6 @@ Node TheoryBuiltinRewriter::blastDistinct(TNode in) { } RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) { - if( node.getKind()==kind::LAMBDA ){ - // The following code ensures that if node is equivalent to a constant - // lambda, then we return the canonical representation for the lambda, which - // in turn ensures that two constant lambdas are equivalent if and only - // if they are the same node. - // We canonicalize lambdas by turning them into array constants, applying - // normalization on array constants, and then converting the array constant - // back to a lambda. - Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl; - Node anode = getArrayRepresentationForLambda( node ); - // Only rewrite constant array nodes, since these are the only cases - // where we require canonicalization of lambdas. Moreover, applying the - // below code is not correct if the arguments to the lambda occur - // in return values. For example, lambda x. ite( x=1, f(x), c ) would - // be converted to (store (storeall ... c) 1 f(x)), and then converted - // to lambda y. ite( y=1, f(x), c), losing the relation between x and y. - if (!anode.isNull() && anode.isConst()) - { - Assert(anode.getType().isArray()); - //must get the standard bound variable list - Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() ); - Node retNode = getLambdaForArrayRepresentation( anode, varList ); - if( !retNode.isNull() && retNode!=node ){ - Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl; - Trace("builtin-rewrite") << " input : " << node << std::endl; - Trace("builtin-rewrite") << " output : " << retNode << ", constant = " << retNode.isConst() << std::endl; - Trace("builtin-rewrite") << " array rep : " << anode << ", constant = " << anode.isConst() << std::endl; - Assert(anode.isConst() == retNode.isConst()); - Assert(retNode.getType() == node.getType()); - Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode)); - return RewriteResponse(REWRITE_DONE, retNode); - } - } - else - { - Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl; - } - return RewriteResponse(REWRITE_DONE, node); - } // otherwise, do the default call return doRewrite(node); } @@ -117,346 +77,6 @@ RewriteResponse TheoryBuiltinRewriter::doRewrite(TNode node) } } -TypeNode TheoryBuiltinRewriter::getFunctionTypeForArrayType(TypeNode atn, - Node bvl) -{ - std::vector children; - for (unsigned i = 0; i < bvl.getNumChildren(); i++) - { - Assert(atn.isArray()); - Assert(bvl[i].getType() == atn.getArrayIndexType()); - children.push_back(atn.getArrayIndexType()); - atn = atn.getArrayConstituentType(); - } - children.push_back(atn); - return NodeManager::currentNM()->mkFunctionType(children); -} - -TypeNode TheoryBuiltinRewriter::getArrayTypeForFunctionType(TypeNode ftn) -{ - Assert(ftn.isFunction()); - // construct the curried array type - unsigned nchildren = ftn.getNumChildren(); - TypeNode ret = ftn[nchildren - 1]; - for (int i = (static_cast(nchildren) - 2); i >= 0; i--) - { - ret = NodeManager::currentNM()->mkArrayType(ftn[i], ret); - } - return ret; -} - -Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec( - TNode a, - TNode bvl, - unsigned bvlIndex, - std::unordered_map& visited) -{ - std::unordered_map::iterator it = visited.find(a); - if( it==visited.end() ){ - Node ret; - if( bvlIndexmkNode( kind::ITE, cond, val, body ); - } - } - }else if( a.getKind()==kind::STORE_ALL ){ - ArrayStoreAll storeAll = a.getConst(); - Node sa = storeAll.getValue(); - // convert the default value recursively (bounded by the number of arguments in bvl) - ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited ); - } - }else{ - ret = a; - } - visited[a] = ret; - return ret; - }else{ - return it->second; - } -} - -Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){ - Assert(a.getType().isArray()); - std::unordered_map visited; - Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl; - Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited ); - if( !body.isNull() ){ - body = Rewriter::rewrite( body ); - Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl; - return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body ); - }else{ - Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl; - return Node::null(); - } -} - -Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n, - TypeNode retType) -{ - Assert(n.getKind() == kind::LAMBDA); - NodeManager* nm = NodeManager::currentNM(); - Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl; - - Node first_arg = n[0][0]; - Node rec_bvl; - unsigned size = n[0].getNumChildren(); - if (size > 1) - { - std::vector< Node > args; - for (unsigned i = 1; i < size; i++) - { - args.push_back( n[0][i] ); - } - rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args); - } - - Trace("builtin-rewrite-debug2") << " process body..." << std::endl; - std::vector< Node > conds; - std::vector< Node > vals; - Node curr = n[1]; - Kind ck = curr.getKind(); - while (ck == kind::ITE || ck == kind::OR || ck == kind::AND - || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE) - { - Node index_eq; - Node curr_val; - Node next; - // Each iteration of this loop infers an entry in the function, e.g. it - // has a value under some condition. - - // [1] We infer that the entry has value "curr_val" under condition - // "index_eq". We set "next" to the node that is the remainder of the - // function to process. - if (ck == kind::ITE) - { - Trace("builtin-rewrite-debug2") - << " process condition : " << curr[0] << std::endl; - index_eq = curr[0]; - curr_val = curr[1]; - next = curr[2]; - } - else if (ck == kind::OR || ck == kind::AND) - { - Trace("builtin-rewrite-debug2") - << " process base : " << curr << std::endl; - // curr = Rewriter::rewrite(curr); - // Trace("builtin-rewrite-debug2") - // << " rewriten base : " << curr << std::endl; - // Complex Boolean return cases, in which - // (1) lambda x. (= x v1) v ... becomes - // lambda x. (ite (= x v1) true [...]) - // - // (2) lambda x. (not (= x v1)) ^ ... becomes - // lambda x. (ite (= x v1) false [...]) - // - // Note the negated cases of the lhs of the OR/AND operators above are - // handled by pushing the recursion to the then-branch, with the - // else-branch being the constant value. For example, the negated (1) - // would be - // (1') lambda x. (not (= x v1)) v ... becomes - // lambda x. (ite (= x v1) [...] true) - // thus requiring the rest of the disjunction to be further processed in - // the then-branch as the current value. - bool pol = curr[0].getKind() != kind::NOT; - bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR); - index_eq = pol ? curr[0] : curr[0][0]; - // processed : the value that is determined by the first child of curr - // remainder : the remaining children of curr - Node processed, remainder; - // the value is the polarity of the first child or its inverse if we are - // in the inverted case - processed = nm->mkConst(!inverted? pol : !pol); - // build an OR/AND with the remaining components - if (curr.getNumChildren() == 2) - { - remainder = curr[1]; - } - else - { - std::vector remainderNodes{curr.begin() + 1, curr.end()}; - remainder = nm->mkNode(ck, remainderNodes); - } - if (inverted) - { - curr_val = remainder; - next = processed; - // If the lambda contains more variables than the one being currently - // processed, the current value can be non-constant, since it'll be - // processed recursively below. Otherwise we fail. - if (rec_bvl.isNull() && !curr_val.isConst()) - { - Trace("builtin-rewrite-debug2") - << "...non-const curr_val " << curr_val << "\n"; - return Node::null(); - } - } - else - { - curr_val = processed; - next = remainder; - } - Trace("builtin-rewrite-debug2") << " index_eq : " << index_eq << "\n"; - Trace("builtin-rewrite-debug2") << " curr_val : " << curr_val << "\n"; - Trace("builtin-rewrite-debug2") << " next : " << next << std::endl; - } - else - { - Trace("builtin-rewrite-debug2") - << " process base : " << curr << std::endl; - // Simple Boolean return cases, in which - // (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false) - // (2) lambda x. v becomes lambda x. (ite (= x v) true false) - // Note the negateg cases of the bodies above are also handled. - bool pol = ck != kind::NOT; - index_eq = pol ? curr : curr[0]; - curr_val = nm->mkConst(pol); - next = nm->mkConst(!pol); - } - - // [2] We ensure that "index_eq" is an equality, if possible. - if (index_eq.getKind() != kind::EQUAL) - { - bool pol = index_eq.getKind() != kind::NOT; - Node indexEqAtom = pol ? index_eq : index_eq[0]; - if (indexEqAtom.getKind() == kind::BOUND_VARIABLE) - { - if (!indexEqAtom.getType().isBoolean()) - { - // Catches default case of non-Boolean variable, e.g. - // lambda x : Int. x. In this case, it is not canonical and we fail. - Trace("builtin-rewrite-debug2") - << " ...non-Boolean variable." << std::endl; - return Node::null(); - } - // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as - // lambda x. (ite (= x true) t s) - index_eq = indexEqAtom.eqNode(nm->mkConst(pol)); - } - else - { - // non-equality condition - Trace("builtin-rewrite-debug2") - << " ...non-equality condition." << std::endl; - return Node::null(); - } - } - else if (Rewriter::rewrite(index_eq) != index_eq) - { - // equality must be oriented correctly based on rewriter - Trace("builtin-rewrite-debug2") << " ...equality not oriented properly." << std::endl; - return Node::null(); - } - - // [3] We ensure that "index_eq" is an equality that is equivalent to - // "first_arg" = "curr_index", where curr_index is a constant, and - // "first_arg" is the current argument we are processing, if possible. - Node curr_index; - for( unsigned r=0; r<2; r++ ){ - Node arg = index_eq[r]; - Node val = index_eq[1-r]; - if( arg==first_arg ){ - if (!val.isConst()) - { - // non-constant value - Trace("builtin-rewrite-debug2") - << " ...non-constant value for argument\n."; - return Node::null(); - }else{ - curr_index = val; - Trace("builtin-rewrite-debug2") - << " arg " << arg << " -> " << val << std::endl; - break; - } - } - } - if (curr_index.isNull()) - { - Trace("builtin-rewrite-debug2") - << " ...could not infer index value." << std::endl; - return Node::null(); - } - - // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the - // remaining arguments (rec_bvl). - if (!rec_bvl.isNull()) - { - curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val); - Trace("builtin-rewrite-debug") << push; - Trace("builtin-rewrite-debug2") << push; - curr_val = getArrayRepresentationForLambdaRec(curr_val, retType); - Trace("builtin-rewrite-debug") << pop; - Trace("builtin-rewrite-debug2") << pop; - if (curr_val.isNull()) - { - Trace("builtin-rewrite-debug2") - << " ...failed to recursively find value." << std::endl; - return Node::null(); - } - } - Trace("builtin-rewrite-debug2") - << " ...condition is index " << curr_val << std::endl; - - // [5] Add the entry - conds.push_back( curr_index ); - vals.push_back( curr_val ); - - // we will now process the remainder - curr = next; - ck = curr.getKind(); - Trace("builtin-rewrite-debug2") - << " process remainder : " << curr << std::endl; - } - if( !rec_bvl.isNull() ){ - curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr); - Trace("builtin-rewrite-debug") << push; - Trace("builtin-rewrite-debug2") << push; - curr = getArrayRepresentationForLambdaRec(curr, retType); - Trace("builtin-rewrite-debug") << pop; - Trace("builtin-rewrite-debug2") << pop; - } - if( !curr.isNull() && curr.isConst() ){ - // compute the return type - TypeNode array_type = retType; - for (unsigned i = 0; i < size; i++) - { - unsigned index = (size - 1) - i; - array_type = nm->mkArrayType(n[0][index].getType(), array_type); - } - Trace("builtin-rewrite-debug2") << " make array store all " << curr.getType() << " annotated : " << array_type << std::endl; - Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType())); - curr = nm->mkConst(ArrayStoreAll(array_type, curr)); - Trace("builtin-rewrite-debug2") << " build array..." << std::endl; - // can only build if default value is constant (since array store all must be constant) - Trace("builtin-rewrite-debug2") << " got constant base " << curr << std::endl; - Trace("builtin-rewrite-debug2") << " conditions " << conds << std::endl; - Trace("builtin-rewrite-debug2") << " values " << vals << std::endl; - // construct store chain - for (int i = static_cast(conds.size()) - 1; i >= 0; i--) - { - Assert(conds[i].getType().isSubtypeOf(first_arg.getType())); - curr = nm->mkNode(kind::STORE, curr, conds[i], vals[i]); - } - Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl; - return curr; - }else{ - Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl; - return Node::null(); - } -} - Node TheoryBuiltinRewriter::rewriteWitness(TNode node) { Assert(node.getKind() == kind::WITNESS); @@ -493,21 +113,6 @@ Node TheoryBuiltinRewriter::rewriteWitness(TNode node) return node; } -Node TheoryBuiltinRewriter::getArrayRepresentationForLambda(TNode n) -{ - Assert(n.getKind() == kind::LAMBDA); - // must carry the overall return type to deal with cases like (lambda ((x Int) - // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else - // case above should be (arraystoreall (Array Int Real) 0.0) - Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType()); - if (anode.isNull()) - { - return anode; - } - // must rewrite it to make canonical - return Rewriter::rewrite(anode); -} - } // namespace builtin } // namespace theory } // namespace cvc5 diff --git a/src/theory/builtin/theory_builtin_rewriter.h b/src/theory/builtin/theory_builtin_rewriter.h index f528ed43c..0f903bc44 100644 --- a/src/theory/builtin/theory_builtin_rewriter.h +++ b/src/theory/builtin/theory_builtin_rewriter.h @@ -37,17 +37,6 @@ class TheoryBuiltinRewriter : public TheoryRewriter RewriteResponse preRewrite(TNode node) override { return doRewrite(node); } - // conversions between lambdas and arrays - private: - /** recursive helper for getLambdaForArrayRepresentation */ - static Node getLambdaForArrayRepresentationRec( - TNode a, - TNode bvl, - unsigned bvlIndex, - std::unordered_map& visited); - /** recursive helper for getArrayRepresentationForLambda */ - static Node getArrayRepresentationForLambdaRec(TNode n, TypeNode retType); - public: /** * The default rewriter for rewrites that occur at both pre and post rewrite. @@ -58,67 +47,6 @@ class TheoryBuiltinRewriter : public TheoryRewriter * Returns the rewritten form of node. */ static Node rewriteWitness(TNode node); - /** Get function type for array type - * - * This returns the function type of terms returned by the function - * getLambdaForArrayRepresentation( t, bvl ), - * where t.getType()=atn. - * - * bvl should be a bound variable list whose variables correspond in-order - * to the index types of the (curried) Array type. For example, a bound - * variable list bvl whose variables have types (Int, Real) can be given as - * input when paired with atn = (Array Int (Array Real Bool)), or (Array Int - * (Array Real (Array Bool Bool))). This function returns (-> Int Real Bool) - * and (-> Int Real (Array Bool Bool)) respectively in these cases. - * On the other hand, the above bvl is not a proper input for - * atn = (Array Int (Array Bool Bool)) or (Array Int Int). - * If the types of bvl and atn do not match, we throw an assertion failure. - */ - static TypeNode getFunctionTypeForArrayType(TypeNode atn, Node bvl); - /** Get array type for function type - * - * This returns the array type of terms returned by - * getArrayRepresentationForLambda( t ), where t.getType()=ftn. - */ - static TypeNode getArrayTypeForFunctionType(TypeNode ftn); - /** - * Given an array constant a, returns a lambda expression that it corresponds - * to, with bound variable list bvl. - * Examples: - * - * (store (storeall (Array Int Int) 2) 0 1) - * becomes - * ((lambda x. (ite (= x 0) 1 2)) - * - * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4)) - * 0 (store (storeall (Array Int Int) 3) 1 2)) becomes (lambda xy. (ite (= x - * 0) (ite (= x 1) 2 3) 4)) - * - * (store (store (storeall (Array Int Bool) false) 2 true) 1 true) - * becomes - * (lambda x. (ite (= x 1) true (ite (= x 2) true false))) - * - * Notice that the return body of the lambda is rewritten to ensure that the - * representation is canonical. Hence the last - * example will in fact be returned as: - * (lambda x. (ite (= x 1) true (= x 2))) - */ - static Node getLambdaForArrayRepresentation(TNode a, TNode bvl); - /** - * Given a lambda expression n, returns an array term that corresponds to n. - * This does the opposite direction of the examples described above. - * - * We limit the return values of this method to be almost constant functions, - * that is, arrays of the form: - * (store ... (store (storeall _ b) i1 e1) ... in en) - * where b, i1, e1, ..., in, en are constants. - * Notice however that the return value of this form need not be a (canonical) - * array constant. - * - * If it is not possible to construct an array of this form that corresponds - * to n, this method returns null. - */ - static Node getArrayRepresentationForLambda(TNode n); }; /* class TheoryBuiltinRewriter */ } // namespace builtin diff --git a/src/theory/builtin/theory_builtin_type_rules.cpp b/src/theory/builtin/theory_builtin_type_rules.cpp index 1888069bc..636952be5 100644 --- a/src/theory/builtin/theory_builtin_type_rules.cpp +++ b/src/theory/builtin/theory_builtin_type_rules.cpp @@ -18,7 +18,6 @@ #include "expr/attribute.h" #include "expr/skolem_manager.h" #include "expr/uninterpreted_constant.h" -#include "util/cardinality.h" namespace cvc5 { namespace theory { @@ -56,25 +55,6 @@ Node SortProperties::mkGroundTerm(TypeNode type) return k; } -Cardinality FunctionProperties::computeCardinality(TypeNode type) -{ - // Don't assert this; allow other theories to use this cardinality - // computation. - // - // Assert(type.getKind() == kind::FUNCTION_TYPE); - - Cardinality argsCard(1); - // get the largest cardinality of function arguments/return type - for (size_t i = 0, i_end = type.getNumChildren() - 1; i < i_end; ++i) - { - argsCard *= type[i].getCardinality(); - } - - Cardinality valueCard = type[type.getNumChildren() - 1].getCardinality(); - - return valueCard ^ argsCard; -} - } // namespace builtin } // namespace theory } // namespace cvc5 diff --git a/src/theory/builtin/theory_builtin_type_rules.h b/src/theory/builtin/theory_builtin_type_rules.h index 54139c433..2117249c2 100644 --- a/src/theory/builtin/theory_builtin_type_rules.h +++ b/src/theory/builtin/theory_builtin_type_rules.h @@ -20,7 +20,6 @@ #include "expr/node.h" #include "expr/type_node.h" -#include "theory/builtin/theory_builtin_rewriter.h" // for array and lambda representation #include @@ -107,54 +106,6 @@ class AbstractValueTypeRule { } };/* class AbstractValueTypeRule */ -class LambdaTypeRule { - public: - inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) { - if( n[0].getType(check) != nodeManager->boundVarListType() ) { - std::stringstream ss; - ss << "expected a bound var list for LAMBDA expression, got `" - << n[0].getType().toString() << "'"; - throw TypeCheckingExceptionPrivate(n, ss.str()); - } - std::vector argTypes; - for(TNode::iterator i = n[0].begin(); i != n[0].end(); ++i) { - argTypes.push_back((*i).getType()); - } - TypeNode rangeType = n[1].getType(check); - return nodeManager->mkFunctionType(argTypes, rangeType); - } - // computes whether a lambda is a constant value, via conversion to array representation - inline static bool computeIsConst(NodeManager* nodeManager, TNode n) - { - Assert(n.getKind() == kind::LAMBDA); - //get array representation of this function, if possible - Node na = TheoryBuiltinRewriter::getArrayRepresentationForLambda(n); - if( !na.isNull() ){ - Assert(na.getType().isArray()); - Trace("lambda-const") << "Array representation for " << n << " is " << na << " " << na.getType() << std::endl; - // must have the standard bound variable list - Node bvl = NodeManager::currentNM()->getBoundVarListForFunctionType( n.getType() ); - if( bvl==n[0] ){ - //array must be constant - if( na.isConst() ){ - Trace("lambda-const") << "*** Constant lambda : " << n; - Trace("lambda-const") << " since its array representation : " << na << " is constant." << std::endl; - return true; - }else{ - Trace("lambda-const") << "Non-constant lambda : " << n << " since array is not constant." << std::endl; - } - }else{ - Trace("lambda-const") << "Non-constant lambda : " << n << " since its varlist is not standard." << std::endl; - Trace("lambda-const") << " standard : " << bvl << std::endl; - Trace("lambda-const") << " current : " << n[0] << std::endl; - } - }else{ - Trace("lambda-const") << "Non-constant lambda : " << n << " since it has no array representation." << std::endl; - } - return false; - } -};/* class LambdaTypeRule */ - class WitnessTypeRule { public: @@ -198,37 +149,6 @@ class SortProperties { static Node mkGroundTerm(TypeNode type); };/* class SortProperties */ -class FunctionProperties { - public: - static Cardinality computeCardinality(TypeNode type); - - /** Function type is well-founded if its component sorts are */ - static bool isWellFounded(TypeNode type) - { - for (TypeNode::iterator i = type.begin(), i_end = type.end(); i != i_end; - ++i) - { - if (!(*i).isWellFounded()) - { - return false; - } - } - return true; - } - /** - * Ground term for function sorts is (lambda x. t) where x is the - * canonical variable list for its type and t is the canonical ground term of - * its range. - */ - static Node mkGroundTerm(TypeNode type) - { - NodeManager* nm = NodeManager::currentNM(); - Node bvl = nm->getBoundVarListForFunctionType(type); - Node ret = type.getRangeType().mkGroundTerm(); - return nm->mkNode(kind::LAMBDA, bvl, ret); - } -};/* class FuctionProperties */ - } // namespace builtin } // namespace theory } // namespace cvc5 diff --git a/src/theory/builtin/type_enumerator.cpp b/src/theory/builtin/type_enumerator.cpp index 0ef1d3ec7..2e919810b 100644 --- a/src/theory/builtin/type_enumerator.cpp +++ b/src/theory/builtin/type_enumerator.cpp @@ -21,31 +21,54 @@ namespace cvc5 { namespace theory { namespace builtin { -FunctionEnumerator::FunctionEnumerator(TypeNode type, - TypeEnumeratorProperties* tep) - : TypeEnumeratorBase(type), - d_arrayEnum(TheoryBuiltinRewriter::getArrayTypeForFunctionType(type), tep) +UninterpretedSortEnumerator::UninterpretedSortEnumerator( + TypeNode type, TypeEnumeratorProperties* tep) + : TypeEnumeratorBase(type), d_count(0) { - Assert(type.getKind() == kind::FUNCTION_TYPE); - d_bvl = NodeManager::currentNM()->getBoundVarListForFunctionType(type); + Assert(type.getKind() == kind::SORT_TYPE); + d_has_fixed_bound = false; + Trace("uf-type-enum") << "UF enum " << type << ", tep = " << tep << std::endl; + if (tep && tep->d_fixed_usort_card) + { + d_has_fixed_bound = true; + std::map::iterator it = tep->d_fixed_card.find(type); + if (it != tep->d_fixed_card.end()) + { + d_fixed_bound = it->second; + } + else + { + d_fixed_bound = Integer(1); + } + Trace("uf-type-enum") << "...fixed bound : " << d_fixed_bound << std::endl; + } } -Node FunctionEnumerator::operator*() +Node UninterpretedSortEnumerator::operator*() { if (isFinished()) { throw NoMoreValuesException(getType()); } - Node a = *d_arrayEnum; - return TheoryBuiltinRewriter::getLambdaForArrayRepresentation(a, d_bvl); + return NodeManager::currentNM()->mkConst( + UninterpretedConstant(getType(), d_count)); } -FunctionEnumerator& FunctionEnumerator::operator++() +UninterpretedSortEnumerator& UninterpretedSortEnumerator::operator++() { - ++d_arrayEnum; + d_count += 1; return *this; } +bool UninterpretedSortEnumerator::isFinished() +{ + if (d_has_fixed_bound) + { + return d_count >= d_fixed_bound; + } + return false; +} + } // namespace builtin } // namespace theory } // namespace cvc5 diff --git a/src/theory/builtin/type_enumerator.h b/src/theory/builtin/type_enumerator.h index 980792f94..711752e23 100644 --- a/src/theory/builtin/type_enumerator.h +++ b/src/theory/builtin/type_enumerator.h @@ -35,73 +35,14 @@ class UninterpretedSortEnumerator : public TypeEnumeratorBase(type), d_count(0) - { - Assert(type.getKind() == kind::SORT_TYPE); - d_has_fixed_bound = false; - Trace("uf-type-enum") << "UF enum " << type << ", tep = " << tep << std::endl; - if( tep && tep->d_fixed_usort_card ){ - d_has_fixed_bound = true; - std::map< TypeNode, Integer >::iterator it = tep->d_fixed_card.find( type ); - if( it!=tep->d_fixed_card.end() ){ - d_fixed_bound = it->second; - }else{ - d_fixed_bound = Integer(1); - } - Trace("uf-type-enum") << "...fixed bound : " << d_fixed_bound << std::endl; - } - } + TypeEnumeratorProperties* tep = nullptr); - Node operator*() override - { - if(isFinished()) { - throw NoMoreValuesException(getType()); - } - return NodeManager::currentNM()->mkConst( - UninterpretedConstant(getType(), d_count)); - } - - UninterpretedSortEnumerator& operator++() override - { - d_count += 1; - return *this; - } - - bool isFinished() override - { - if( d_has_fixed_bound ){ - return d_count>=d_fixed_bound; - }else{ - return false; - } - } + Node operator*() override; -};/* class UninterpretedSortEnumerator */ + UninterpretedSortEnumerator& operator++() override; -/** FunctionEnumerator -* This enumerates function values, based on the enumerator for the -* array type corresponding to the given function type. -*/ -class FunctionEnumerator : public TypeEnumeratorBase -{ - public: - FunctionEnumerator(TypeNode type, TypeEnumeratorProperties* tep = nullptr); - /** Get the current term of the enumerator. */ - Node operator*() override; - /** Increment the enumerator. */ - FunctionEnumerator& operator++() override; - /** is the enumerator finished? */ - bool isFinished() override { return d_arrayEnum.isFinished(); } - private: - /** Enumerates arrays, which we convert to functions. */ - TypeEnumerator d_arrayEnum; - /** The bound variable list for the function type we are enumerating. - * All terms output by this enumerator are of the form (LAMBDA d_bvl t) for - * some term t. - */ - Node d_bvl; -}; /* class FunctionEnumerator */ + bool isFinished() override; +}; } // namespace builtin } // namespace theory diff --git a/src/theory/datatypes/kinds b/src/theory/datatypes/kinds index cb3a78cf2..5324e1c79 100644 --- a/src/theory/datatypes/kinds +++ b/src/theory/datatypes/kinds @@ -21,22 +21,22 @@ cardinality CONSTRUCTOR_TYPE \ operator SELECTOR_TYPE 2 "selector" # can re-use function cardinality cardinality SELECTOR_TYPE \ - "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \ - "theory/builtin/theory_builtin_type_rules.h" + "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \ + "theory/uf/theory_uf_type_rules.h" # tester type has a constructor type operator TESTER_TYPE 1 "tester" # can re-use function cardinality cardinality TESTER_TYPE \ - "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \ - "theory/builtin/theory_builtin_type_rules.h" + "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \ + "theory/uf/theory_uf_type_rules.h" # tester type has a constructor type operator UPDATER_TYPE 2 "datatype update" # can re-use function cardinality cardinality UPDATER_TYPE \ - "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \ - "theory/builtin/theory_builtin_type_rules.h" + "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \ + "theory/uf/theory_uf_type_rules.h" parameterized APPLY_CONSTRUCTOR APPLY_TYPE_ASCRIPTION 0: "constructor application; first parameter is the constructor, remaining parameters (if any) are parameters to the constructor" diff --git a/src/theory/uf/function_const.cpp b/src/theory/uf/function_const.cpp new file mode 100644 index 000000000..181cb20ca --- /dev/null +++ b/src/theory/uf/function_const.cpp @@ -0,0 +1,412 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utilities for function constants + */ + +#include "theory/uf/function_const.h" + +#include "expr/array_store_all.h" +#include "theory/rewriter.h" + +namespace cvc5 { +namespace theory { +namespace uf { + +TypeNode FunctionConst::getFunctionTypeForArrayType(TypeNode atn, Node bvl) +{ + std::vector children; + for (unsigned i = 0; i < bvl.getNumChildren(); i++) + { + Assert(atn.isArray()); + Assert(bvl[i].getType() == atn.getArrayIndexType()); + children.push_back(atn.getArrayIndexType()); + atn = atn.getArrayConstituentType(); + } + children.push_back(atn); + return NodeManager::currentNM()->mkFunctionType(children); +} + +TypeNode FunctionConst::getArrayTypeForFunctionType(TypeNode ftn) +{ + Assert(ftn.isFunction()); + // construct the curried array type + size_t nchildren = ftn.getNumChildren(); + TypeNode ret = ftn[nchildren - 1]; + for (size_t i = 0; i < nchildren - 1; i++) + { + size_t ii = nchildren - i - 2; + ret = NodeManager::currentNM()->mkArrayType(ftn[ii], ret); + } + return ret; +} + +Node FunctionConst::getLambdaForArrayRepresentationRec( + TNode a, + TNode bvl, + unsigned bvlIndex, + std::unordered_map& visited) +{ + std::unordered_map::iterator it = visited.find(a); + if (it != visited.end()) + { + return it->second; + } + Node ret; + if (bvlIndex < bvl.getNumChildren()) + { + Assert(a.getType().isArray()); + if (a.getKind() == kind::STORE) + { + // convert the array recursively + Node body = + getLambdaForArrayRepresentationRec(a[0], bvl, bvlIndex, visited); + if (!body.isNull()) + { + // convert the value recursively (bounded by the number of arguments + // in bvl) + Node val = getLambdaForArrayRepresentationRec( + a[2], bvl, bvlIndex + 1, visited); + if (!val.isNull()) + { + Assert(!TypeNode::leastCommonTypeNode(a[1].getType(), + bvl[bvlIndex].getType()) + .isNull()); + Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType()) + .isNull()); + Node cond = bvl[bvlIndex].eqNode(a[1]); + ret = NodeManager::currentNM()->mkNode(kind::ITE, cond, val, body); + } + } + } + else if (a.getKind() == kind::STORE_ALL) + { + ArrayStoreAll storeAll = a.getConst(); + Node sa = storeAll.getValue(); + // convert the default value recursively (bounded by the number of + // arguments in bvl) + ret = getLambdaForArrayRepresentationRec(sa, bvl, bvlIndex + 1, visited); + } + } + else + { + ret = a; + } + visited[a] = ret; + return ret; +} + +Node FunctionConst::getLambdaForArrayRepresentation(TNode a, TNode bvl) +{ + Assert(a.getType().isArray()); + std::unordered_map visited; + Trace("builtin-rewrite-debug") + << "Get lambda for : " << a << ", with variables " << bvl << std::endl; + Node body = getLambdaForArrayRepresentationRec(a, bvl, 0, visited); + if (!body.isNull()) + { + body = Rewriter::rewrite(body); + Trace("builtin-rewrite-debug") + << "...got lambda body " << body << std::endl; + return NodeManager::currentNM()->mkNode(kind::LAMBDA, bvl, body); + } + Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl; + return Node::null(); +} + +Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n, + TypeNode retType) +{ + Assert(n.getKind() == kind::LAMBDA); + NodeManager* nm = NodeManager::currentNM(); + Trace("builtin-rewrite-debug") + << "Get array representation for : " << n << std::endl; + + Node first_arg = n[0][0]; + Node rec_bvl; + size_t size = n[0].getNumChildren(); + if (size > 1) + { + std::vector args; + for (size_t i = 1; i < size; i++) + { + args.push_back(n[0][i]); + } + rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args); + } + + Trace("builtin-rewrite-debug2") << " process body..." << std::endl; + std::vector conds; + std::vector vals; + Node curr = n[1]; + Kind ck = curr.getKind(); + while (ck == kind::ITE || ck == kind::OR || ck == kind::AND + || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE) + { + Node index_eq; + Node curr_val; + Node next; + // Each iteration of this loop infers an entry in the function, e.g. it + // has a value under some condition. + + // [1] We infer that the entry has value "curr_val" under condition + // "index_eq". We set "next" to the node that is the remainder of the + // function to process. + if (ck == kind::ITE) + { + Trace("builtin-rewrite-debug2") + << " process condition : " << curr[0] << std::endl; + index_eq = curr[0]; + curr_val = curr[1]; + next = curr[2]; + } + else if (ck == kind::OR || ck == kind::AND) + { + Trace("builtin-rewrite-debug2") + << " process base : " << curr << std::endl; + // curr = Rewriter::rewrite(curr); + // Trace("builtin-rewrite-debug2") + // << " rewriten base : " << curr << std::endl; + // Complex Boolean return cases, in which + // (1) lambda x. (= x v1) v ... becomes + // lambda x. (ite (= x v1) true [...]) + // + // (2) lambda x. (not (= x v1)) ^ ... becomes + // lambda x. (ite (= x v1) false [...]) + // + // Note the negated cases of the lhs of the OR/AND operators above are + // handled by pushing the recursion to the then-branch, with the + // else-branch being the constant value. For example, the negated (1) + // would be + // (1') lambda x. (not (= x v1)) v ... becomes + // lambda x. (ite (= x v1) [...] true) + // thus requiring the rest of the disjunction to be further processed in + // the then-branch as the current value. + bool pol = curr[0].getKind() != kind::NOT; + bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR); + index_eq = pol ? curr[0] : curr[0][0]; + // processed : the value that is determined by the first child of curr + // remainder : the remaining children of curr + Node processed, remainder; + // the value is the polarity of the first child or its inverse if we are + // in the inverted case + processed = nm->mkConst(!inverted ? pol : !pol); + // build an OR/AND with the remaining components + if (curr.getNumChildren() == 2) + { + remainder = curr[1]; + } + else + { + std::vector remainderNodes{curr.begin() + 1, curr.end()}; + remainder = nm->mkNode(ck, remainderNodes); + } + if (inverted) + { + curr_val = remainder; + next = processed; + // If the lambda contains more variables than the one being currently + // processed, the current value can be non-constant, since it'll be + // processed recursively below. Otherwise we fail. + if (rec_bvl.isNull() && !curr_val.isConst()) + { + Trace("builtin-rewrite-debug2") + << "...non-const curr_val " << curr_val << "\n"; + return Node::null(); + } + } + else + { + curr_val = processed; + next = remainder; + } + Trace("builtin-rewrite-debug2") << " index_eq : " << index_eq << "\n"; + Trace("builtin-rewrite-debug2") << " curr_val : " << curr_val << "\n"; + Trace("builtin-rewrite-debug2") << " next : " << next << std::endl; + } + else + { + Trace("builtin-rewrite-debug2") + << " process base : " << curr << std::endl; + // Simple Boolean return cases, in which + // (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false) + // (2) lambda x. v becomes lambda x. (ite (= x v) true false) + // Note the negateg cases of the bodies above are also handled. + bool pol = ck != kind::NOT; + index_eq = pol ? curr : curr[0]; + curr_val = nm->mkConst(pol); + next = nm->mkConst(!pol); + } + + // [2] We ensure that "index_eq" is an equality, if possible. + if (index_eq.getKind() != kind::EQUAL) + { + bool pol = index_eq.getKind() != kind::NOT; + Node indexEqAtom = pol ? index_eq : index_eq[0]; + if (indexEqAtom.getKind() == kind::BOUND_VARIABLE) + { + if (!indexEqAtom.getType().isBoolean()) + { + // Catches default case of non-Boolean variable, e.g. + // lambda x : Int. x. In this case, it is not canonical and we fail. + Trace("builtin-rewrite-debug2") + << " ...non-Boolean variable." << std::endl; + return Node::null(); + } + // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as + // lambda x. (ite (= x true) t s) + index_eq = indexEqAtom.eqNode(nm->mkConst(pol)); + } + else + { + // non-equality condition + Trace("builtin-rewrite-debug2") + << " ...non-equality condition." << std::endl; + return Node::null(); + } + } + else if (Rewriter::rewrite(index_eq) != index_eq) + { + // equality must be oriented correctly based on rewriter + Trace("builtin-rewrite-debug2") + << " ...equality not oriented properly." << std::endl; + return Node::null(); + } + + // [3] We ensure that "index_eq" is an equality that is equivalent to + // "first_arg" = "curr_index", where curr_index is a constant, and + // "first_arg" is the current argument we are processing, if possible. + Node curr_index; + for (unsigned r = 0; r < 2; r++) + { + Node arg = index_eq[r]; + Node val = index_eq[1 - r]; + if (arg == first_arg) + { + if (!val.isConst()) + { + // non-constant value + Trace("builtin-rewrite-debug2") + << " ...non-constant value for argument\n."; + return Node::null(); + } + else + { + curr_index = val; + Trace("builtin-rewrite-debug2") + << " arg " << arg << " -> " << val << std::endl; + break; + } + } + } + if (curr_index.isNull()) + { + Trace("builtin-rewrite-debug2") + << " ...could not infer index value." << std::endl; + return Node::null(); + } + + // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the + // remaining arguments (rec_bvl). + if (!rec_bvl.isNull()) + { + curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val); + Trace("builtin-rewrite-debug") << push; + Trace("builtin-rewrite-debug2") << push; + curr_val = getArrayRepresentationForLambdaRec(curr_val, retType); + Trace("builtin-rewrite-debug") << pop; + Trace("builtin-rewrite-debug2") << pop; + if (curr_val.isNull()) + { + Trace("builtin-rewrite-debug2") + << " ...failed to recursively find value." << std::endl; + return Node::null(); + } + } + Trace("builtin-rewrite-debug2") + << " ...condition is index " << curr_val << std::endl; + + // [5] Add the entry + conds.push_back(curr_index); + vals.push_back(curr_val); + + // we will now process the remainder + curr = next; + ck = curr.getKind(); + Trace("builtin-rewrite-debug2") + << " process remainder : " << curr << std::endl; + } + if (!rec_bvl.isNull()) + { + curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr); + Trace("builtin-rewrite-debug") << push; + Trace("builtin-rewrite-debug2") << push; + curr = getArrayRepresentationForLambdaRec(curr, retType); + Trace("builtin-rewrite-debug") << pop; + Trace("builtin-rewrite-debug2") << pop; + } + if (!curr.isNull() && curr.isConst()) + { + // compute the return type + TypeNode array_type = retType; + for (size_t i = 0; i < size; i++) + { + size_t index = (size - 1) - i; + array_type = nm->mkArrayType(n[0][index].getType(), array_type); + } + Trace("builtin-rewrite-debug2") + << " make array store all " << curr.getType() + << " annotated : " << array_type << std::endl; + Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType())); + curr = nm->mkConst(ArrayStoreAll(array_type, curr)); + Trace("builtin-rewrite-debug2") << " build array..." << std::endl; + // can only build if default value is constant (since array store all must + // be constant) + Trace("builtin-rewrite-debug2") + << " got constant base " << curr << std::endl; + Trace("builtin-rewrite-debug2") << " conditions " << conds << std::endl; + Trace("builtin-rewrite-debug2") << " values " << vals << std::endl; + // construct store chain + for (size_t i = 0, numCond = conds.size(); i < numCond; i++) + { + size_t ii = (numCond - 1) - i; + Assert(conds[ii].getType().isSubtypeOf(first_arg.getType())); + curr = nm->mkNode(kind::STORE, curr, conds[ii], vals[ii]); + } + Trace("builtin-rewrite-debug") + << "...got array " << curr << " for " << n << std::endl; + return curr; + } + Trace("builtin-rewrite-debug") + << "...failed to get array (cannot get constant default value)" + << std::endl; + return Node::null(); +} + +Node FunctionConst::getArrayRepresentationForLambda(TNode n) +{ + Assert(n.getKind() == kind::LAMBDA); + // must carry the overall return type to deal with cases like (lambda ((x Int) + // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else + // case above should be (arraystoreall (Array Int Real) 0.0) + Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType()); + if (anode.isNull()) + { + return anode; + } + // must rewrite it to make canonical + return Rewriter::rewrite(anode); +} + +} // namespace uf +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/uf/function_const.h b/src/theory/uf/function_const.h new file mode 100644 index 000000000..10d1bf89c --- /dev/null +++ b/src/theory/uf/function_const.h @@ -0,0 +1,110 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Utilities for function constants + */ + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__UF__FUNCTION_CONST_H +#define CVC5__THEORY__UF__FUNCTION_CONST_H + +#include + +#include "expr/node.h" + +namespace cvc5 { +namespace theory { +namespace uf { + +/** Conversion between lambda and array constants */ +class FunctionConst +{ + public: + /** Get function type for array type + * + * This returns the function type of terms returned by the function + * getLambdaForArrayRepresentation( t, bvl ), + * where t.getType()=atn. + * + * bvl should be a bound variable list whose variables correspond in-order + * to the index types of the (curried) Array type. For example, a bound + * variable list bvl whose variables have types (Int, Real) can be given as + * input when paired with atn = (Array Int (Array Real Bool)), or (Array Int + * (Array Real (Array Bool Bool))). This function returns (-> Int Real Bool) + * and (-> Int Real (Array Bool Bool)) respectively in these cases. + * On the other hand, the above bvl is not a proper input for + * atn = (Array Int (Array Bool Bool)) or (Array Int Int). + * If the types of bvl and atn do not match, we throw an assertion failure. + */ + static TypeNode getFunctionTypeForArrayType(TypeNode atn, Node bvl); + /** Get array type for function type + * + * This returns the array type of terms returned by + * getArrayRepresentationForLambda( t ), where t.getType()=ftn. + */ + static TypeNode getArrayTypeForFunctionType(TypeNode ftn); + /** + * Given an array constant a, returns a lambda expression that it corresponds + * to, with bound variable list bvl. + * Examples: + * + * (store (storeall (Array Int Int) 2) 0 1) + * becomes + * ((lambda x. (ite (= x 0) 1 2)) + * + * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4)) + * 0 (store (storeall (Array Int Int) 3) 1 2)) becomes (lambda xy. (ite (= x + * 0) (ite (= x 1) 2 3) 4)) + * + * (store (store (storeall (Array Int Bool) false) 2 true) 1 true) + * becomes + * (lambda x. (ite (= x 1) true (ite (= x 2) true false))) + * + * Notice that the return body of the lambda is rewritten to ensure that the + * representation is canonical. Hence the last + * example will in fact be returned as: + * (lambda x. (ite (= x 1) true (= x 2))) + */ + static Node getLambdaForArrayRepresentation(TNode a, TNode bvl); + /** + * Given a lambda expression n, returns an array term that corresponds to n. + * This does the opposite direction of the examples described above. + * + * We limit the return values of this method to be almost constant functions, + * that is, arrays of the form: + * (store ... (store (storeall _ b) i1 e1) ... in en) + * where b, i1, e1, ..., in, en are constants. + * Notice however that the return value of this form need not be a (canonical) + * array constant. + * + * If it is not possible to construct an array of this form that corresponds + * to n, this method returns null. + */ + static Node getArrayRepresentationForLambda(TNode n); + + private: + /** recursive helper for getLambdaForArrayRepresentation */ + static Node getLambdaForArrayRepresentationRec( + TNode a, + TNode bvl, + unsigned bvlIndex, + std::unordered_map& visited); + /** recursive helper for getArrayRepresentationForLambda */ + static Node getArrayRepresentationForLambdaRec(TNode n, TypeNode retType); +}; + +} // namespace uf +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__THEORY__UF__FUNCTION_CONST_H */ diff --git a/src/theory/uf/kinds b/src/theory/uf/kinds index 0faa5c672..a1db5120f 100644 --- a/src/theory/uf/kinds +++ b/src/theory/uf/kinds @@ -15,10 +15,28 @@ parameterized APPLY_UF VARIABLE 1: "application of an uninterpreted function; fi typerule APPLY_UF ::cvc5::theory::uf::UfTypeRule +operator FUNCTION_TYPE 2: "a function type" +cardinality FUNCTION_TYPE \ + "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \ + "theory/uf/theory_uf_type_rules.h" +well-founded FUNCTION_TYPE \ + "::cvc5::theory::uf::FunctionProperties::isWellFounded(%TYPE%)" \ + "::cvc5::theory::uf::FunctionProperties::mkGroundTerm(%TYPE%)" \ + "theory/uf/theory_uf_type_rules.h" +enumerator FUNCTION_TYPE \ + ::cvc5::theory::uf::FunctionEnumerator \ + "theory/uf/type_enumerator.h" + +operator LAMBDA 2 "a lambda expression; first parameter is a BOUND_VAR_LIST, second is lambda body" + +typerule LAMBDA ::cvc5::theory::uf::LambdaTypeRule + variable BOOLEAN_TERM_VARIABLE "Boolean term variable" -parameterized PARTIAL_APPLY_UF APPLY_UF 1: "partial uninterpreted function application" -typerule PARTIAL_APPLY_UF ::cvc5::theory::uf::PartialTypeRule +variable LAMBDA_VARIABLE "Lambda variable, used for lazy lambda lifting" + +# lambda expressions that are isomorphic to array constants can be considered constants +construle LAMBDA ::cvc5::theory::uf::LambdaTypeRule operator HO_APPLY 2 "higher-order (partial) function application" typerule HO_APPLY ::cvc5::theory::uf::HoApplyTypeRule diff --git a/src/theory/uf/theory_uf_rewriter.cpp b/src/theory/uf/theory_uf_rewriter.cpp index f4bedb4b8..ba00c316f 100644 --- a/src/theory/uf/theory_uf_rewriter.cpp +++ b/src/theory/uf/theory_uf_rewriter.cpp @@ -18,6 +18,7 @@ #include "expr/node_algorithm.h" #include "theory/rewriter.h" #include "theory/substitutions.h" +#include "theory/uf/function_const.h" namespace cvc5 { namespace theory { @@ -139,6 +140,11 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node) return RewriteResponse(REWRITE_AGAIN_FULL, new_body); } } + else if (node.getKind() == kind::LAMBDA) + { + Node ret = rewriteLambda(node); + return RewriteResponse(REWRITE_DONE, ret); + } return RewriteResponse(REWRITE_DONE, node); } @@ -204,6 +210,56 @@ Node TheoryUfRewriter::decomposeHoApply(TNode n, } bool TheoryUfRewriter::canUseAsApplyUfOperator(TNode n) { return n.isVar(); } +Node TheoryUfRewriter::rewriteLambda(Node node) +{ + Assert(node.getKind() == kind::LAMBDA); + // The following code ensures that if node is equivalent to a constant + // lambda, then we return the canonical representation for the lambda, which + // in turn ensures that two constant lambdas are equivalent if and only + // if they are the same node. + // We canonicalize lambdas by turning them into array constants, applying + // normalization on array constants, and then converting the array constant + // back to a lambda. + Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl; + Node anode = FunctionConst::getArrayRepresentationForLambda(node); + // Only rewrite constant array nodes, since these are the only cases + // where we require canonicalization of lambdas. Moreover, applying the + // below code is not correct if the arguments to the lambda occur + // in return values. For example, lambda x. ite( x=1, f(x), c ) would + // be converted to (store (storeall ... c) 1 f(x)), and then converted + // to lambda y. ite( y=1, f(x), c), losing the relation between x and y. + if (!anode.isNull() && anode.isConst()) + { + Assert(anode.getType().isArray()); + // must get the standard bound variable list + Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( + node.getType()); + Node retNode = + FunctionConst::getLambdaForArrayRepresentation(anode, varList); + if (!retNode.isNull() && retNode != node) + { + Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl; + Trace("builtin-rewrite") << " input : " << node << std::endl; + Trace("builtin-rewrite") + << " output : " << retNode << ", constant = " << retNode.isConst() + << std::endl; + Trace("builtin-rewrite") + << " array rep : " << anode << ", constant = " << anode.isConst() + << std::endl; + Assert(anode.isConst() == retNode.isConst()); + Assert(retNode.getType() == node.getType()); + Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode)); + return retNode; + } + } + else + { + Trace("builtin-rewrite-debug") + << "...failed to get array representation." << std::endl; + } + return node; +} + } // namespace uf } // namespace theory } // namespace cvc5 diff --git a/src/theory/uf/theory_uf_rewriter.h b/src/theory/uf/theory_uf_rewriter.h index dfa797f71..31a6f4669 100644 --- a/src/theory/uf/theory_uf_rewriter.h +++ b/src/theory/uf/theory_uf_rewriter.h @@ -35,11 +35,11 @@ class TheoryUfRewriter : public TheoryRewriter { public: TheoryUfRewriter(bool isHigherOrder = false); + /** post-rewrite */ RewriteResponse postRewrite(TNode node) override; - + /** pre-rewrite */ RewriteResponse preRewrite(TNode node) override; - - public: // conversion between HO_APPLY AND APPLY_UF + // conversion between HO_APPLY AND APPLY_UF // converts an APPLY_UF to a curried HO_APPLY e.g. (f a b) becomes (@ (@ f a) // b) static Node getHoApplyForApplyUf(TNode n); @@ -62,6 +62,10 @@ class TheoryUfRewriter : public TheoryRewriter * Then, f and g can be used as APPLY_UF operators, but (ite C f g), (lambda x1. (f x1)) as well as the variable x above are not. */ static bool canUseAsApplyUfOperator(TNode n); + + private: + /** Entry point for rewriting lambdas */ + static Node rewriteLambda(Node node); /** Is the logic higher-order? */ bool d_isHigherOrder; }; /* class TheoryUfRewriter */ diff --git a/src/theory/uf/theory_uf_type_rules.cpp b/src/theory/uf/theory_uf_type_rules.cpp index 5b132fc27..a05c76d4c 100644 --- a/src/theory/uf/theory_uf_type_rules.cpp +++ b/src/theory/uf/theory_uf_type_rules.cpp @@ -19,6 +19,8 @@ #include #include "expr/cardinality_constraint.h" +#include "theory/uf/function_const.h" +#include "util/cardinality.h" #include "util/rational.h" namespace cvc5 { @@ -160,6 +162,112 @@ TypeNode HoApplyTypeRule::computeType(NodeManager* nodeManager, } } +TypeNode LambdaTypeRule::computeType(NodeManager* nodeManager, + TNode n, + bool check) +{ + if (n[0].getType(check) != nodeManager->boundVarListType()) + { + std::stringstream ss; + ss << "expected a bound var list for LAMBDA expression, got `" + << n[0].getType().toString() << "'"; + throw TypeCheckingExceptionPrivate(n, ss.str()); + } + std::vector argTypes; + for (TNode::iterator i = n[0].begin(); i != n[0].end(); ++i) + { + argTypes.push_back((*i).getType()); + } + TypeNode rangeType = n[1].getType(check); + return nodeManager->mkFunctionType(argTypes, rangeType); +} + +bool LambdaTypeRule::computeIsConst(NodeManager* nodeManager, TNode n) +{ + Assert(n.getKind() == kind::LAMBDA); + // get array representation of this function, if possible + Node na = FunctionConst::getArrayRepresentationForLambda(n); + if (!na.isNull()) + { + Assert(na.getType().isArray()); + Trace("lambda-const") << "Array representation for " << n << " is " << na + << " " << na.getType() << std::endl; + // must have the standard bound variable list + Node bvl = + NodeManager::currentNM()->getBoundVarListForFunctionType(n.getType()); + if (bvl == n[0]) + { + // array must be constant + if (na.isConst()) + { + Trace("lambda-const") << "*** Constant lambda : " << n; + Trace("lambda-const") << " since its array representation : " << na + << " is constant." << std::endl; + return true; + } + else + { + Trace("lambda-const") << "Non-constant lambda : " << n + << " since array is not constant." << std::endl; + } + } + else + { + Trace("lambda-const") + << "Non-constant lambda : " << n + << " since its varlist is not standard." << std::endl; + Trace("lambda-const") << " standard : " << bvl << std::endl; + Trace("lambda-const") << " current : " << n[0] << std::endl; + } + } + else + { + Trace("lambda-const") << "Non-constant lambda : " << n + << " since it has no array representation." + << std::endl; + } + return false; +} + +Cardinality FunctionProperties::computeCardinality(TypeNode type) +{ + // Don't assert this; allow other theories to use this cardinality + // computation. + // + // Assert(type.getKind() == kind::FUNCTION_TYPE); + + Cardinality argsCard(1); + // get the largest cardinality of function arguments/return type + for (size_t i = 0, i_end = type.getNumChildren() - 1; i < i_end; ++i) + { + argsCard *= type[i].getCardinality(); + } + + Cardinality valueCard = type[type.getNumChildren() - 1].getCardinality(); + + return valueCard ^ argsCard; +} + +bool FunctionProperties::isWellFounded(TypeNode type) +{ + for (TypeNode::iterator i = type.begin(), i_end = type.end(); i != i_end; ++i) + { + if (!(*i).isWellFounded()) + { + return false; + } + } + return true; +} + +Node FunctionProperties::mkGroundTerm(TypeNode type) +{ + NodeManager* nm = NodeManager::currentNM(); + Node bvl = nm->getBoundVarListForFunctionType(type); + Node ret = type.getRangeType().mkGroundTerm(); + return nm->mkNode(kind::LAMBDA, bvl, ret); +} + } // namespace uf } // namespace theory } // namespace cvc5 diff --git a/src/theory/uf/theory_uf_type_rules.h b/src/theory/uf/theory_uf_type_rules.h index b9451a500..6f0374ae6 100644 --- a/src/theory/uf/theory_uf_type_rules.h +++ b/src/theory/uf/theory_uf_type_rules.h @@ -69,6 +69,30 @@ class HoApplyTypeRule static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); }; +class LambdaTypeRule +{ + public: + static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check); + // computes whether a lambda is a constant value, via conversion to array + // representation + static bool computeIsConst(NodeManager* nodeManager, TNode n); +}; /* class LambdaTypeRule */ + +class FunctionProperties +{ + public: + static Cardinality computeCardinality(TypeNode type); + + /** Function type is well-founded if its component sorts are */ + static bool isWellFounded(TypeNode type); + /** + * Ground term for function sorts is (lambda x. t) where x is the + * canonical variable list for its type and t is the canonical ground term of + * its range. + */ + static Node mkGroundTerm(TypeNode type); +}; /* class FuctionProperties */ + } // namespace uf } // namespace theory } // namespace cvc5 diff --git a/src/theory/uf/type_enumerator.cpp b/src/theory/uf/type_enumerator.cpp new file mode 100644 index 000000000..a7f1f3ec3 --- /dev/null +++ b/src/theory/uf/type_enumerator.cpp @@ -0,0 +1,51 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Enumerator for functions. + */ + +#include "theory/uf/type_enumerator.h" + +#include "theory/uf/function_const.h" + +namespace cvc5 { +namespace theory { +namespace uf { + +FunctionEnumerator::FunctionEnumerator(TypeNode type, + TypeEnumeratorProperties* tep) + : TypeEnumeratorBase(type), + d_arrayEnum(FunctionConst::getArrayTypeForFunctionType(type), tep) +{ + Assert(type.getKind() == kind::FUNCTION_TYPE); + d_bvl = NodeManager::currentNM()->getBoundVarListForFunctionType(type); +} + +Node FunctionEnumerator::operator*() +{ + if (isFinished()) + { + throw NoMoreValuesException(getType()); + } + Node a = *d_arrayEnum; + return FunctionConst::getLambdaForArrayRepresentation(a, d_bvl); +} + +FunctionEnumerator& FunctionEnumerator::operator++() +{ + ++d_arrayEnum; + return *this; +} + +} // namespace uf +} // namespace theory +} // namespace cvc5 diff --git a/src/theory/uf/type_enumerator.h b/src/theory/uf/type_enumerator.h new file mode 100644 index 000000000..dfbbc1924 --- /dev/null +++ b/src/theory/uf/type_enumerator.h @@ -0,0 +1,59 @@ +/****************************************************************************** + * Top contributors (to current version): + * Andrew Reynolds + * + * This file is part of the cvc5 project. + * + * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS + * in the top-level source directory and their institutional affiliations. + * All rights reserved. See the file COPYING in the top-level source + * directory for licensing information. + * **************************************************************************** + * + * Enumerator for functions. + */ + +#include "cvc5_private.h" + +#ifndef CVC5__THEORY__UF__TYPE_ENUMERATOR_H +#define CVC5__THEORY__UF__TYPE_ENUMERATOR_H + +#include "expr/kind.h" +#include "expr/type_node.h" +#include "theory/type_enumerator.h" +#include "util/integer.h" + +namespace cvc5 { +namespace theory { +namespace uf { + +/** FunctionEnumerator + * This enumerates function values, based on the enumerator for the + * array type corresponding to the given function type. + */ +class FunctionEnumerator : public TypeEnumeratorBase +{ + public: + FunctionEnumerator(TypeNode type, TypeEnumeratorProperties* tep = nullptr); + /** Get the current term of the enumerator. */ + Node operator*() override; + /** Increment the enumerator. */ + FunctionEnumerator& operator++() override; + /** is the enumerator finished? */ + bool isFinished() override { return d_arrayEnum.isFinished(); } + + private: + /** Enumerates arrays, which we convert to functions. */ + TypeEnumerator d_arrayEnum; + /** The bound variable list for the function type we are enumerating. + * All terms output by this enumerator are of the form (LAMBDA d_bvl t) for + * some term t. + */ + Node d_bvl; +}; /* class FunctionEnumerator */ + +} // namespace uf +} // namespace theory +} // namespace cvc5 + +#endif /* CVC5__THEORY__UF__TYPE_ENUMERATOR_H */