Move functions and lambdas from builtin to uf (#7570)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 5 Nov 2021 20:12:22 +0000 (15:12 -0500)
committerGitHub <noreply@github.com>
Fri, 5 Nov 2021 20:12:22 +0000 (20:12 +0000)
This is in preparation for adding better native support for handling lambdas in the higher-order extension of the UF theory.

We require that LAMBDA and function types belong to theory UF so that the theory solver is properly notified.

This also splits the utility methods for computing whether a function is "constant" to its own file.

This PR is code move only.

18 files changed:
src/CMakeLists.txt
src/theory/builtin/kinds
src/theory/builtin/theory_builtin_rewriter.cpp
src/theory/builtin/theory_builtin_rewriter.h
src/theory/builtin/theory_builtin_type_rules.cpp
src/theory/builtin/theory_builtin_type_rules.h
src/theory/builtin/type_enumerator.cpp
src/theory/builtin/type_enumerator.h
src/theory/datatypes/kinds
src/theory/uf/function_const.cpp [new file with mode: 0644]
src/theory/uf/function_const.h [new file with mode: 0644]
src/theory/uf/kinds
src/theory/uf/theory_uf_rewriter.cpp
src/theory/uf/theory_uf_rewriter.h
src/theory/uf/theory_uf_type_rules.cpp
src/theory/uf/theory_uf_type_rules.h
src/theory/uf/type_enumerator.cpp [new file with mode: 0644]
src/theory/uf/type_enumerator.h [new file with mode: 0644]

index 1adf406955e3739dadf54197bc80410804035e35..c526bd13b3d045dd3d9eac7ddfb60fe3eafd7490 100644 (file)
@@ -1117,6 +1117,8 @@ libcvc5_add_sources(
   theory/uf/equality_engine_types.h
   theory/uf/eq_proof.cpp
   theory/uf/eq_proof.h
+  theory/uf/function_const.cpp
+  theory/uf/function_const.h
   theory/uf/proof_checker.cpp
   theory/uf/proof_checker.h
   theory/uf/proof_equality_engine.cpp
@@ -1133,6 +1135,8 @@ libcvc5_add_sources(
   theory/uf/theory_uf_rewriter.h
   theory/uf/theory_uf_type_rules.cpp
   theory/uf/theory_uf_type_rules.h
+  theory/uf/type_enumerator.cpp
+  theory/uf/type_enumerator.h
   theory/valuation.cpp
   theory/valuation.h
 )
index d4a8782b522861e64d3c7b53bb4bfde2863b76c6..381573a12cde281cf91488f5b079dd9753e0d26b 100644 (file)
@@ -262,7 +262,8 @@ parameterized SORT_TYPE SORT_TAG 0: "specifies types of user-declared 'uninterpr
 cardinality SORT_TYPE "Cardinality(Cardinality::INTEGERS)"
 well-founded SORT_TYPE \
     "::cvc5::theory::builtin::SortProperties::isWellFounded(%TYPE%)" \
-    "::cvc5::theory::builtin::SortProperties::mkGroundTerm(%TYPE%)"
+    "::cvc5::theory::builtin::SortProperties::mkGroundTerm(%TYPE%)" \
+    "theory/builtin/theory_builtin_type_rules.h"
 
 constant UNINTERPRETED_CONSTANT \
     class \
@@ -301,8 +302,6 @@ variable BOUND_VARIABLE "a bound variable (permitted in bindings and the associa
 variable SKOLEM "a Skolem variable (internal only)"
 operator SEXPR 0: "a symbolic expression (any arity)"
 
-operator LAMBDA 2 "a lambda expression; first parameter is a BOUND_VAR_LIST, second is lambda body"
-
 operator WITNESS 2 "a witness expression; first parameter is a BOUND_VAR_LIST, second is the witness body"
 
 constant TYPE_CONSTANT \
@@ -311,17 +310,6 @@ constant TYPE_CONSTANT \
     ::cvc5::TypeConstantHashFunction \
     "expr/kind.h" \
     "a representation for basic types"
-operator FUNCTION_TYPE 2: "a function type"
-cardinality FUNCTION_TYPE \
-    "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \
-    "theory/builtin/theory_builtin_type_rules.h"
-well-founded FUNCTION_TYPE \
-    "::cvc5::theory::builtin::FunctionProperties::isWellFounded(%TYPE%)" \
-    "::cvc5::theory::builtin::FunctionProperties::mkGroundTerm(%TYPE%)" \
-    "theory/builtin/theory_builtin_type_rules.h"
-enumerator FUNCTION_TYPE \
-    ::cvc5::theory::builtin::FunctionEnumerator \
-    "theory/builtin/type_enumerator.h"
 sort SEXPR_TYPE \
     Cardinality::INTEGERS \
     not-well-founded \
@@ -330,10 +318,6 @@ sort SEXPR_TYPE \
 typerule EQUAL ::cvc5::theory::builtin::EqualityTypeRule
 typerule DISTINCT ::cvc5::theory::builtin::DistinctTypeRule
 typerule SEXPR ::cvc5::theory::builtin::SExprTypeRule
-typerule LAMBDA ::cvc5::theory::builtin::LambdaTypeRule
 typerule WITNESS ::cvc5::theory::builtin::WitnessTypeRule
 
-# lambda expressions that are isomorphic to array constants can be considered constants
-construle LAMBDA ::cvc5::theory::builtin::LambdaTypeRule
-
 endtheory
index 0ee72fc5f22df880ca7828a4429cae8bf8e43b33..b57f2bf423c441513d19e7f4b2d5c1a6962a47ca 100644 (file)
@@ -18,7 +18,6 @@
 
 #include "theory/builtin/theory_builtin_rewriter.h"
 
-#include "expr/array_store_all.h"
 #include "expr/attribute.h"
 #include "expr/node_algorithm.h"
 #include "theory/rewriter.h"
@@ -55,45 +54,6 @@ Node TheoryBuiltinRewriter::blastDistinct(TNode in) {
 }
 
 RewriteResponse TheoryBuiltinRewriter::postRewrite(TNode node) {
-  if( node.getKind()==kind::LAMBDA ){
-    // The following code ensures that if node is equivalent to a constant
-    // lambda, then we return the canonical representation for the lambda, which
-    // in turn ensures that two constant lambdas are equivalent if and only
-    // if they are the same node.
-    // We canonicalize lambdas by turning them into array constants, applying
-    // normalization on array constants, and then converting the array constant
-    // back to a lambda.
-    Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
-    Node anode = getArrayRepresentationForLambda( node );
-    // Only rewrite constant array nodes, since these are the only cases
-    // where we require canonicalization of lambdas. Moreover, applying the
-    // below code is not correct if the arguments to the lambda occur
-    // in return values. For example, lambda x. ite( x=1, f(x), c ) would
-    // be converted to (store (storeall ... c) 1 f(x)), and then converted
-    // to lambda y. ite( y=1, f(x), c), losing the relation between x and y.
-    if (!anode.isNull() && anode.isConst())
-    {
-      Assert(anode.getType().isArray());
-      //must get the standard bound variable list
-      Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType( node.getType() );
-      Node retNode = getLambdaForArrayRepresentation( anode, varList );
-      if( !retNode.isNull() && retNode!=node ){
-        Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
-        Trace("builtin-rewrite") << "     input  : " << node << std::endl;
-        Trace("builtin-rewrite") << "     output : " << retNode << ", constant = " << retNode.isConst() << std::endl;
-        Trace("builtin-rewrite") << "  array rep : " << anode << ", constant = " << anode.isConst() << std::endl;
-        Assert(anode.isConst() == retNode.isConst());
-        Assert(retNode.getType() == node.getType());
-        Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode));
-        return RewriteResponse(REWRITE_DONE, retNode);
-      }
-    }
-    else
-    {
-      Trace("builtin-rewrite-debug") << "...failed to get array representation." << std::endl;
-    }
-    return RewriteResponse(REWRITE_DONE, node);
-  }
   // otherwise, do the default call
   return doRewrite(node);
 }
@@ -117,346 +77,6 @@ RewriteResponse TheoryBuiltinRewriter::doRewrite(TNode node)
   }
 }
 
