Generalize term canonizer for type classes (#6895)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 5 Aug 2021 20:41:52 +0000 (15:41 -0500)
committerGitHub <noreply@github.com>
Thu, 5 Aug 2021 20:41:52 +0000 (20:41 +0000)
Initial work towards rewrite rule reconstruction.

src/expr/term_canonize.cpp
src/expr/term_canonize.h

index 0480a046a084e89f6bfcf53f465ce3f02a6b394a..0e05a52eabae6d429c03c8049d393908d3ff9236 100644 (file)
@@ -25,7 +25,10 @@ using namespace cvc5::kind;
 namespace cvc5 {
 namespace expr {
 
-TermCanonize::TermCanonize() : d_op_id_count(0), d_typ_id_count(0) {}
+TermCanonize::TermCanonize(TypeClassCallback* tcc)
+    : d_tcc(tcc), d_op_id_count(0), d_typ_id_count(0)
+{
+}
 
 int TermCanonize::getIdForOperator(Node op)
 {
@@ -93,11 +96,13 @@ bool TermCanonize::getTermOrder(Node a, Node b)
   return false;
 }
 
-Node TermCanonize::getCanonicalFreeVar(TypeNode tn, unsigned i)
+Node TermCanonize::getCanonicalFreeVar(TypeNode tn, unsigned i, uint32_t tc)
 {
   Assert(!tn.isNull());
   NodeManager* nm = NodeManager::currentNM();
-  while (d_cn_free_var[tn].size() <= i)
+  std::pair<TypeNode, uint32_t> key(tn, tc);
+  std::vector<Node>& tvars = d_cn_free_var[key];
+  while (tvars.size() <= i)
   {
     std::stringstream oss;
     oss << tn;
@@ -109,10 +114,15 @@ Node TermCanonize::getCanonicalFreeVar(TypeNode tn, unsigned i)
     std::stringstream os;
     os << typ_name[0] << i;
     Node x = nm->mkBoundVar(os.str().c_str(), tn);
-    d_fvIndex[x] = d_cn_free_var[tn].size();
-    d_cn_free_var[tn].push_back(x);
+    d_fvIndex[x] = tvars.size();
+    tvars.push_back(x);
   }
-  return d_cn_free_var[tn][i];
+  return tvars[i];
+}
+
+uint32_t TermCanonize::getTypeClass(TNode v)
+{
+  return d_tcc == nullptr ? 0 : d_tcc->getTypeClass(v);
 }
 
 size_t TermCanonize::getIndexForFreeVariable(Node v) const
@@ -131,11 +141,12 @@ struct sortTermOrder
   bool operator()(Node i, Node j) { return d_tu->getTermOrder(i, j); }
 };
 
-Node TermCanonize::getCanonicalTerm(TNode n,
-                                    bool apply_torder,
-                                    bool doHoVar,
-                                    std::map<TypeNode, unsigned>& var_count,
-                                    std::map<TNode, Node>& visited)
+Node TermCanonize::getCanonicalTerm(
+    TNode n,
+    bool apply_torder,
+    bool doHoVar,
+    std::map<std::pair<TypeNode, uint32_t>, unsigned>& var_count,
+    std::map<TNode, Node>& visited)
 {
   std::map<TNode, Node>::iterator it = visited.find(n);
   if (it != visited.end())
@@ -146,11 +157,13 @@ Node TermCanonize::getCanonicalTerm(TNode n,
   Trace("canon-term-debug") << "Get canonical term for " << n << std::endl;
   if (n.getKind() == BOUND_VARIABLE)
   {
+    uint32_t tc = getTypeClass(n);
     TypeNode tn = n.getType();
+    std::pair<TypeNode, uint32_t> key(tn, tc);
     // allocate variable
-    unsigned vn = var_count[tn];
-    var_count[tn]++;
-    Node fv = getCanonicalFreeVar(tn, vn);
+    unsigned vn = var_count[key];
+    var_count[key]++;
+    Node fv = getCanonicalFreeVar(tn, vn, tc);
     visited[n] = fv;
     Trace("canon-term-debug") << "...allocate variable." << std::endl;
     return fv;
@@ -204,7 +217,7 @@ Node TermCanonize::getCanonicalTerm(TNode n,
 
 Node TermCanonize::getCanonicalTerm(TNode n, bool apply_torder, bool doHoVar)
 {
-  std::map<TypeNode, unsigned> var_count;
+  std::map<std::pair<TypeNode, uint32_t>, unsigned> var_count;
   std::map<TNode, Node> visited;
   return getCanonicalTerm(n, apply_torder, doHoVar, var_count, visited);
 }
index 6b9cf9cb91814b08d2018463ad66e16cbc408563..e1524cbf66e84ac95c20028ad724e9ec86032a6e 100644 (file)
 namespace cvc5 {
 namespace expr {
 
+/**
+ * Generalization of types. This class is a simple callback for giving
+ * identifiers to variables that may be a more fine-grained way of classifying
+ * the variable than its type. An example usage of type classes are for
+ * distinguishing "list variables" for rewrite rule reconstruction.
+ */
+class TypeClassCallback
+{
+ public:
+  TypeClassCallback() {}
+  virtual ~TypeClassCallback() {}
+  /** Return the type class for variable v */
+  virtual uint32_t getTypeClass(TNode v) = 0;
+};
+
 /** TermCanonize
  *
  * This class contains utilities for canonizing terms with respect to
@@ -34,7 +49,13 @@ namespace expr {
 class TermCanonize
 {
  public:
-  TermCanonize();
+  /**
+   * @param tcc The type class callback. This class will canonize variables in
+   * a way that disinguishes variables that are given different type class
+   * identifiers. Otherwise, this class will assume all variables of the
+   * same type have the same type class.
+   */
+  TermCanonize(TypeClassCallback* tcc = nullptr);
   ~TermCanonize() {}
 
   /** Maps operators to an identifier, useful for ordering. */
@@ -50,7 +71,7 @@ class TermCanonize
    */
   bool getTermOrder(Node a, Node b);
   /** get canonical free variable #i of type tn */
-  Node getCanonicalFreeVar(TypeNode tn, unsigned i);
+  Node getCanonicalFreeVar(TypeNode tn, unsigned i, uint32_t tc = 0);
   /** get canonical term
    *
    * This returns a canonical (alpha-equivalent) version of n, where
@@ -68,6 +89,8 @@ class TermCanonize
                         bool doHoVar = true);
 
  private:
+  /** The (optional) type class callback */
+  TypeClassCallback* d_tcc;
   /** the number of ids we have allocated for operators */
   int d_op_id_count;
   /** map from operators to id */
@@ -76,12 +99,14 @@ class TermCanonize
   int d_typ_id_count;
   /** map from type to id */
   std::map<TypeNode, int> d_typ_id;
-  /** free variables for each type */
-  std::map<TypeNode, std::vector<Node> > d_cn_free_var;
+  /** free variables for each type / type class pair */
+  std::map<std::pair<TypeNode, uint32_t>, std::vector<Node> > d_cn_free_var;
   /**
    * Map from each free variable above to their index in their respective vector
    */
   std::map<Node, size_t> d_fvIndex;
+  /** Get type class */
+  uint32_t getTypeClass(TNode v);
   /**
    * Return the range of the free variable in the above map, or 0 if it does not
    * exist.
@@ -93,11 +118,12 @@ class TermCanonize
    * counter of how many variables we have allocated for each type (var_count),
    * and a cache of visited nodes (visited).
    */
-  Node getCanonicalTerm(TNode n,
-                        bool apply_torder,
-                        bool doHoVar,
-                        std::map<TypeNode, unsigned>& var_count,
-                        std::map<TNode, Node>& visited);
+  Node getCanonicalTerm(
+      TNode n,
+      bool apply_torder,
+      bool doHoVar,
+      std::map<std::pair<TypeNode, uint32_t>, unsigned>& var_count,
+      std::map<TNode, Node>& visited);
 };
 
 }  // namespace expr