From cb9539e1b150593e22128e1dda7d692e6444ec4b Mon Sep 17 00:00:00 2001 From: Andrew Reynolds Date: Thu, 5 Aug 2021 15:41:52 -0500 Subject: [PATCH] Generalize term canonizer for type classes (#6895) Initial work towards rewrite rule reconstruction. --- src/expr/term_canonize.cpp | 43 ++++++++++++++++++++++++------------- src/expr/term_canonize.h | 44 ++++++++++++++++++++++++++++++-------- 2 files changed, 63 insertions(+), 24 deletions(-) diff --git a/src/expr/term_canonize.cpp b/src/expr/term_canonize.cpp index 0480a046a..0e05a52ea 100644 --- a/src/expr/term_canonize.cpp +++ b/src/expr/term_canonize.cpp @@ -25,7 +25,10 @@ using namespace cvc5::kind; namespace cvc5 { namespace expr { -TermCanonize::TermCanonize() : d_op_id_count(0), d_typ_id_count(0) {} +TermCanonize::TermCanonize(TypeClassCallback* tcc) + : d_tcc(tcc), d_op_id_count(0), d_typ_id_count(0) +{ +} int TermCanonize::getIdForOperator(Node op) { @@ -93,11 +96,13 @@ bool TermCanonize::getTermOrder(Node a, Node b) return false; } -Node TermCanonize::getCanonicalFreeVar(TypeNode tn, unsigned i) +Node TermCanonize::getCanonicalFreeVar(TypeNode tn, unsigned i, uint32_t tc) { Assert(!tn.isNull()); NodeManager* nm = NodeManager::currentNM(); - while (d_cn_free_var[tn].size() <= i) + std::pair key(tn, tc); + std::vector& tvars = d_cn_free_var[key]; + while (tvars.size() <= i) { std::stringstream oss; oss << tn; @@ -109,10 +114,15 @@ Node TermCanonize::getCanonicalFreeVar(TypeNode tn, unsigned i) std::stringstream os; os << typ_name[0] << i; Node x = nm->mkBoundVar(os.str().c_str(), tn); - d_fvIndex[x] = d_cn_free_var[tn].size(); - d_cn_free_var[tn].push_back(x); + d_fvIndex[x] = tvars.size(); + tvars.push_back(x); } - return d_cn_free_var[tn][i]; + return tvars[i]; +} + +uint32_t TermCanonize::getTypeClass(TNode v) +{ + return d_tcc == nullptr ? 0 : d_tcc->getTypeClass(v); } size_t TermCanonize::getIndexForFreeVariable(Node v) const @@ -131,11 +141,12 @@ struct sortTermOrder bool operator()(Node i, Node j) { return d_tu->getTermOrder(i, j); } }; -Node TermCanonize::getCanonicalTerm(TNode n, - bool apply_torder, - bool doHoVar, - std::map& var_count, - std::map& visited) +Node TermCanonize::getCanonicalTerm( + TNode n, + bool apply_torder, + bool doHoVar, + std::map, unsigned>& var_count, + std::map& visited) { std::map::iterator it = visited.find(n); if (it != visited.end()) @@ -146,11 +157,13 @@ Node TermCanonize::getCanonicalTerm(TNode n, Trace("canon-term-debug") << "Get canonical term for " << n << std::endl; if (n.getKind() == BOUND_VARIABLE) { + uint32_t tc = getTypeClass(n); TypeNode tn = n.getType(); + std::pair key(tn, tc); // allocate variable - unsigned vn = var_count[tn]; - var_count[tn]++; - Node fv = getCanonicalFreeVar(tn, vn); + unsigned vn = var_count[key]; + var_count[key]++; + Node fv = getCanonicalFreeVar(tn, vn, tc); visited[n] = fv; Trace("canon-term-debug") << "...allocate variable." << std::endl; return fv; @@ -204,7 +217,7 @@ Node TermCanonize::getCanonicalTerm(TNode n, Node TermCanonize::getCanonicalTerm(TNode n, bool apply_torder, bool doHoVar) { - std::map var_count; + std::map, unsigned> var_count; std::map visited; return getCanonicalTerm(n, apply_torder, doHoVar, var_count, visited); } diff --git a/src/expr/term_canonize.h b/src/expr/term_canonize.h index 6b9cf9cb9..e1524cbf6 100644 --- a/src/expr/term_canonize.h +++ b/src/expr/term_canonize.h @@ -24,6 +24,21 @@ namespace cvc5 { namespace expr { +/** + * Generalization of types. This class is a simple callback for giving + * identifiers to variables that may be a more fine-grained way of classifying + * the variable than its type. An example usage of type classes are for + * distinguishing "list variables" for rewrite rule reconstruction. + */ +class TypeClassCallback +{ + public: + TypeClassCallback() {} + virtual ~TypeClassCallback() {} + /** Return the type class for variable v */ + virtual uint32_t getTypeClass(TNode v) = 0; +}; + /** TermCanonize * * This class contains utilities for canonizing terms with respect to @@ -34,7 +49,13 @@ namespace expr { class TermCanonize { public: - TermCanonize(); + /** + * @param tcc The type class callback. This class will canonize variables in + * a way that disinguishes variables that are given different type class + * identifiers. Otherwise, this class will assume all variables of the + * same type have the same type class. + */ + TermCanonize(TypeClassCallback* tcc = nullptr); ~TermCanonize() {} /** Maps operators to an identifier, useful for ordering. */ @@ -50,7 +71,7 @@ class TermCanonize */ bool getTermOrder(Node a, Node b); /** get canonical free variable #i of type tn */ - Node getCanonicalFreeVar(TypeNode tn, unsigned i); + Node getCanonicalFreeVar(TypeNode tn, unsigned i, uint32_t tc = 0); /** get canonical term * * This returns a canonical (alpha-equivalent) version of n, where @@ -68,6 +89,8 @@ class TermCanonize bool doHoVar = true); private: + /** The (optional) type class callback */ + TypeClassCallback* d_tcc; /** the number of ids we have allocated for operators */ int d_op_id_count; /** map from operators to id */ @@ -76,12 +99,14 @@ class TermCanonize int d_typ_id_count; /** map from type to id */ std::map d_typ_id; - /** free variables for each type */ - std::map > d_cn_free_var; + /** free variables for each type / type class pair */ + std::map, std::vector > d_cn_free_var; /** * Map from each free variable above to their index in their respective vector */ std::map d_fvIndex; + /** Get type class */ + uint32_t getTypeClass(TNode v); /** * Return the range of the free variable in the above map, or 0 if it does not * exist. @@ -93,11 +118,12 @@ class TermCanonize * counter of how many variables we have allocated for each type (var_count), * and a cache of visited nodes (visited). */ - Node getCanonicalTerm(TNode n, - bool apply_torder, - bool doHoVar, - std::map& var_count, - std::map& visited); + Node getCanonicalTerm( + TNode n, + bool apply_torder, + bool doHoVar, + std::map, unsigned>& var_count, + std::map& visited); }; } // namespace expr -- 2.30.2