-TypeNode TheoryBuiltinRewriter::getFunctionTypeForArrayType(TypeNode atn,
-                                                            Node bvl)
-{
-  std::vector<TypeNode> children;
-  for (unsigned i = 0; i < bvl.getNumChildren(); i++)
-  {
-    Assert(atn.isArray());
-    Assert(bvl[i].getType() == atn.getArrayIndexType());
-    children.push_back(atn.getArrayIndexType());
-    atn = atn.getArrayConstituentType();
-  }
-  children.push_back(atn);
-  return NodeManager::currentNM()->mkFunctionType(children);
-}
-
-TypeNode TheoryBuiltinRewriter::getArrayTypeForFunctionType(TypeNode ftn)
-{
-  Assert(ftn.isFunction());
-  // construct the curried array type
-  unsigned nchildren = ftn.getNumChildren();
-  TypeNode ret = ftn[nchildren - 1];
-  for (int i = (static_cast<int>(nchildren) - 2); i >= 0; i--)
-  {
-    ret = NodeManager::currentNM()->mkArrayType(ftn[i], ret);
-  }
-  return ret;
-}
-
-Node TheoryBuiltinRewriter::getLambdaForArrayRepresentationRec(
-    TNode a,
-    TNode bvl,
-    unsigned bvlIndex,
-    std::unordered_map<TNode, Node>& visited)
-{
-  std::unordered_map<TNode, Node>::iterator it = visited.find(a);
-  if( it==visited.end() ){
-    Node ret;
-    if( bvlIndex<bvl.getNumChildren() ){
-      Assert(a.getType().isArray());
-      if( a.getKind()==kind::STORE ){
-        // convert the array recursively
-        Node body = getLambdaForArrayRepresentationRec( a[0], bvl, bvlIndex, visited );
-        if( !body.isNull() ){
-          // convert the value recursively (bounded by the number of arguments in bvl)
-          Node val = getLambdaForArrayRepresentationRec( a[2], bvl, bvlIndex+1, visited );
-          if( !val.isNull() ){
-            Assert(!TypeNode::leastCommonTypeNode(a[1].getType(),
-                                                  bvl[bvlIndex].getType())
-                        .isNull());
-            Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType())
-                        .isNull());
-            Node cond = bvl[bvlIndex].eqNode( a[1] );
-            ret = NodeManager::currentNM()->mkNode( kind::ITE, cond, val, body );
-          }
-        }
-      }else if( a.getKind()==kind::STORE_ALL ){
-        ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
-        Node sa = storeAll.getValue();
-        // convert the default value recursively (bounded by the number of arguments in bvl)
-        ret = getLambdaForArrayRepresentationRec( sa, bvl, bvlIndex+1, visited );
-      }
-    }else{
-      ret = a;
-    }
-    visited[a] = ret;
-    return ret;
-  }else{
-    return it->second;
-  }
-}
-
-Node TheoryBuiltinRewriter::getLambdaForArrayRepresentation( TNode a, TNode bvl ){
-  Assert(a.getType().isArray());
-  std::unordered_map<TNode, Node> visited;
-  Trace("builtin-rewrite-debug") << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
-  Node body = getLambdaForArrayRepresentationRec( a, bvl, 0, visited );
-  if( !body.isNull() ){
-    body = Rewriter::rewrite( body );
-    Trace("builtin-rewrite-debug") << "...got lambda body " << body << std::endl;
-    return NodeManager::currentNM()->mkNode( kind::LAMBDA, bvl, body );
-  }else{
-    Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
-    return Node::null();
-  }
-}
-
-Node TheoryBuiltinRewriter::getArrayRepresentationForLambdaRec(TNode n,
-                                                               TypeNode retType)
-{
-  Assert(n.getKind() == kind::LAMBDA);
-  NodeManager* nm = NodeManager::currentNM();
-  Trace("builtin-rewrite-debug") << "Get array representation for : " << n << std::endl;
-
-  Node first_arg = n[0][0];
-  Node rec_bvl;
-  unsigned size = n[0].getNumChildren();
-  if (size > 1)
-  {
-    std::vector< Node > args;
-    for (unsigned i = 1; i < size; i++)
-    {
-      args.push_back( n[0][i] );
-    }
-    rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args);
-  }
-
-  Trace("builtin-rewrite-debug2") << "  process body..." << std::endl;
-  std::vector< Node > conds;
-  std::vector< Node > vals;
-  Node curr = n[1];
-  Kind ck = curr.getKind();
-  while (ck == kind::ITE || ck == kind::OR || ck == kind::AND
-         || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE)
-  {
-    Node index_eq;
-    Node curr_val;
-    Node next;
-    // Each iteration of this loop infers an entry in the function, e.g. it
-    // has a value under some condition.
-
-    // [1] We infer that the entry has value "curr_val" under condition
-    // "index_eq". We set "next" to the node that is the remainder of the
-    // function to process.
-    if (ck == kind::ITE)
-    {
-      Trace("builtin-rewrite-debug2")
-          << "  process condition : " << curr[0] << std::endl;
-      index_eq = curr[0];
-      curr_val = curr[1];
-      next = curr[2];
-    }
-    else if (ck == kind::OR || ck == kind::AND)
-    {
-      Trace("builtin-rewrite-debug2")
-          << "  process base : " << curr << std::endl;
-      // curr = Rewriter::rewrite(curr);
-      // Trace("builtin-rewrite-debug2")
-      //     << "  rewriten base : " << curr << std::endl;
-      // Complex Boolean return cases, in which
-      //  (1) lambda x. (= x v1) v ... becomes
-      //      lambda x. (ite (= x v1) true [...])
-      //
-      //  (2) lambda x. (not (= x v1)) ^ ... becomes
-      //      lambda x. (ite (= x v1) false [...])
-      //
-      // Note the negated cases of the lhs of the OR/AND operators above are
-      // handled by pushing the recursion to the then-branch, with the
-      // else-branch being the constant value. For example, the negated (1)
-      // would be
-      //  (1') lambda x. (not (= x v1)) v ... becomes
-      //       lambda x. (ite (= x v1) [...] true)
-      // thus requiring the rest of the disjunction to be further processed in
-      // the then-branch as the current value.
-      bool pol = curr[0].getKind() != kind::NOT;
-      bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR);
-      index_eq = pol ? curr[0] : curr[0][0];
-      // processed : the value that is determined by the first child of curr
-      // remainder : the remaining children of curr
-      Node processed, remainder;
-      // the value is the polarity of the first child or its inverse if we are
-      // in the inverted case
-      processed = nm->mkConst(!inverted? pol : !pol);
-      // build an OR/AND with the remaining components
-      if (curr.getNumChildren() == 2)
-      {
-        remainder = curr[1];
-      }
-      else
-      {
-        std::vector<Node> remainderNodes{curr.begin() + 1, curr.end()};
-        remainder = nm->mkNode(ck, remainderNodes);
-      }
-      if (inverted)
-      {
-        curr_val = remainder;
-        next = processed;
-        // If the lambda contains more variables than the one being currently
-        // processed, the current value can be non-constant, since it'll be
-        // processed recursively below. Otherwise we fail.
-        if (rec_bvl.isNull() && !curr_val.isConst())
-        {
-          Trace("builtin-rewrite-debug2")
-              << "...non-const curr_val " << curr_val << "\n";
-          return Node::null();
-        }
-      }
-      else
-      {
-        curr_val = processed;
-        next = remainder;
-      }
-      Trace("builtin-rewrite-debug2") << "  index_eq : " << index_eq << "\n";
-      Trace("builtin-rewrite-debug2") << "  curr_val : " << curr_val << "\n";
-      Trace("builtin-rewrite-debug2") << "  next : " << next << std::endl;
-    }
-    else
-    {
-      Trace("builtin-rewrite-debug2")
-          << "  process base : " << curr << std::endl;
-      // Simple Boolean return cases, in which
-      //  (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false)
-      //  (2) lambda x. v becomes lambda x. (ite (= x v) true false)
-      // Note the negateg cases of the bodies above are also handled.
-      bool pol = ck != kind::NOT;
-      index_eq = pol ? curr : curr[0];
-      curr_val = nm->mkConst(pol);
-      next = nm->mkConst(!pol);
-    }
-
-    // [2] We ensure that "index_eq" is an equality, if possible.
-    if (index_eq.getKind() != kind::EQUAL)
-    {
-      bool pol = index_eq.getKind() != kind::NOT;
-      Node indexEqAtom = pol ? index_eq : index_eq[0];
-      if (indexEqAtom.getKind() == kind::BOUND_VARIABLE)
-      {
-        if (!indexEqAtom.getType().isBoolean())
-        {
-          // Catches default case of non-Boolean variable, e.g.
-          // lambda x : Int. x. In this case, it is not canonical and we fail.
-          Trace("builtin-rewrite-debug2")
-              << "  ...non-Boolean variable." << std::endl;
-          return Node::null();
-        }
-        // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as
-        // lambda x. (ite (= x true) t s)
-        index_eq = indexEqAtom.eqNode(nm->mkConst(pol));
-      }
-      else
-      {
-        // non-equality condition
-        Trace("builtin-rewrite-debug2")
-            << "  ...non-equality condition." << std::endl;
-        return Node::null();
-      }
-    }
-    else if (Rewriter::rewrite(index_eq) != index_eq)
-    {
-      // equality must be oriented correctly based on rewriter
-      Trace("builtin-rewrite-debug2") << "  ...equality not oriented properly." << std::endl;
-      return Node::null();
-    }
-
-    // [3] We ensure that "index_eq" is an equality that is equivalent to
-    // "first_arg" = "curr_index", where curr_index is a constant, and
-    // "first_arg" is the current argument we are processing, if possible.
-    Node curr_index;
-    for( unsigned r=0; r<2; r++ ){
-      Node arg = index_eq[r];
-      Node val = index_eq[1-r];
-      if( arg==first_arg ){
-        if (!val.isConst())
-        {
-          // non-constant value
-          Trace("builtin-rewrite-debug2")
-              << "  ...non-constant value for argument\n.";
-          return Node::null();
-        }else{
-          curr_index = val;
-          Trace("builtin-rewrite-debug2")
-              << "  arg " << arg << " -> " << val << std::endl;
-          break;
-        }
-      }
-    }
-    if (curr_index.isNull())
-    {
-      Trace("builtin-rewrite-debug2")
-          << "  ...could not infer index value." << std::endl;
-      return Node::null();
-    }
-
-    // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the
-    // remaining arguments (rec_bvl).
-    if (!rec_bvl.isNull())
-    {
-      curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val);
-      Trace("builtin-rewrite-debug") << push;
-      Trace("builtin-rewrite-debug2") << push;
-      curr_val = getArrayRepresentationForLambdaRec(curr_val, retType);
-      Trace("builtin-rewrite-debug") << pop;
-      Trace("builtin-rewrite-debug2") << pop;
-      if (curr_val.isNull())
-      {
-        Trace("builtin-rewrite-debug2")
-            << "  ...failed to recursively find value." << std::endl;
-        return Node::null();
-      }
-    }
-    Trace("builtin-rewrite-debug2")
-        << "  ...condition is index " << curr_val << std::endl;
-
-    // [5] Add the entry
-    conds.push_back( curr_index );
-    vals.push_back( curr_val );
-
-    // we will now process the remainder
-    curr = next;
-    ck = curr.getKind();
-    Trace("builtin-rewrite-debug2")
-        << "  process remainder : " << curr << std::endl;
-  }
-  if( !rec_bvl.isNull() ){
-    curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr);
-    Trace("builtin-rewrite-debug") << push;
-    Trace("builtin-rewrite-debug2") << push;
-    curr = getArrayRepresentationForLambdaRec(curr, retType);
-    Trace("builtin-rewrite-debug") << pop;
-    Trace("builtin-rewrite-debug2") << pop;
-  }
-  if( !curr.isNull() && curr.isConst() ){
-    // compute the return type
-    TypeNode array_type = retType;
-    for (unsigned i = 0; i < size; i++)
-    {
-      unsigned index = (size - 1) - i;
-      array_type = nm->mkArrayType(n[0][index].getType(), array_type);
-    }
-    Trace("builtin-rewrite-debug2") << "  make array store all " << curr.getType() << " annotated : " << array_type << std::endl;
-    Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType()));
-    curr = nm->mkConst(ArrayStoreAll(array_type, curr));
-    Trace("builtin-rewrite-debug2") << "  build array..." << std::endl;
-    // can only build if default value is constant (since array store all must be constant)
-    Trace("builtin-rewrite-debug2") << "  got constant base " << curr << std::endl;
-    Trace("builtin-rewrite-debug2") << "  conditions " << conds << std::endl;
-    Trace("builtin-rewrite-debug2") << "  values " << vals << std::endl;
-    // construct store chain
-    for (int i = static_cast<int>(conds.size()) - 1; i >= 0; i--)
-    {
-      Assert(conds[i].getType().isSubtypeOf(first_arg.getType()));
-      curr = nm->mkNode(kind::STORE, curr, conds[i], vals[i]);
-    }
-    Trace("builtin-rewrite-debug") << "...got array " << curr << " for " << n << std::endl;
-    return curr;
-  }else{
-    Trace("builtin-rewrite-debug") << "...failed to get array (cannot get constant default value)" << std::endl;
-    return Node::null();    
-  }
-}
-
 Node TheoryBuiltinRewriter::rewriteWitness(TNode node)
 {
   Assert(node.getKind() == kind::WITNESS);
@@ -493,21 +113,6 @@ Node TheoryBuiltinRewriter::rewriteWitness(TNode node)
   return node;
 }
 
-Node TheoryBuiltinRewriter::getArrayRepresentationForLambda(TNode n)
-{
-  Assert(n.getKind() == kind::LAMBDA);
-  // must carry the overall return type to deal with cases like (lambda ((x Int)
-  // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else
-  // case above should be (arraystoreall (Array Int Real) 0.0)
-  Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType());
-  if (anode.isNull())
-  {
-    return anode;
-  }
-  // must rewrite it to make canonical
-  return Rewriter::rewrite(anode);
-}
-
 }  // namespace builtin
 }  // namespace theory
 }  // namespace cvc5
index f528ed43cf47bdd64a076421937d64b80abd6863..0f903bc448a544a51a0e05fa24f610863929215c 100644 (file)
@@ -37,17 +37,6 @@ class TheoryBuiltinRewriter : public TheoryRewriter
 
   RewriteResponse preRewrite(TNode node) override { return doRewrite(node); }
 
-  // conversions between lambdas and arrays
- private:
-  /** recursive helper for getLambdaForArrayRepresentation */
-  static Node getLambdaForArrayRepresentationRec(
-      TNode a,
-      TNode bvl,
-      unsigned bvlIndex,
-      std::unordered_map<TNode, Node>& visited);
-  /** recursive helper for getArrayRepresentationForLambda */
-  static Node getArrayRepresentationForLambdaRec(TNode n, TypeNode retType);
-
  public:
   /**
    * The default rewriter for rewrites that occur at both pre and post rewrite.
@@ -58,67 +47,6 @@ class TheoryBuiltinRewriter : public TheoryRewriter
    * Returns the rewritten form of node.
    */
   static Node rewriteWitness(TNode node);
-  /** Get function type for array type
-   *
-   * This returns the function type of terms returned by the function
-   * getLambdaForArrayRepresentation( t, bvl ),
-   * where t.getType()=atn.
-   *
-   * bvl should be a bound variable list whose variables correspond in-order
-   * to the index types of the (curried) Array type. For example, a bound
-   * variable list bvl whose variables have types (Int, Real) can be given as
-   * input when paired with atn = (Array Int (Array Real Bool)), or (Array Int
-   * (Array Real (Array Bool Bool))). This function returns (-> Int Real Bool)
-   * and (-> Int Real (Array Bool Bool)) respectively in these cases.
-   * On the other hand, the above bvl is not a proper input for
-   * atn = (Array Int (Array Bool Bool)) or (Array Int Int).
-   * If the types of bvl and atn do not match, we throw an assertion failure.
-   */
-  static TypeNode getFunctionTypeForArrayType(TypeNode atn, Node bvl);
-  /** Get array type for function type
-   *
-   * This returns the array type of terms returned by
-   * getArrayRepresentationForLambda( t ), where t.getType()=ftn.
-   */
-  static TypeNode getArrayTypeForFunctionType(TypeNode ftn);
-  /**
-   * Given an array constant a, returns a lambda expression that it corresponds
-   * to, with bound variable list bvl.
-   * Examples:
-   *
-   * (store (storeall (Array Int Int) 2) 0 1)
-   * becomes
-   * ((lambda x. (ite (= x 0) 1 2))
-   *
-   * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4))
-   * 0 (store (storeall (Array Int Int) 3) 1 2)) becomes (lambda xy. (ite (= x
-   * 0) (ite (= x 1) 2 3) 4))
-   *
-   * (store (store (storeall (Array Int Bool) false) 2 true) 1 true)
-   * becomes
-   * (lambda x. (ite (= x 1) true (ite (= x 2) true false)))
-   *
-   * Notice that the return body of the lambda is rewritten to ensure that the
-   * representation is canonical. Hence the last
-   * example will in fact be returned as:
-   * (lambda x. (ite (= x 1) true (= x 2)))
-   */
-  static Node getLambdaForArrayRepresentation(TNode a, TNode bvl);
-  /**
-   * Given a lambda expression n, returns an array term that corresponds to n.
-   * This does the opposite direction of the examples described above.
-   *
-   * We limit the return values of this method to be almost constant functions,
-   * that is, arrays of the form:
-   *   (store ... (store (storeall _ b) i1 e1) ... in en)
-   * where b, i1, e1, ..., in, en are constants.
-   * Notice however that the return value of this form need not be a (canonical)
-   * array constant.
-   *
-   * If it is not possible to construct an array of this form that corresponds
-   * to n, this method returns null.
-   */
-  static Node getArrayRepresentationForLambda(TNode n);
 }; /* class TheoryBuiltinRewriter */
 
 }  // namespace builtin
index 1888069bce8fde8c8fce56250068d0f0a9168200..636952be5a0ed5515b1dc66e774873dbf2a705ef 100644 (file)
@@ -18,7 +18,6 @@
 #include "expr/attribute.h"
 #include "expr/skolem_manager.h"
 #include "expr/uninterpreted_constant.h"
-#include "util/cardinality.h"
 
 namespace cvc5 {
 namespace theory {
@@ -56,25 +55,6 @@ Node SortProperties::mkGroundTerm(TypeNode type)
   return k;
 }
 
-Cardinality FunctionProperties::computeCardinality(TypeNode type)
-{
-  // Don't assert this; allow other theories to use this cardinality
-  // computation.
-  //
-  // Assert(type.getKind() == kind::FUNCTION_TYPE);
-
-  Cardinality argsCard(1);
-  // get the largest cardinality of function arguments/return type
-  for (size_t i = 0, i_end = type.getNumChildren() - 1; i < i_end; ++i)
-  {
-    argsCard *= type[i].getCardinality();
-  }
-
-  Cardinality valueCard = type[type.getNumChildren() - 1].getCardinality();
-
-  return valueCard ^ argsCard;
-}
-
 }  // namespace builtin
 }  // namespace theory
 }  // namespace cvc5
index 54139c4337aa91ce1bb03e5cd5663c0bd76de0c0..2117249c2609bcde3cdd3ce785bd6dee17222a2f 100644 (file)
@@ -20,7 +20,6 @@
 
 #include "expr/node.h"
 #include "expr/type_node.h"
-#include "theory/builtin/theory_builtin_rewriter.h" // for array and lambda representation
 
 #include <sstream>
 
@@ -107,54 +106,6 @@ class AbstractValueTypeRule {
   }
 };/* class AbstractValueTypeRule */
 
-class LambdaTypeRule {
- public:
-  inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) {
-    if( n[0].getType(check) != nodeManager->boundVarListType() ) {
-      std::stringstream ss;
-      ss << "expected a bound var list for LAMBDA expression, got `"
-         << n[0].getType().toString() << "'";
-      throw TypeCheckingExceptionPrivate(n, ss.str());
-    }
-    std::vector<TypeNode> argTypes;
-    for(TNode::iterator i = n[0].begin(); i != n[0].end(); ++i) {
-      argTypes.push_back((*i).getType());
-    }
-    TypeNode rangeType = n[1].getType(check);
-    return nodeManager->mkFunctionType(argTypes, rangeType);
-  }
-  // computes whether a lambda is a constant value, via conversion to array representation
-  inline static bool computeIsConst(NodeManager* nodeManager, TNode n)
-  {
-    Assert(n.getKind() == kind::LAMBDA);
-    //get array representation of this function, if possible
-    Node na = TheoryBuiltinRewriter::getArrayRepresentationForLambda(n);
-    if( !na.isNull() ){
-      Assert(na.getType().isArray());
-      Trace("lambda-const") << "Array representation for " << n << " is " << na << " " << na.getType() << std::endl;
-      // must have the standard bound variable list
-      Node bvl = NodeManager::currentNM()->getBoundVarListForFunctionType( n.getType() );
-      if( bvl==n[0] ){
-        //array must be constant
-        if( na.isConst() ){
-          Trace("lambda-const") << "*** Constant lambda : " << n;
-          Trace("lambda-const") << " since its array representation : " << na << " is constant." << std::endl;
-          return true;
-        }else{
-          Trace("lambda-const") << "Non-constant lambda : " << n << " since array is not constant." << std::endl;
-        } 
-      }else{
-        Trace("lambda-const") << "Non-constant lambda : " << n << " since its varlist is not standard." << std::endl;
-        Trace("lambda-const") << "  standard : " << bvl << std::endl;
-        Trace("lambda-const") << "   current : " << n[0] << std::endl;
-      } 
-    }else{
-      Trace("lambda-const") << "Non-constant lambda : " << n << " since it has no array representation." << std::endl;
-    } 
-    return false;
-  }
-};/* class LambdaTypeRule */
-
 class WitnessTypeRule
 {
  public:
@@ -198,37 +149,6 @@ class SortProperties {
   static Node mkGroundTerm(TypeNode type);
 };/* class SortProperties */
 
-class FunctionProperties {
- public:
-  static Cardinality computeCardinality(TypeNode type);
-
-  /** Function type is well-founded if its component sorts are */
-  static bool isWellFounded(TypeNode type)
-  {
-    for (TypeNode::iterator i = type.begin(), i_end = type.end(); i != i_end;
-         ++i)
-    {
-      if (!(*i).isWellFounded())
-      {
-        return false;
-      }
-    }
-    return true;
-  }
-  /**
-   * Ground term for function sorts is (lambda x. t) where x is the
-   * canonical variable list for its type and t is the canonical ground term of
-   * its range.
-   */
-  static Node mkGroundTerm(TypeNode type)
-  {
-    NodeManager* nm = NodeManager::currentNM();
-    Node bvl = nm->getBoundVarListForFunctionType(type);
-    Node ret = type.getRangeType().mkGroundTerm();
-    return nm->mkNode(kind::LAMBDA, bvl, ret);
-  }
-};/* class FuctionProperties */
-
 }  // namespace builtin
 }  // namespace theory
 }  // namespace cvc5
index 0ef1d3ec7e340a2684205bd5d8731c430d639e3c..2e919810bde427654b2fb5505e77d70ec8891d03 100644 (file)
@@ -21,31 +21,54 @@ namespace cvc5 {
 namespace theory {
 namespace builtin {
 
-FunctionEnumerator::FunctionEnumerator(TypeNode type,
-                                       TypeEnumeratorProperties* tep)
-    : TypeEnumeratorBase<FunctionEnumerator>(type),
-      d_arrayEnum(TheoryBuiltinRewriter::getArrayTypeForFunctionType(type), tep)
+UninterpretedSortEnumerator::UninterpretedSortEnumerator(
+    TypeNode type, TypeEnumeratorProperties* tep)
+    : TypeEnumeratorBase<UninterpretedSortEnumerator>(type), d_count(0)
 {
-  Assert(type.getKind() == kind::FUNCTION_TYPE);
-  d_bvl = NodeManager::currentNM()->getBoundVarListForFunctionType(type);
+  Assert(type.getKind() == kind::SORT_TYPE);
+  d_has_fixed_bound = false;
+  Trace("uf-type-enum") << "UF enum " << type << ", tep = " << tep << std::endl;
+  if (tep && tep->d_fixed_usort_card)
+  {
+    d_has_fixed_bound = true;
+    std::map<TypeNode, Integer>::iterator it = tep->d_fixed_card.find(type);
+    if (it != tep->d_fixed_card.end())
+    {
+      d_fixed_bound = it->second;
+    }
+    else
+    {
+      d_fixed_bound = Integer(1);
+    }
+    Trace("uf-type-enum") << "...fixed bound : " << d_fixed_bound << std::endl;
+  }
 }
 
-Node FunctionEnumerator::operator*()
+Node UninterpretedSortEnumerator::operator*()
 {
   if (isFinished())
   {
     throw NoMoreValuesException(getType());
   }
-  Node a = *d_arrayEnum;
-  return TheoryBuiltinRewriter::getLambdaForArrayRepresentation(a, d_bvl);
+  return NodeManager::currentNM()->mkConst(
+      UninterpretedConstant(getType(), d_count));
 }
 
-FunctionEnumerator& FunctionEnumerator::operator++()
+UninterpretedSortEnumerator& UninterpretedSortEnumerator::operator++()
 {
-  ++d_arrayEnum;
+  d_count += 1;
   return *this;
 }
 
+bool UninterpretedSortEnumerator::isFinished()
+{
+  if (d_has_fixed_bound)
+  {
+    return d_count >= d_fixed_bound;
+  }
+  return false;
+}
+
 }  // namespace builtin
 }  // namespace theory
 }  // namespace cvc5
index 980792f940d72aadea281d1f868dd0659c3dc16b..711752e2366752db1e3fe25c0d3ed8c92d9a0556 100644 (file)
@@ -35,73 +35,14 @@ class UninterpretedSortEnumerator : public TypeEnumeratorBase<UninterpretedSortE
 
  public:
   UninterpretedSortEnumerator(TypeNode type,
-                              TypeEnumeratorProperties* tep = nullptr)
-      : TypeEnumeratorBase<UninterpretedSortEnumerator>(type), d_count(0)
-  {
-    Assert(type.getKind() == kind::SORT_TYPE);
-    d_has_fixed_bound = false;
-    Trace("uf-type-enum") << "UF enum " << type << ", tep = " << tep << std::endl;
-    if( tep && tep->d_fixed_usort_card ){
-      d_has_fixed_bound = true;
-      std::map< TypeNode, Integer >::iterator it = tep->d_fixed_card.find( type );
-      if( it!=tep->d_fixed_card.end() ){
-        d_fixed_bound = it->second;
-      }else{
-        d_fixed_bound = Integer(1);
-      }
-      Trace("uf-type-enum") << "...fixed bound : " << d_fixed_bound << std::endl;
-    }
-  }
+                              TypeEnumeratorProperties* tep = nullptr);
 
-  Node operator*() override
-  {
-    if(isFinished()) {
-      throw NoMoreValuesException(getType());
-    }
-    return NodeManager::currentNM()->mkConst(
-        UninterpretedConstant(getType(), d_count));
-  }
-
-  UninterpretedSortEnumerator& operator++() override
-  {
-    d_count += 1;
-    return *this;
-  }
-
-  bool isFinished() override
-  {
-    if( d_has_fixed_bound ){
-      return d_count>=d_fixed_bound;
-    }else{
-      return false;
-    }
-  }
+  Node operator*() override;
 
-};/* class UninterpretedSortEnumerator */
+  UninterpretedSortEnumerator& operator++() override;
 
-/** FunctionEnumerator
-* This enumerates function values, based on the enumerator for the
-* array type corresponding to the given function type.
-*/
-class FunctionEnumerator : public TypeEnumeratorBase<FunctionEnumerator>
-{
- public:
-  FunctionEnumerator(TypeNode type, TypeEnumeratorProperties* tep = nullptr);
-  /** Get the current term of the enumerator. */
-  Node operator*() override;
-  /** Increment the enumerator. */
-  FunctionEnumerator& operator++() override;
-  /** is the enumerator finished? */
-  bool isFinished() override { return d_arrayEnum.isFinished(); }
- private:
-  /** Enumerates arrays, which we convert to functions. */
-  TypeEnumerator d_arrayEnum;
-  /** The bound variable list for the function type we are enumerating.
-  * All terms output by this enumerator are of the form (LAMBDA d_bvl t) for
-  * some term t.
-  */
-  Node d_bvl;
-}; /* class FunctionEnumerator */
+  bool isFinished() override;
+};
 
 }  // namespace builtin
 }  // namespace theory
index cb3a78cf275bda9cce64e367362d6a61f3ba2ef1..5324e1c79d4be1d81e5fe5c92b8470d6fb231f8f 100644 (file)
@@ -21,22 +21,22 @@ cardinality CONSTRUCTOR_TYPE \
 operator SELECTOR_TYPE 2 "selector"
 # can re-use function cardinality
 cardinality SELECTOR_TYPE \
-    "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \
-    "theory/builtin/theory_builtin_type_rules.h"
+    "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \
+    "theory/uf/theory_uf_type_rules.h"
 
 # tester type has a constructor type
 operator TESTER_TYPE 1 "tester"
 # can re-use function cardinality
 cardinality TESTER_TYPE \
-    "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \
-    "theory/builtin/theory_builtin_type_rules.h"
+    "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \
+    "theory/uf/theory_uf_type_rules.h"
 
 # tester type has a constructor type
 operator UPDATER_TYPE 2 "datatype update"
 # can re-use function cardinality
 cardinality UPDATER_TYPE \
-    "::cvc5::theory::builtin::FunctionProperties::computeCardinality(%TYPE%)" \
-    "theory/builtin/theory_builtin_type_rules.h"
+    "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \
+    "theory/uf/theory_uf_type_rules.h"
 
 parameterized APPLY_CONSTRUCTOR APPLY_TYPE_ASCRIPTION 0: "constructor application; first parameter is the constructor, remaining parameters (if any) are parameters to the constructor"
 
diff --git a/src/theory/uf/function_const.cpp b/src/theory/uf/function_const.cpp
new file mode 100644 (file)
index 0000000..181cb20
--- /dev/null
@@ -0,0 +1,412 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utilities for function constants
+ */
+
+#include "theory/uf/function_const.h"
+
+#include "expr/array_store_all.h"
+#include "theory/rewriter.h"
+
+namespace cvc5 {
+namespace theory {
+namespace uf {
+
+TypeNode FunctionConst::getFunctionTypeForArrayType(TypeNode atn, Node bvl)
+{
+  std::vector<TypeNode> children;
+  for (unsigned i = 0; i < bvl.getNumChildren(); i++)
+  {
+    Assert(atn.isArray());
+    Assert(bvl[i].getType() == atn.getArrayIndexType());
+    children.push_back(atn.getArrayIndexType());
+    atn = atn.getArrayConstituentType();
+  }
+  children.push_back(atn);
+  return NodeManager::currentNM()->mkFunctionType(children);
+}
+
+TypeNode FunctionConst::getArrayTypeForFunctionType(TypeNode ftn)
+{
+  Assert(ftn.isFunction());
+  // construct the curried array type
+  size_t nchildren = ftn.getNumChildren();
+  TypeNode ret = ftn[nchildren - 1];
+  for (size_t i = 0; i < nchildren - 1; i++)
+  {
+    size_t ii = nchildren - i - 2;
+    ret = NodeManager::currentNM()->mkArrayType(ftn[ii], ret);
+  }
+  return ret;
+}
+
+Node FunctionConst::getLambdaForArrayRepresentationRec(
+    TNode a,
+    TNode bvl,
+    unsigned bvlIndex,
+    std::unordered_map<TNode, Node>& visited)
+{
+  std::unordered_map<TNode, Node>::iterator it = visited.find(a);
+  if (it != visited.end())
+  {
+    return it->second;
+  }
+  Node ret;
+  if (bvlIndex < bvl.getNumChildren())
+  {
+    Assert(a.getType().isArray());
+    if (a.getKind() == kind::STORE)
+    {
+      // convert the array recursively
+      Node body =
+          getLambdaForArrayRepresentationRec(a[0], bvl, bvlIndex, visited);
+      if (!body.isNull())
+      {
+        // convert the value recursively (bounded by the number of arguments
+        // in bvl)
+        Node val = getLambdaForArrayRepresentationRec(
+            a[2], bvl, bvlIndex + 1, visited);
+        if (!val.isNull())
+        {
+          Assert(!TypeNode::leastCommonTypeNode(a[1].getType(),
+                                                bvl[bvlIndex].getType())
+                      .isNull());
+          Assert(!TypeNode::leastCommonTypeNode(val.getType(), body.getType())
+                      .isNull());
+          Node cond = bvl[bvlIndex].eqNode(a[1]);
+          ret = NodeManager::currentNM()->mkNode(kind::ITE, cond, val, body);
+        }
+      }
+    }
+    else if (a.getKind() == kind::STORE_ALL)
+    {
+      ArrayStoreAll storeAll = a.getConst<ArrayStoreAll>();
+      Node sa = storeAll.getValue();
+      // convert the default value recursively (bounded by the number of
+      // arguments in bvl)
+      ret = getLambdaForArrayRepresentationRec(sa, bvl, bvlIndex + 1, visited);
+    }
+  }
+  else
+  {
+    ret = a;
+  }
+  visited[a] = ret;
+  return ret;
+}
+
+Node FunctionConst::getLambdaForArrayRepresentation(TNode a, TNode bvl)
+{
+  Assert(a.getType().isArray());
+  std::unordered_map<TNode, Node> visited;
+  Trace("builtin-rewrite-debug")
+      << "Get lambda for : " << a << ", with variables " << bvl << std::endl;
+  Node body = getLambdaForArrayRepresentationRec(a, bvl, 0, visited);
+  if (!body.isNull())
+  {
+    body = Rewriter::rewrite(body);
+    Trace("builtin-rewrite-debug")
+        << "...got lambda body " << body << std::endl;
+    return NodeManager::currentNM()->mkNode(kind::LAMBDA, bvl, body);
+  }
+  Trace("builtin-rewrite-debug") << "...failed to get lambda body" << std::endl;
+  return Node::null();
+}
+
+Node FunctionConst::getArrayRepresentationForLambdaRec(TNode n,
+                                                       TypeNode retType)
+{
+  Assert(n.getKind() == kind::LAMBDA);
+  NodeManager* nm = NodeManager::currentNM();
+  Trace("builtin-rewrite-debug")
+      << "Get array representation for : " << n << std::endl;
+
+  Node first_arg = n[0][0];
+  Node rec_bvl;
+  size_t size = n[0].getNumChildren();
+  if (size > 1)
+  {
+    std::vector<Node> args;
+    for (size_t i = 1; i < size; i++)
+    {
+      args.push_back(n[0][i]);
+    }
+    rec_bvl = nm->mkNode(kind::BOUND_VAR_LIST, args);
+  }
+
+  Trace("builtin-rewrite-debug2") << "  process body..." << std::endl;
+  std::vector<Node> conds;
+  std::vector<Node> vals;
+  Node curr = n[1];
+  Kind ck = curr.getKind();
+  while (ck == kind::ITE || ck == kind::OR || ck == kind::AND
+         || ck == kind::EQUAL || ck == kind::NOT || ck == kind::BOUND_VARIABLE)
+  {
+    Node index_eq;
+    Node curr_val;
+    Node next;
+    // Each iteration of this loop infers an entry in the function, e.g. it
+    // has a value under some condition.
+
+    // [1] We infer that the entry has value "curr_val" under condition
+    // "index_eq". We set "next" to the node that is the remainder of the
+    // function to process.
+    if (ck == kind::ITE)
+    {
+      Trace("builtin-rewrite-debug2")
+          << "  process condition : " << curr[0] << std::endl;
+      index_eq = curr[0];
+      curr_val = curr[1];
+      next = curr[2];
+    }
+    else if (ck == kind::OR || ck == kind::AND)
+    {
+      Trace("builtin-rewrite-debug2")
+          << "  process base : " << curr << std::endl;
+      // curr = Rewriter::rewrite(curr);
+      // Trace("builtin-rewrite-debug2")
+      //     << "  rewriten base : " << curr << std::endl;
+      // Complex Boolean return cases, in which
+      //  (1) lambda x. (= x v1) v ... becomes
+      //      lambda x. (ite (= x v1) true [...])
+      //
+      //  (2) lambda x. (not (= x v1)) ^ ... becomes
+      //      lambda x. (ite (= x v1) false [...])
+      //
+      // Note the negated cases of the lhs of the OR/AND operators above are
+      // handled by pushing the recursion to the then-branch, with the
+      // else-branch being the constant value. For example, the negated (1)
+      // would be
+      //  (1') lambda x. (not (= x v1)) v ... becomes
+      //       lambda x. (ite (= x v1) [...] true)
+      // thus requiring the rest of the disjunction to be further processed in
+      // the then-branch as the current value.
+      bool pol = curr[0].getKind() != kind::NOT;
+      bool inverted = (pol && ck == kind::AND) || (!pol && ck == kind::OR);
+      index_eq = pol ? curr[0] : curr[0][0];
+      // processed : the value that is determined by the first child of curr
+      // remainder : the remaining children of curr
+      Node processed, remainder;
+      // the value is the polarity of the first child or its inverse if we are
+      // in the inverted case
+      processed = nm->mkConst(!inverted ? pol : !pol);
+      // build an OR/AND with the remaining components
+      if (curr.getNumChildren() == 2)
+      {
+        remainder = curr[1];
+      }
+      else
+      {
+        std::vector<Node> remainderNodes{curr.begin() + 1, curr.end()};
+        remainder = nm->mkNode(ck, remainderNodes);
+      }
+      if (inverted)
+      {
+        curr_val = remainder;
+        next = processed;
+        // If the lambda contains more variables than the one being currently
+        // processed, the current value can be non-constant, since it'll be
+        // processed recursively below. Otherwise we fail.
+        if (rec_bvl.isNull() && !curr_val.isConst())
+        {
+          Trace("builtin-rewrite-debug2")
+              << "...non-const curr_val " << curr_val << "\n";
+          return Node::null();
+        }
+      }
+      else
+      {
+        curr_val = processed;
+        next = remainder;
+      }
+      Trace("builtin-rewrite-debug2") << "  index_eq : " << index_eq << "\n";
+      Trace("builtin-rewrite-debug2") << "  curr_val : " << curr_val << "\n";
+      Trace("builtin-rewrite-debug2") << "  next : " << next << std::endl;
+    }
+    else
+    {
+      Trace("builtin-rewrite-debug2")
+          << "  process base : " << curr << std::endl;
+      // Simple Boolean return cases, in which
+      //  (1) lambda x. (= x v) becomes lambda x. (ite (= x v) true false)
+      //  (2) lambda x. v becomes lambda x. (ite (= x v) true false)
+      // Note the negateg cases of the bodies above are also handled.
+      bool pol = ck != kind::NOT;
+      index_eq = pol ? curr : curr[0];
+      curr_val = nm->mkConst(pol);
+      next = nm->mkConst(!pol);
+    }
+
+    // [2] We ensure that "index_eq" is an equality, if possible.
+    if (index_eq.getKind() != kind::EQUAL)
+    {
+      bool pol = index_eq.getKind() != kind::NOT;
+      Node indexEqAtom = pol ? index_eq : index_eq[0];
+      if (indexEqAtom.getKind() == kind::BOUND_VARIABLE)
+      {
+        if (!indexEqAtom.getType().isBoolean())
+        {
+          // Catches default case of non-Boolean variable, e.g.
+          // lambda x : Int. x. In this case, it is not canonical and we fail.
+          Trace("builtin-rewrite-debug2")
+              << "  ...non-Boolean variable." << std::endl;
+          return Node::null();
+        }
+        // Boolean argument case, e.g. lambda x. ite( x, t, s ) is processed as
+        // lambda x. (ite (= x true) t s)
+        index_eq = indexEqAtom.eqNode(nm->mkConst(pol));
+      }
+      else
+      {
+        // non-equality condition
+        Trace("builtin-rewrite-debug2")
+            << "  ...non-equality condition." << std::endl;
+        return Node::null();
+      }
+    }
+    else if (Rewriter::rewrite(index_eq) != index_eq)
+    {
+      // equality must be oriented correctly based on rewriter
+      Trace("builtin-rewrite-debug2")
+          << "  ...equality not oriented properly." << std::endl;
+      return Node::null();
+    }
+
+    // [3] We ensure that "index_eq" is an equality that is equivalent to
+    // "first_arg" = "curr_index", where curr_index is a constant, and
+    // "first_arg" is the current argument we are processing, if possible.
+    Node curr_index;
+    for (unsigned r = 0; r < 2; r++)
+    {
+      Node arg = index_eq[r];
+      Node val = index_eq[1 - r];
+      if (arg == first_arg)
+      {
+        if (!val.isConst())
+        {
+          // non-constant value
+          Trace("builtin-rewrite-debug2")
+              << "  ...non-constant value for argument\n.";
+          return Node::null();
+        }
+        else
+        {
+          curr_index = val;
+          Trace("builtin-rewrite-debug2")
+              << "  arg " << arg << " -> " << val << std::endl;
+          break;
+        }
+      }
+    }
+    if (curr_index.isNull())
+    {
+      Trace("builtin-rewrite-debug2")
+          << "  ...could not infer index value." << std::endl;
+      return Node::null();
+    }
+
+    // [4] Recurse to ensure that "curr_val" has been normalized w.r.t. the
+    // remaining arguments (rec_bvl).
+    if (!rec_bvl.isNull())
+    {
+      curr_val = nm->mkNode(kind::LAMBDA, rec_bvl, curr_val);
+      Trace("builtin-rewrite-debug") << push;
+      Trace("builtin-rewrite-debug2") << push;
+      curr_val = getArrayRepresentationForLambdaRec(curr_val, retType);
+      Trace("builtin-rewrite-debug") << pop;
+      Trace("builtin-rewrite-debug2") << pop;
+      if (curr_val.isNull())
+      {
+        Trace("builtin-rewrite-debug2")
+            << "  ...failed to recursively find value." << std::endl;
+        return Node::null();
+      }
+    }
+    Trace("builtin-rewrite-debug2")
+        << "  ...condition is index " << curr_val << std::endl;
+
+    // [5] Add the entry
+    conds.push_back(curr_index);
+    vals.push_back(curr_val);
+
+    // we will now process the remainder
+    curr = next;
+    ck = curr.getKind();
+    Trace("builtin-rewrite-debug2")
+        << "  process remainder : " << curr << std::endl;
+  }
+  if (!rec_bvl.isNull())
+  {
+    curr = nm->mkNode(kind::LAMBDA, rec_bvl, curr);
+    Trace("builtin-rewrite-debug") << push;
+    Trace("builtin-rewrite-debug2") << push;
+    curr = getArrayRepresentationForLambdaRec(curr, retType);
+    Trace("builtin-rewrite-debug") << pop;
+    Trace("builtin-rewrite-debug2") << pop;
+  }
+  if (!curr.isNull() && curr.isConst())
+  {
+    // compute the return type
+    TypeNode array_type = retType;
+    for (size_t i = 0; i < size; i++)
+    {
+      size_t index = (size - 1) - i;
+      array_type = nm->mkArrayType(n[0][index].getType(), array_type);
+    }
+    Trace("builtin-rewrite-debug2")
+        << "  make array store all " << curr.getType()
+        << " annotated : " << array_type << std::endl;
+    Assert(curr.getType().isSubtypeOf(array_type.getArrayConstituentType()));
+    curr = nm->mkConst(ArrayStoreAll(array_type, curr));
+    Trace("builtin-rewrite-debug2") << "  build array..." << std::endl;
+    // can only build if default value is constant (since array store all must
+    // be constant)
+    Trace("builtin-rewrite-debug2")
+        << "  got constant base " << curr << std::endl;
+    Trace("builtin-rewrite-debug2") << "  conditions " << conds << std::endl;
+    Trace("builtin-rewrite-debug2") << "  values " << vals << std::endl;
+    // construct store chain
+    for (size_t i = 0, numCond = conds.size(); i < numCond; i++)
+    {
+      size_t ii = (numCond - 1) - i;
+      Assert(conds[ii].getType().isSubtypeOf(first_arg.getType()));
+      curr = nm->mkNode(kind::STORE, curr, conds[ii], vals[ii]);
+    }
+    Trace("builtin-rewrite-debug")
+        << "...got array " << curr << " for " << n << std::endl;
+    return curr;
+  }
+  Trace("builtin-rewrite-debug")
+      << "...failed to get array (cannot get constant default value)"
+      << std::endl;
+  return Node::null();
+}
+
+Node FunctionConst::getArrayRepresentationForLambda(TNode n)
+{
+  Assert(n.getKind() == kind::LAMBDA);
+  // must carry the overall return type to deal with cases like (lambda ((x Int)
+  // (y Int)) (ite (= x _) 0.5 0.0)), where the inner construction for the else
+  // case above should be (arraystoreall (Array Int Real) 0.0)
+  Node anode = getArrayRepresentationForLambdaRec(n, n[1].getType());
+  if (anode.isNull())
+  {
+    return anode;
+  }
+  // must rewrite it to make canonical
+  return Rewriter::rewrite(anode);
+}
+
+}  // namespace uf
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/uf/function_const.h b/src/theory/uf/function_const.h
new file mode 100644 (file)
index 0000000..10d1bf8
--- /dev/null
@@ -0,0 +1,110 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Utilities for function constants
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__UF__FUNCTION_CONST_H
+#define CVC5__THEORY__UF__FUNCTION_CONST_H
+
+#include <unordered_map>
+
+#include "expr/node.h"
+
+namespace cvc5 {
+namespace theory {
+namespace uf {
+
+/** Conversion between lambda and array constants */
+class FunctionConst
+{
+ public:
+  /** Get function type for array type
+   *
+   * This returns the function type of terms returned by the function
+   * getLambdaForArrayRepresentation( t, bvl ),
+   * where t.getType()=atn.
+   *
+   * bvl should be a bound variable list whose variables correspond in-order
+   * to the index types of the (curried) Array type. For example, a bound
+   * variable list bvl whose variables have types (Int, Real) can be given as
+   * input when paired with atn = (Array Int (Array Real Bool)), or (Array Int
+   * (Array Real (Array Bool Bool))). This function returns (-> Int Real Bool)
+   * and (-> Int Real (Array Bool Bool)) respectively in these cases.
+   * On the other hand, the above bvl is not a proper input for
+   * atn = (Array Int (Array Bool Bool)) or (Array Int Int).
+   * If the types of bvl and atn do not match, we throw an assertion failure.
+   */
+  static TypeNode getFunctionTypeForArrayType(TypeNode atn, Node bvl);
+  /** Get array type for function type
+   *
+   * This returns the array type of terms returned by
+   * getArrayRepresentationForLambda( t ), where t.getType()=ftn.
+   */
+  static TypeNode getArrayTypeForFunctionType(TypeNode ftn);
+  /**
+   * Given an array constant a, returns a lambda expression that it corresponds
+   * to, with bound variable list bvl.
+   * Examples:
+   *
+   * (store (storeall (Array Int Int) 2) 0 1)
+   * becomes
+   * ((lambda x. (ite (= x 0) 1 2))
+   *
+   * (store (storeall (Array Int (Array Int Int)) (storeall (Array Int Int) 4))
+   * 0 (store (storeall (Array Int Int) 3) 1 2)) becomes (lambda xy. (ite (= x
+   * 0) (ite (= x 1) 2 3) 4))
+   *
+   * (store (store (storeall (Array Int Bool) false) 2 true) 1 true)
+   * becomes
+   * (lambda x. (ite (= x 1) true (ite (= x 2) true false)))
+   *
+   * Notice that the return body of the lambda is rewritten to ensure that the
+   * representation is canonical. Hence the last
+   * example will in fact be returned as:
+   * (lambda x. (ite (= x 1) true (= x 2)))
+   */
+  static Node getLambdaForArrayRepresentation(TNode a, TNode bvl);
+  /**
+   * Given a lambda expression n, returns an array term that corresponds to n.
+   * This does the opposite direction of the examples described above.
+   *
+   * We limit the return values of this method to be almost constant functions,
+   * that is, arrays of the form:
+   *   (store ... (store (storeall _ b) i1 e1) ... in en)
+   * where b, i1, e1, ..., in, en are constants.
+   * Notice however that the return value of this form need not be a (canonical)
+   * array constant.
+   *
+   * If it is not possible to construct an array of this form that corresponds
+   * to n, this method returns null.
+   */
+  static Node getArrayRepresentationForLambda(TNode n);
+
+ private:
+  /** recursive helper for getLambdaForArrayRepresentation */
+  static Node getLambdaForArrayRepresentationRec(
+      TNode a,
+      TNode bvl,
+      unsigned bvlIndex,
+      std::unordered_map<TNode, Node>& visited);
+  /** recursive helper for getArrayRepresentationForLambda */
+  static Node getArrayRepresentationForLambdaRec(TNode n, TypeNode retType);
+};
+
+}  // namespace uf
+}  // namespace theory
+}  // namespace cvc5
+
+#endif /* CVC5__THEORY__UF__FUNCTION_CONST_H */
index 0faa5c67203bd02aedde9033d659bc43e8e14989..a1db5120ff45dd4ba0cb012af4cbae773c5fee7e 100644 (file)
@@ -15,10 +15,28 @@ parameterized APPLY_UF VARIABLE 1: "application of an uninterpreted function; fi
 
 typerule APPLY_UF ::cvc5::theory::uf::UfTypeRule
 
+operator FUNCTION_TYPE 2: "a function type"
+cardinality FUNCTION_TYPE \
+    "::cvc5::theory::uf::FunctionProperties::computeCardinality(%TYPE%)" \
+    "theory/uf/theory_uf_type_rules.h"
+well-founded FUNCTION_TYPE \
+    "::cvc5::theory::uf::FunctionProperties::isWellFounded(%TYPE%)" \
+    "::cvc5::theory::uf::FunctionProperties::mkGroundTerm(%TYPE%)" \
+    "theory/uf/theory_uf_type_rules.h"
+enumerator FUNCTION_TYPE \
+    ::cvc5::theory::uf::FunctionEnumerator \
+    "theory/uf/type_enumerator.h"
+
+operator LAMBDA 2 "a lambda expression; first parameter is a BOUND_VAR_LIST, second is lambda body"
+
+typerule LAMBDA ::cvc5::theory::uf::LambdaTypeRule
+
 variable BOOLEAN_TERM_VARIABLE "Boolean term variable"
 
-parameterized PARTIAL_APPLY_UF APPLY_UF 1: "partial uninterpreted function application"
-typerule PARTIAL_APPLY_UF ::cvc5::theory::uf::PartialTypeRule
+variable LAMBDA_VARIABLE "Lambda variable, used for lazy lambda lifting"
+
+# lambda expressions that are isomorphic to array constants can be considered constants
+construle LAMBDA ::cvc5::theory::uf::LambdaTypeRule
 
 operator HO_APPLY 2 "higher-order (partial) function application"
 typerule HO_APPLY ::cvc5::theory::uf::HoApplyTypeRule
index f4bedb4b840cecce17844a9ef7780815e6e0e3ba..ba00c316fed14ab058ddf82996329a29d6b5bd55 100644 (file)
@@ -18,6 +18,7 @@
 #include "expr/node_algorithm.h"
 #include "theory/rewriter.h"
 #include "theory/substitutions.h"
+#include "theory/uf/function_const.h"
 
 namespace cvc5 {
 namespace theory {
@@ -139,6 +140,11 @@ RewriteResponse TheoryUfRewriter::postRewrite(TNode node)
       return RewriteResponse(REWRITE_AGAIN_FULL, new_body);
     }
   }
+  else if (node.getKind() == kind::LAMBDA)
+  {
+    Node ret = rewriteLambda(node);
+    return RewriteResponse(REWRITE_DONE, ret);
+  }
   return RewriteResponse(REWRITE_DONE, node);
 }
 
@@ -204,6 +210,56 @@ Node TheoryUfRewriter::decomposeHoApply(TNode n,
 }
 bool TheoryUfRewriter::canUseAsApplyUfOperator(TNode n) { return n.isVar(); }
 
+Node TheoryUfRewriter::rewriteLambda(Node node)
+{
+  Assert(node.getKind() == kind::LAMBDA);
+  // The following code ensures that if node is equivalent to a constant
+  // lambda, then we return the canonical representation for the lambda, which
+  // in turn ensures that two constant lambdas are equivalent if and only
+  // if they are the same node.
+  // We canonicalize lambdas by turning them into array constants, applying
+  // normalization on array constants, and then converting the array constant
+  // back to a lambda.
+  Trace("builtin-rewrite") << "Rewriting lambda " << node << "..." << std::endl;
+  Node anode = FunctionConst::getArrayRepresentationForLambda(node);
+  // Only rewrite constant array nodes, since these are the only cases
+  // where we require canonicalization of lambdas. Moreover, applying the
+  // below code is not correct if the arguments to the lambda occur
+  // in return values. For example, lambda x. ite( x=1, f(x), c ) would
+  // be converted to (store (storeall ... c) 1 f(x)), and then converted
+  // to lambda y. ite( y=1, f(x), c), losing the relation between x and y.
+  if (!anode.isNull() && anode.isConst())
+  {
+    Assert(anode.getType().isArray());
+    // must get the standard bound variable list
+    Node varList = NodeManager::currentNM()->getBoundVarListForFunctionType(
+        node.getType());
+    Node retNode =
+        FunctionConst::getLambdaForArrayRepresentation(anode, varList);
+    if (!retNode.isNull() && retNode != node)
+    {
+      Trace("builtin-rewrite") << "Rewrote lambda : " << std::endl;
+      Trace("builtin-rewrite") << "     input  : " << node << std::endl;
+      Trace("builtin-rewrite")
+          << "     output : " << retNode << ", constant = " << retNode.isConst()
+          << std::endl;
+      Trace("builtin-rewrite")
+          << "  array rep : " << anode << ", constant = " << anode.isConst()
+          << std::endl;
+      Assert(anode.isConst() == retNode.isConst());
+      Assert(retNode.getType() == node.getType());
+      Assert(expr::hasFreeVar(node) == expr::hasFreeVar(retNode));
+      return retNode;
+    }
+  }
+  else
+  {
+    Trace("builtin-rewrite-debug")
+        << "...failed to get array representation." << std::endl;
+  }
+  return node;
+}
+
 }  // namespace uf
 }  // namespace theory
 }  // namespace cvc5
index dfa797f71d1eafea2a259c7d05362b2bac8bf4fa..31a6f4669b0b548a3cda9e3ecbfda6777d4ad3a2 100644 (file)
@@ -35,11 +35,11 @@ class TheoryUfRewriter : public TheoryRewriter
 {
  public:
   TheoryUfRewriter(bool isHigherOrder = false);
+  /** post-rewrite */
   RewriteResponse postRewrite(TNode node) override;
-
+  /** pre-rewrite */
   RewriteResponse preRewrite(TNode node) override;
-
- public:  // conversion between HO_APPLY AND APPLY_UF
+  // conversion between HO_APPLY AND APPLY_UF
   // converts an APPLY_UF to a curried HO_APPLY e.g. (f a b) becomes (@ (@ f a)
   // b)
   static Node getHoApplyForApplyUf(TNode n);
@@ -62,6 +62,10 @@ class TheoryUfRewriter : public TheoryRewriter
    * Then, f and g can be used as APPLY_UF operators, but (ite C f g), (lambda x1. (f x1)) as well as the variable x above are not.
    */
   static bool canUseAsApplyUfOperator(TNode n);
+
+ private:
+  /** Entry point for rewriting lambdas */
+  static Node rewriteLambda(Node node);
   /** Is the logic higher-order? */
   bool d_isHigherOrder;
 }; /* class TheoryUfRewriter */
index 5b132fc2772bf221929eea5a4d69d2d17ab09673..a05c76d4c37aa847fbc0e9e88e44b2761ba82c58 100644 (file)
@@ -19,6 +19,8 @@
 #include <sstream>
 
 #include "expr/cardinality_constraint.h"
+#include "theory/uf/function_const.h"
+#include "util/cardinality.h"
 #include "util/rational.h"
 
 namespace cvc5 {
@@ -160,6 +162,112 @@ TypeNode HoApplyTypeRule::computeType(NodeManager* nodeManager,
   }
 }
 
+TypeNode LambdaTypeRule::computeType(NodeManager* nodeManager,
+                                     TNode n,
+                                     bool check)
+{
+  if (n[0].getType(check) != nodeManager->boundVarListType())
+  {
+    std::stringstream ss;
+    ss << "expected a bound var list for LAMBDA expression, got `"
+       << n[0].getType().toString() << "'";
+    throw TypeCheckingExceptionPrivate(n, ss.str());
+  }
+  std::vector<TypeNode> argTypes;
+  for (TNode::iterator i = n[0].begin(); i != n[0].end(); ++i)
+  {
+    argTypes.push_back((*i).getType());
+  }
+  TypeNode rangeType = n[1].getType(check);
+  return nodeManager->mkFunctionType(argTypes, rangeType);
+}
+
+bool LambdaTypeRule::computeIsConst(NodeManager* nodeManager, TNode n)
+{
+  Assert(n.getKind() == kind::LAMBDA);
+  // get array representation of this function, if possible
+  Node na = FunctionConst::getArrayRepresentationForLambda(n);
+  if (!na.isNull())
+  {
+    Assert(na.getType().isArray());
+    Trace("lambda-const") << "Array representation for " << n << " is " << na
+                          << " " << na.getType() << std::endl;
+    // must have the standard bound variable list
+    Node bvl =
+        NodeManager::currentNM()->getBoundVarListForFunctionType(n.getType());
+    if (bvl == n[0])
+    {
+      // array must be constant
+      if (na.isConst())
+      {
+        Trace("lambda-const") << "*** Constant lambda : " << n;
+        Trace("lambda-const") << " since its array representation : " << na
+                              << " is constant." << std::endl;
+        return true;
+      }
+      else
+      {
+        Trace("lambda-const") << "Non-constant lambda : " << n
+                              << " since array is not constant." << std::endl;
+      }
+    }
+    else
+    {
+      Trace("lambda-const")
+          << "Non-constant lambda : " << n
+          << " since its varlist is not standard." << std::endl;
+      Trace("lambda-const") << "  standard : " << bvl << std::endl;
+      Trace("lambda-const") << "   current : " << n[0] << std::endl;
+    }
+  }
+  else
+  {
+    Trace("lambda-const") << "Non-constant lambda : " << n
+                          << " since it has no array representation."
+                          << std::endl;
+  }
+  return false;
+}
+
+Cardinality FunctionProperties::computeCardinality(TypeNode type)
+{
+  // Don't assert this; allow other theories to use this cardinality
+  // computation.
+  //
+  // Assert(type.getKind() == kind::FUNCTION_TYPE);
+
+  Cardinality argsCard(1);
+  // get the largest cardinality of function arguments/return type
+  for (size_t i = 0, i_end = type.getNumChildren() - 1; i < i_end; ++i)
+  {
+    argsCard *= type[i].getCardinality();
+  }
+
+  Cardinality valueCard = type[type.getNumChildren() - 1].getCardinality();
+
+  return valueCard ^ argsCard;
+}
+
+bool FunctionProperties::isWellFounded(TypeNode type)
+{
+  for (TypeNode::iterator i = type.begin(), i_end = type.end(); i != i_end; ++i)
+  {
+    if (!(*i).isWellFounded())
+    {
+      return false;
+    }
+  }
+  return true;
+}
+
+Node FunctionProperties::mkGroundTerm(TypeNode type)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Node bvl = nm->getBoundVarListForFunctionType(type);
+  Node ret = type.getRangeType().mkGroundTerm();
+  return nm->mkNode(kind::LAMBDA, bvl, ret);
+}
+
 }  // namespace uf
 }  // namespace theory
 }  // namespace cvc5
index b9451a50094bfc4dc4044f9773beaf9748b36b36..6f0374ae63f8c1b77ee1595381b3e489ed811bfc 100644 (file)
@@ -69,6 +69,30 @@ class HoApplyTypeRule
   static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
 };
 
+class LambdaTypeRule
+{
+ public:
+  static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check);
+  // computes whether a lambda is a constant value, via conversion to array
+  // representation
+  static bool computeIsConst(NodeManager* nodeManager, TNode n);
+}; /* class LambdaTypeRule */
+
+class FunctionProperties
+{
+ public:
+  static Cardinality computeCardinality(TypeNode type);
+
+  /** Function type is well-founded if its component sorts are */
+  static bool isWellFounded(TypeNode type);
+  /**
+   * Ground term for function sorts is (lambda x. t) where x is the
+   * canonical variable list for its type and t is the canonical ground term of
+   * its range.
+   */
+  static Node mkGroundTerm(TypeNode type);
+}; /* class FuctionProperties */
+
 }  // namespace uf
 }  // namespace theory
 }  // namespace cvc5
diff --git a/src/theory/uf/type_enumerator.cpp b/src/theory/uf/type_enumerator.cpp
new file mode 100644 (file)
index 0000000..a7f1f3e
--- /dev/null
@@ -0,0 +1,51 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Enumerator for functions.
+ */
+
+#include "theory/uf/type_enumerator.h"
+
+#include "theory/uf/function_const.h"
+
+namespace cvc5 {
+namespace theory {
+namespace uf {
+
+FunctionEnumerator::FunctionEnumerator(TypeNode type,
+                                       TypeEnumeratorProperties* tep)
+    : TypeEnumeratorBase<FunctionEnumerator>(type),
+      d_arrayEnum(FunctionConst::getArrayTypeForFunctionType(type), tep)
+{
+  Assert(type.getKind() == kind::FUNCTION_TYPE);
+  d_bvl = NodeManager::currentNM()->getBoundVarListForFunctionType(type);
+}
+
+Node FunctionEnumerator::operator*()
+{
+  if (isFinished())
+  {
+    throw NoMoreValuesException(getType());
+  }
+  Node a = *d_arrayEnum;
+  return FunctionConst::getLambdaForArrayRepresentation(a, d_bvl);
+}
+
+FunctionEnumerator& FunctionEnumerator::operator++()
+{
+  ++d_arrayEnum;
+  return *this;
+}
+
+}  // namespace uf
+}  // namespace theory
+}  // namespace cvc5
diff --git a/src/theory/uf/type_enumerator.h b/src/theory/uf/type_enumerator.h
new file mode 100644 (file)
index 0000000..dfbbc19
--- /dev/null
@@ -0,0 +1,59 @@
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * Enumerator for functions.
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__UF__TYPE_ENUMERATOR_H
+#define CVC5__THEORY__UF__TYPE_ENUMERATOR_H
+
+#include "expr/kind.h"
+#include "expr/type_node.h"
+#include "theory/type_enumerator.h"
+#include "util/integer.h"
+
+namespace cvc5 {
+namespace theory {
+namespace uf {
+
+/** FunctionEnumerator
+ * This enumerates function values, based on the enumerator for the
+ * array type corresponding to the given function type.
+ */
+class FunctionEnumerator : public TypeEnumeratorBase<FunctionEnumerator>
+{
+ public:
+  FunctionEnumerator(TypeNode type, TypeEnumeratorProperties* tep = nullptr);
+  /** Get the current term of the enumerator. */
+  Node operator*() override;
+  /** Increment the enumerator. */
+  FunctionEnumerator& operator++() override;
+  /** is the enumerator finished? */
+  bool isFinished() override { return d_arrayEnum.isFinished(); }
+
+ private:
+  /** Enumerates arrays, which we convert to functions. */
+  TypeEnumerator d_arrayEnum;
+  /** The bound variable list for the function type we are enumerating.
+   * All terms output by this enumerator are of the form (LAMBDA d_bvl t) for
+   * some term t.
+   */
+  Node d_bvl;
+}; /* class FunctionEnumerator */
+
+}  // namespace uf
+}  // namespace theory
+}  // namespace cvc5
+
+#endif /* CVC5__THEORY__UF__TYPE_ENUMERATOR_H */