congruence closure module now supports things other than APPLY_UF; ported from "array...
authorMorgan Deters <mdeters@gmail.com>
Tue, 14 Dec 2010 21:07:46 +0000 (21:07 +0000)
committerMorgan Deters <mdeters@gmail.com>
Tue, 14 Dec 2010 21:07:46 +0000 (21:07 +0000)
src/expr/node_manager.cpp
src/expr/node_manager.h
src/expr/type_constant.h
src/printer/Makefile.am
src/theory/uf/morgan/theory_uf_morgan.h
src/util/congruence_closure.h
test/unit/util/congruence_closure_white.h

index 280c552544bf3a962b1d6e403da9916ca108c6e9..9006bf4d96934f16712e90a7a7e6c5ff684f9e7b 100644 (file)
@@ -228,6 +228,9 @@ TypeNode NodeManager::computeType(TNode n, bool check)
 
   // Infer the type
   switch(n.getKind()) {
+  case kind::BUILTIN:
+    typeNode = builtinOperatorType();
+    break;
   case kind::SORT_TYPE:
     typeNode = kindType();
     break;
index 206cf35d5ae17a8e821d2d332a96849eedcbc9cc..04de81b1c88872b66bbe594c3c9decc845bef815 100644 (file)
@@ -471,6 +471,12 @@ public:
   /** Get the (singleton) type for sorts. */
   inline TypeNode kindType();
 
+  /**
+   * Get the (singleton) type for builtin operators (that is, the type
+   * of the Node returned from Node::getOperator() when the operator
+   * is built-in, like EQUAL). */
+  inline TypeNode builtinOperatorType();
+
   /**
    * Make a function type from domain to range.
    *
@@ -680,6 +686,11 @@ inline TypeNode NodeManager::kindType() {
   return TypeNode(mkTypeConst<TypeConstant>(KIND_TYPE));
 }
 
+/** Get the (singleton) type for builtin operators. */
+inline TypeNode NodeManager::builtinOperatorType() {
+  return TypeNode(mkTypeConst<TypeConstant>(BUILTIN_OPERATOR_TYPE));
+}
+
 /** Make a function type from domain to range. */
 inline TypeNode NodeManager::mkFunctionType(const TypeNode& domain, const TypeNode& range) {
   std::vector<TypeNode> sorts;
index 3001d4513248ef3f313a0a1e337e2306421f5938..23c23cc9fb163fcef9edf19ecd1f48f12b97b766 100644 (file)
@@ -36,7 +36,9 @@ enum TypeConstant {
   /** The real type */
   REAL_TYPE,
   /** The kind type (type of types) */
-  KIND_TYPE
+  KIND_TYPE,
+  /** The builtin operator type (type of non-PARAMETERIZED operators) */
+  BUILTIN_OPERATOR_TYPE
 };/* enum TypeConstant */
 
 /**
index 8fd50d823739a66b9f4497ff79d5dbff88352d86..ca3cbacf139f70a7f12e52cacb6a89d9d6d17733 100644 (file)
@@ -16,6 +16,3 @@ libprinter_la_SOURCES = \
        smt2/smt2_printer.cpp \
        cvc/cvc_printer.h \
        cvc/cvc_printer.cpp
-
-libprinter_la_LIBADD = \
-       @builddir@/../lib/libreplacements.la
index 99e6f5fbce7024ffe9760fd7166d8e4a85ffc5a8..cbc5f1eabe290c773217e3b1ec588991e97583c5 100644 (file)
@@ -74,7 +74,7 @@ private:
   /**
    * Instance of the congruence closure module.
    */
-  CongruenceClosure<CongruenceChannel> d_cc;
+  CongruenceClosure<CongruenceChannel, CongruenceOperator<kind::APPLY_UF> > d_cc;
 
   /**
    * Our union find for equalities.
index 8a13e35873af76d440475184e1d51bc5946e4363..0968b39edc45f48ccec225ec96436ffae1ed5ddb 100644 (file)
 
 namespace CVC4 {
 
-template <class OutputChannel>
+template <class OutputChannel, class CongruenceOperatorList>
 class CongruenceClosure;
 
-template <class OutputChannel>
+template <class OutputChannel, class CongruenceOperatorList>
 std::ostream& operator<<(std::ostream& out,
-                         const CongruenceClosure<OutputChannel>& cc);
+                         const CongruenceClosure<OutputChannel, CongruenceOperatorList>& cc);
 
 /**
  * A CongruenceClosureException is thrown by
@@ -59,6 +59,36 @@ public:
     Exception(std::string("Congruence closure exception: ") + msg) {}
 };/* class CongruenceClosureException */
 
+struct EndOfCongruenceOpList;
+template <Kind kind_, class Tail_ = EndOfCongruenceOpList>
+struct CongruenceOperator {
+  enum { kind = kind_ };
+  typedef Tail_ Tail;
+};/* class CongruenceOperator<> */
+
+#define CONGRUENCE_OPERATORS_1(kind1) CongruenceOperator<kind1, EndOfCongruenceOpList>
+#define CONGRUENCE_OPERATORS_2(kind1, kind2) CongruenceOperator<kind1, CONGRUENCE_OPERATORS_1(kind2)>
+#define CONGRUENCE_OPERATORS_3(kind1, kind2, kind3) CongruenceOperator<kind1, CONGRUENCE_OPERATORS_2(kind2, kind3)>
+#define CONGRUENCE_OPERATORS_4(kind1, kind2, kind3, kind4) CongruenceOperator<kind1, CONGRUENCE_OPERATORS_3(kind2, kind3, kind4)>
+#define CONGRUENCE_OPERATORS_5(kind1, kind2, kind3, kind4, kind5) CongruenceOperator<kind1, CONGRUENCE_OPERATORS_4(kind2, kind3, kind4, kind5)>
+
+/**
+ * Returns true if the kind k is registered as a congruence operator
+ * for this CongruenceClosure.  (That is, if it's in the
+ * CongruenceOperatorList template parameter.)  False otherwise.
+ */
+template <class CongruenceOperatorList>
+inline bool isInCongruenceOperatorList(Kind k) {
+  typedef typename CongruenceOperatorList::Tail Tail;
+  return k == Kind(CongruenceOperatorList::kind) ||
+    isInCongruenceOperatorList<Tail>(k);
+}
+
+// specialization for empty list
+template <>
+inline bool isInCongruenceOperatorList<EndOfCongruenceOpList>(Kind k) {
+  return false;
+}
 
 /**
  * Congruence closure module for CVC4.
@@ -95,8 +125,12 @@ public:
  *       // interrupt you.
  *     }
  *   };
+ *
+ * CongruenceOperatorList is a typelist of congruence Kinds,
+ * e.g., CONGRUENCE_OPERATORS_1(kind::APPLY_UF)
+ * or CONGRUENCE_OPERATORS_2(kind::SELECT, kind::STORE)
  */
-template <class OutputChannel>
+template <class OutputChannel, class CongruenceOperatorList>
 class CongruenceClosure {
   /** The context */
   context::Context* d_context;
@@ -147,6 +181,10 @@ class CongruenceClosure {
   AverageStat d_explanationLength;/*! average explanation length */
   IntStat d_newSkolemVars;/*! new vars created */
 
+  static inline bool isCongruenceOperator(Kind k) {
+    return isInCongruenceOperatorList<CongruenceOperatorList>(k);
+  }
+
 public:
   /** Construct a congruence closure module instance */
   CongruenceClosure(context::Context* ctxt, OutputChannel* out)
@@ -187,8 +225,8 @@ public:
     Assert(inputEq.getKind() == kind::EQUAL ||
            inputEq.getKind() == kind::IFF);
     NodeBuilder<> eqb(inputEq.getKind());
-    if(inputEq[1].getKind() == kind::APPLY_UF &&
-       inputEq[0].getKind() != kind::APPLY_UF) {
+    if(isCongruenceOperator(inputEq[1].getKind()) &&
+       !isCongruenceOperator(inputEq[0].getKind())) {
       eqb << flatten(inputEq[1]) << inputEq[0];
     } else {
       eqb << flatten(inputEq[0]) << replace(flatten(inputEq[1]));
@@ -199,11 +237,13 @@ public:
   void addEq(TNode eq, TNode inputEq);
 
   Node flatten(TNode t) {
-    if(t.getKind() == kind::APPLY_UF) {
-      NodeBuilder<> appb(kind::APPLY_UF);
+    if(isCongruenceOperator(t.getKind())) {
+      NodeBuilder<> appb(t.getKind());
       Assert(replace(flatten(t.getOperator())) == t.getOperator(),
              "CongruenceClosure:: bad state: higher-order term ??");
-      appb << t.getOperator();
+      if(t.getMetaKind() == kind::metakind::PARAMETERIZED) {
+       appb << t.getOperator();
+      }
       for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
         appb << replace(flatten(*i));
       }
@@ -214,7 +254,7 @@ public:
   }
 
   Node replace(TNode t) {
-    if(t.getKind() == kind::APPLY_UF) {
+    if(isCongruenceOperator(t.getKind())) {
       EqMap::iterator i = d_eqMap.find(t);
       if(i == d_eqMap.end()) {
         ++d_newSkolemVars;
@@ -315,7 +355,7 @@ public:
 private:
 
   friend std::ostream& operator<< <>(std::ostream& out,
-                                     const CongruenceClosure<OutputChannel>& cc);
+                                     const CongruenceClosure<OutputChannel, CongruenceOperatorList>& cc);
 
   /**
    * Internal propagation of information.  Propagation tends to
@@ -414,8 +454,8 @@ public:
 };/* class CongruenceClosure */
 
 
-template <class OutputChannel>
-void CongruenceClosure<OutputChannel>::addTerm(TNode t) {
+template <class OutputChannel, class CongruenceOperatorList>
+void CongruenceClosure<OutputChannel, CongruenceOperatorList>::addTerm(TNode t) {
   Node trm = replace(flatten(t));
   Node trmp = find(trm);
 
@@ -445,8 +485,8 @@ void CongruenceClosure<OutputChannel>::addTerm(TNode t) {
 }
 
 
-template <class OutputChannel>
-void CongruenceClosure<OutputChannel>::addEq(TNode eq, TNode inputEq) {
+template <class OutputChannel, class CongruenceOperatorList>
+void CongruenceClosure<OutputChannel, CongruenceOperatorList>::addEq(TNode eq, TNode inputEq) {
   Assert(!eq[0].getType().isFunction() && !eq[1].getType().isFunction(),
          "CongruenceClosure:: equality between function symbols not allowed");
 
@@ -457,7 +497,7 @@ void CongruenceClosure<OutputChannel>::addEq(TNode eq, TNode inputEq) {
   }
   Assert(eq.getKind() == kind::EQUAL ||
          eq.getKind() == kind::IFF);
-  Assert(eq[1].getKind() != kind::APPLY_UF);
+  Assert(!isCongruenceOperator(eq[1].getKind()));
   if(areCongruent(eq[0], eq[1])) {
     Trace("cc") << "CC -- redundant, ignoring...\n";
     return;
@@ -472,7 +512,7 @@ void CongruenceClosure<OutputChannel>::addEq(TNode eq, TNode inputEq) {
   // change from paper: do this whether or not s, t are applications
   Trace("cc:detail") << "CC        propagating the eq" << std::endl;
 
-  if(s.getKind() != kind::APPLY_UF) {
+  if(!isCongruenceOperator(s.getKind())) {
     // s, t are constants
     propagate(eq);
   } else {
@@ -500,16 +540,18 @@ void CongruenceClosure<OutputChannel>::addEq(TNode eq, TNode inputEq) {
 }/* addEq() */
 
 
-template <class OutputChannel>
-Node CongruenceClosure<OutputChannel>::buildRepresentativesOfApply(TNode apply,
+template <class OutputChannel, class CongruenceOperatorList>
+Node CongruenceClosure<OutputChannel, CongruenceOperatorList>::buildRepresentativesOfApply(TNode apply,
                                                               Kind kindToBuild)
   throw(AssertionException) {
-  Assert(apply.getKind() == kind::APPLY_UF);
+  Assert(isCongruenceOperator(apply.getKind()));
   NodeBuilder<> argspb(kindToBuild);
-  // FIXME probably don't have to do find() of operator
   Assert(find(apply.getOperator()) == apply.getOperator(),
-         "CongruenceClosure:: bad state: function symbol merged with another");
-  argspb << apply.getOperator();
+         "CongruenceClosure:: bad state: "
+         "function symbol (or other congruence operator) merged with another");
+  if(apply.getMetaKind() == kind::metakind::PARAMETERIZED) {
+    argspb << apply.getOperator();
+  }
   for(TNode::iterator i = apply.begin(); i != apply.end(); ++i) {
     argspb << find(*i);
   }
@@ -517,8 +559,8 @@ Node CongruenceClosure<OutputChannel>::buildRepresentativesOfApply(TNode apply,
 }/* buildRepresentativesOfApply() */
 
 
-template <class OutputChannel>
-void CongruenceClosure<OutputChannel>::propagate(TNode seed) {
+template <class OutputChannel, class CongruenceOperatorList>
+void CongruenceClosure<OutputChannel, CongruenceOperatorList>::propagate(TNode seed) {
   Trace("cc:detail") << "=== doing a round of propagation ===" << std::endl
                      << "the \"seed\" propagation is: " << seed << std::endl;
 
@@ -559,8 +601,8 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) {
       a = e[0][1];
       b = e[1][1];
 
-      Assert(a.getKind() != kind::APPLY_UF);
-      Assert(b.getKind() != kind::APPLY_UF);
+      Assert(!isCongruenceOperator(a.getKind()));
+      Assert(!isCongruenceOperator(b.getKind()));
 
       Trace("cc") << "                 ( " << a << " , " << b << " )" << std::endl;
     }
@@ -631,7 +673,7 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) {
             Trace("cc:detail") << "calling merge2 " << c << bp << std::endl;
             merge(c, bp);
             // move c from classList(ap) to classlist(bp);
-            //i = cl.erase(i);// FIXME do we need to?
+            //i = cl.erase(i);// difference from paper: don't need to erase
             Trace("cc") << " adding c to class list of " << bp << std::endl;
             cl_bp->push_back(c);
           }
@@ -658,10 +700,11 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) {
                    eq.getKind() == kind::IFF);
             // change from paper
             // use list elts can have form (apply c..) = x  OR  x = (apply c..)
-            Assert(eq[0].getKind() == kind::APPLY_UF || eq[1].getKind() == kind::APPLY_UF);
+            Assert(isCongruenceOperator(eq[0].getKind()) ||
+                   isCongruenceOperator(eq[1].getKind()));
             // do for each side that is an application
             for(int side = 0; side <= 1; ++side) {
-              if(eq[side].getKind() != kind::APPLY_UF) {
+              if(!isCongruenceOperator(eq[side].getKind())) {
                 continue;
               }
 
@@ -681,14 +724,14 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) {
                 // remove f(c1,c2)=c from UseList(ap)
                 Trace("cc:detail") << "supposed to remove " << eq << std::endl
                                    << "  from UseList of " << ap << std::endl;
-                //i = ul.erase(i);// FIXME do we need to?
+                //i = ul.erase(i);// difference from paper: don't need to erase
               } else {
                 Trace("cc") << "CC     -- lookup(c') is null" << std::endl;
                 Trace("cc") << "CC     -- setlookup(c') to " << eq << std::endl;
                 // set lookup(c1',c2') to f(c1,c2)=c
                 setLookup(cp, eq);
                 // move f(c1,c2)=c from UseList(ap) to UseList(b')
-                //i = ul.erase(i);// FIXME do we need to remove from UseList(ap) ?
+                //i = ul.erase(i);// difference from paper: don't need to erase
                 appendToUseList(bp, eq);
               }
             }
@@ -713,8 +756,8 @@ void CongruenceClosure<OutputChannel>::propagate(TNode seed) {
 }/* propagate() */
 
 
-template <class OutputChannel>
-void CongruenceClosure<OutputChannel>::merge(TNode ec1, TNode ec2) {
+template <class OutputChannel, class CongruenceOperatorList>
+void CongruenceClosure<OutputChannel, CongruenceOperatorList>::merge(TNode ec1, TNode ec2) {
   /*
   if(Debug.isOn("cc:detail")) {
     Debug("cc:detail") << "  -- merging " << ec1
@@ -730,8 +773,8 @@ void CongruenceClosure<OutputChannel>::merge(TNode ec1, TNode ec2) {
   Trace("cc") << "CC             to " << ec2 << std::endl;
 
   /* can now be applications
-  Assert(ec1.getKind() != kind::APPLY_UF);
-  Assert(ec2.getKind() != kind::APPLY_UF);
+  Assert(!isCongruenceOperator(ec1.getKind()));
+  Assert(!isCongruenceOperator(ec2.getKind()));
   */
 
   Assert(find(ec1) != ec2);
@@ -747,8 +790,8 @@ void CongruenceClosure<OutputChannel>::merge(TNode ec1, TNode ec2) {
 }/* merge() */
 
 
-template <class OutputChannel>
-void CongruenceClosure<OutputChannel>::mergeProof(TNode a, TNode b, TNode e) {
+template <class OutputChannel, class CongruenceOperatorList>
+void CongruenceClosure<OutputChannel, CongruenceOperatorList>::mergeProof(TNode a, TNode b, TNode e) {
   Trace("cc") << "  -- merge-proofing " << a << "\n"
               << "                and " << b << "\n"
               << "               with " << e << "\n";
@@ -786,26 +829,29 @@ void CongruenceClosure<OutputChannel>::mergeProof(TNode a, TNode b, TNode e) {
 }/* mergeProof() */
 
 
-template <class OutputChannel>
-Node CongruenceClosure<OutputChannel>::normalize(TNode t) const
+template <class OutputChannel, class CongruenceOperatorList>
+Node CongruenceClosure<OutputChannel, CongruenceOperatorList>::normalize(TNode t) const
   throw(AssertionException) {
   Trace("cc:detail") << "normalize " << t << std::endl;
-  if(t.getKind() != kind::APPLY_UF) {// t is a constant
+  if(!isCongruenceOperator(t.getKind())) {// t is a constant
     t = find(t);
     Trace("cc:detail") << "  find " << t << std::endl;
     return t;
   } else {// t is an apply
     NodeBuilder<> apb(kind::TUPLE);
     Assert(normalize(t.getOperator()) == t.getOperator(),
-           "CongruenceClosure:: bad state: function symbol merged with another");
-    apb << t.getOperator();
+           "CongruenceClosure:: bad state: "
+           "function symbol merged with another");
+    if(t.getMetaKind() == kind::metakind::PARAMETERIZED) {
+      apb << t.getOperator();
+    }
     Node n;
-    bool allConstants = (n.getKind() != kind::APPLY_UF);
+    bool allConstants = (!isCongruenceOperator(n.getKind()));
     for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
       TNode c = *i;
       n = normalize(c);
       apb << n;
-      allConstants = (allConstants && n.getKind() != kind::APPLY_UF);
+      allConstants = (allConstants && !isCongruenceOperator(n.getKind()));
     }
 
     Node ap = apb;
@@ -815,15 +861,15 @@ Node CongruenceClosure<OutputChannel>::normalize(TNode t) const
     if(allConstants && !theLookup.isNull()) {
       Assert(theLookup.getKind() == kind::EQUAL ||
              theLookup.getKind() == kind::IFF);
-      Assert(theLookup[0].getKind() == kind::APPLY_UF);
-      Assert(theLookup[1].getKind() != kind::APPLY_UF);
+      Assert(isCongruenceOperator(theLookup[0].getKind()));
+      Assert(!isCongruenceOperator(theLookup[1].getKind()));
       return find(theLookup[1]);
     } else {
-      NodeBuilder<> fa(kind::APPLY_UF);
+      NodeBuilder<> fa(t.getKind());
       for(Node::iterator i = ap.begin(); i != ap.end(); ++i) {
         fa << *i;
       }
-      // ensure a hard Node link exists during the call
+      // ensure a hard Node link exists for the return
       Node n = fa;
       return n;
     }
@@ -834,8 +880,8 @@ Node CongruenceClosure<OutputChannel>::normalize(TNode t) const
 // This is the find() operation for the auxiliary union-find.  This
 // union-find is not context-dependent, as it's used only during
 // explain().  It does path compression.
-template <class OutputChannel>
-Node CongruenceClosure<OutputChannel>::highestNode(TNode a, UnionFind_t& unionFind) const
+template <class OutputChannel, class CongruenceOperatorList>
+Node CongruenceClosure<OutputChannel, CongruenceOperatorList>::highestNode(TNode a, UnionFind_t& unionFind) const
   throw(AssertionException) {
   UnionFind_t::iterator i = unionFind.find(a);
   if(i == unionFind.end()) {
@@ -846,8 +892,8 @@ Node CongruenceClosure<OutputChannel>::highestNode(TNode a, UnionFind_t& unionFi
 }/* highestNode() */
 
 
-template <class OutputChannel>
-void CongruenceClosure<OutputChannel>::explainAlongPath(TNode a, TNode c, PendingProofList_t& pending, UnionFind_t& unionFind, std::list<Node>& pf)
+template <class OutputChannel, class CongruenceOperatorList>
+void CongruenceClosure<OutputChannel, CongruenceOperatorList>::explainAlongPath(TNode a, TNode c, PendingProofList_t& pending, UnionFind_t& unionFind, std::list<Node>& pf)
   throw(AssertionException) {
 
   a = highestNode(a, unionFind);
@@ -863,10 +909,10 @@ void CongruenceClosure<OutputChannel>::explainAlongPath(TNode a, TNode c, Pendin
       Assert(e.getKind() == kind::TUPLE);
       pf.push_back(e[0]);
       pf.push_back(e[1]);
-      Assert(e[0][0].getKind() == kind::APPLY_UF);
-      Assert(e[0][1].getKind() != kind::APPLY_UF);
-      Assert(e[1][0].getKind() == kind::APPLY_UF);
-      Assert(e[1][1].getKind() != kind::APPLY_UF);
+      Assert(isCongruenceOperator(e[0][0].getKind()));
+      Assert(!isCongruenceOperator(e[0][1].getKind()));
+      Assert(isCongruenceOperator(e[1][0].getKind()));
+      Assert(!isCongruenceOperator(e[1][1].getKind()));
       Assert(e[0][0].getNumChildren() == e[1][0].getNumChildren());
       Assert(e[0][0].getOperator() == e[1][0].getOperator(),
              "CongruenceClosure:: bad state: function symbols should be equal");
@@ -882,8 +928,8 @@ void CongruenceClosure<OutputChannel>::explainAlongPath(TNode a, TNode c, Pendin
 }/* explainAlongPath() */
 
 
-template <class OutputChannel>
-Node CongruenceClosure<OutputChannel>::nearestCommonAncestor(TNode a, TNode b, UnionFind_t& unionFind)
+template <class OutputChannel, class CongruenceOperatorList>
+Node CongruenceClosure<OutputChannel, CongruenceOperatorList>::nearestCommonAncestor(TNode a, TNode b, UnionFind_t& unionFind)
   throw(AssertionException) {
   SeenSet_t seen;
 
@@ -907,8 +953,8 @@ Node CongruenceClosure<OutputChannel>::nearestCommonAncestor(TNode a, TNode b, U
 }/* nearestCommonAncestor() */
 
 
-template <class OutputChannel>
-Node CongruenceClosure<OutputChannel>::explain(Node a, Node b)
+template <class OutputChannel, class CongruenceOperatorList>
+Node CongruenceClosure<OutputChannel, CongruenceOperatorList>::explain(Node a, Node b)
   throw(CongruenceClosureException, AssertionException) {
 
   Assert(a != b);
@@ -918,10 +964,10 @@ Node CongruenceClosure<OutputChannel>::explain(Node a, Node b)
                                      "that aren't congruent");
   }
 
-  if(a.getKind() == kind::APPLY_UF) {
+  if(isCongruenceOperator(a.getKind())) {
     a = replace(flatten(a));
   }
-  if(b.getKind() == kind::APPLY_UF) {
+  if(isCongruenceOperator(b.getKind())) {
     b = replace(flatten(b));
   }
 
@@ -976,44 +1022,44 @@ Node CongruenceClosure<OutputChannel>::explain(Node a, Node b)
 }/* explain() */
 
 
-template <class OutputChannel>
+template <class OutputChannel, class CongruenceOperatorList>
 std::ostream& operator<<(std::ostream& out,
-                         const CongruenceClosure<OutputChannel>& cc) {
+                         const CongruenceClosure<OutputChannel, CongruenceOperatorList>& cc) {
   out << "==============================================" << std::endl;
 
   /*out << "Representatives:" << std::endl;
-  for(typename CongruenceClosure<OutputChannel>::RepresentativeMap::const_iterator i = cc.d_representative.begin(); i != cc.d_representative.end(); ++i) {
+  for(typename CongruenceClosure<OutputChannel, CongruenceOperatorList>::RepresentativeMap::const_iterator i = cc.d_representative.begin(); i != cc.d_representative.end(); ++i) {
     out << "  " << (*i).first << " => " << (*i).second << std::endl;
   }*/
 
   out << "ClassLists:" << std::endl;
-  for(typename CongruenceClosure<OutputChannel>::ClassLists::const_iterator i = cc.d_classList.begin(); i != cc.d_classList.end(); ++i) {
+  for(typename CongruenceClosure<OutputChannel, CongruenceOperatorList>::ClassLists::const_iterator i = cc.d_classList.begin(); i != cc.d_classList.end(); ++i) {
     if(cc.find((*i).first) == (*i).first) {
       out << "  " << (*i).first << " =>" << std::endl;
-      for(typename CongruenceClosure<OutputChannel>::ClassList::const_iterator j = (*i).second->begin(); j != (*i).second->end(); ++j) {
+      for(typename CongruenceClosure<OutputChannel, CongruenceOperatorList>::ClassList::const_iterator j = (*i).second->begin(); j != (*i).second->end(); ++j) {
         out << "      " << *j << std::endl;
       }
     }
   }
 
   out << "UseLists:" << std::endl;
-  for(typename CongruenceClosure<OutputChannel>::UseLists::const_iterator i = cc.d_useList.begin(); i != cc.d_useList.end(); ++i) {
+  for(typename CongruenceClosure<OutputChannel, CongruenceOperatorList>::UseLists::const_iterator i = cc.d_useList.begin(); i != cc.d_useList.end(); ++i) {
     if(cc.find((*i).first) == (*i).first) {
       out << "  " << (*i).first << " =>" << std::endl;
-      for(typename CongruenceClosure<OutputChannel>::UseList::const_iterator j = (*i).second->begin(); j != (*i).second->end(); ++j) {
+      for(typename CongruenceClosure<OutputChannel, CongruenceOperatorList>::UseList::const_iterator j = (*i).second->begin(); j != (*i).second->end(); ++j) {
         out << "      " << *j << std::endl;
       }
     }
   }
 
   out << "Lookup:" << std::endl;
-  for(typename CongruenceClosure<OutputChannel>::LookupMap::const_iterator i = cc.d_lookup.begin(); i != cc.d_lookup.end(); ++i) {
+  for(typename CongruenceClosure<OutputChannel, CongruenceOperatorList>::LookupMap::const_iterator i = cc.d_lookup.begin(); i != cc.d_lookup.end(); ++i) {
     TNode n = (*i).second;
     out << "  " << (*i).first << " => " << n << std::endl;
   }
 
   out << "Care set:" << std::endl;
-  for(typename CongruenceClosure<OutputChannel>::CareSet::const_iterator i = cc.d_careSet.begin(); i != cc.d_careSet.end(); ++i) {
+  for(typename CongruenceClosure<OutputChannel, CongruenceOperatorList>::CareSet::const_iterator i = cc.d_careSet.begin(); i != cc.d_careSet.end(); ++i) {
     out << "  " << *i << std::endl;
   }
 
index 29a104a8ea9eefc6f357e5f11be7f66b69f7fc66..a12cb79ead4c1b342d7d73c6d92ea9e0513cbdff 100644 (file)
@@ -92,13 +92,18 @@ class CongruenceClosureWhite : public CxxTest::TestSuite {
   NodeManager* d_nm;
   NodeManagerScope* d_scope;
   MyOutputChannel* d_out;
-  CongruenceClosure<MyOutputChannel>* d_cc;
+  CongruenceClosure<MyOutputChannel, CongruenceOperator<kind::APPLY_UF> >* d_cc;
+  CongruenceClosure<MyOutputChannel, CONGRUENCE_OPERATORS_2(kind::SELECT, kind::STORE)>* d_ccArray;
 
   TypeNode U;
   Node a, f, fa, ffa, fffa, ffffa, b, fb, ffb, fffb, ffffb;
   Node g, gab, gba, gfafb, gfbfa, gfaa, gbfb;
   Node h, hab, hba, hfaa;
   Node a_eq_b, fa_eq_b, a_eq_fb, fa_eq_fb, h_eq_g;
+
+  Node ar, ar_a, ar_b;
+  Node arar, arar_a, arar_b;
+
 public:
 
   void setUp() {
@@ -106,7 +111,8 @@ public:
     d_nm = new NodeManager(d_context);
     d_scope = new NodeManagerScope(d_nm);
     d_out = new MyOutputChannel(d_context, d_nm);
-    d_cc = new CongruenceClosure<MyOutputChannel>(d_context, d_out);
+    d_cc = new CongruenceClosure<MyOutputChannel, CongruenceOperator<kind::APPLY_UF> >(d_context, d_out);
+    d_ccArray = new CongruenceClosure<MyOutputChannel, CONGRUENCE_OPERATORS_2(kind::SELECT, kind::STORE)>(d_context, d_out);
 
     U = d_nm->mkSort("U");
 
@@ -140,10 +146,22 @@ public:
     fa_eq_fb = d_nm->mkNode(kind::EQUAL, fa, fb);
 
     h_eq_g = d_nm->mkNode(kind::EQUAL, h, g);
+
+    // arrays
+    ar = d_nm->mkVar("ar", d_nm->mkArrayType(U, U));
+    ar_a = d_nm->mkNode(kind::SELECT, ar, a);
+    ar_b = d_nm->mkNode(kind::SELECT, ar, b);
+
+    arar = d_nm->mkVar("arar", d_nm->mkArrayType(U, d_nm->mkArrayType(U, U)));
+    arar_a = d_nm->mkNode(kind::SELECT, arar, a);
+    arar_b = d_nm->mkNode(kind::SELECT, arar, b);
   }
 
   void tearDown() {
     try {
+      arar = arar_a = arar_b = Node::null();
+      ar = ar_a = ar_b = Node::null();
+
       f = a = fa = ffa = fffa = ffffa = Node::null();
       b = fb = ffb = fffb = ffffb = Node::null();
       g = gab = gba = gfafb = gfbfa = gfaa = gbfb = Node::null();
@@ -151,14 +169,14 @@ public:
       a_eq_b = fa_eq_b = a_eq_fb = fa_eq_fb = h_eq_g = Node::null();
       U = TypeNode::null();
 
+      delete d_ccArray;
       delete d_cc;
       delete d_out;
       delete d_scope;
       delete d_nm;
       delete d_context;
-
     } catch(Exception& e) {
-      Warning() << std::endl << e << std::endl;
+      cout << "\n\n" << e << "\n\n";
       throw;
     }
   }
@@ -382,7 +400,7 @@ public:
     TS_ASSERT(d_out->areCongruent(gfaa, gba));
     TS_ASSERT(d_cc->areCongruent(gfaa, gba));
 
-    } catch(Exception e) {
+    } catch(Exception& e) {
       cout << "\n\n" << e << "\n\n";
       throw;
     }
@@ -430,7 +448,10 @@ public:
     d_cc->addEquality(d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::APPLY_UF, f1,c_1,c_0),c_1));
     d_cc->addEquality(d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::APPLY_UF, f1,c_0,c_1),c_0));
     d_cc->addEquality(d_nm->mkNode(kind::EQUAL, d_nm->mkNode(kind::APPLY_UF, f1,c_0,c_0),c_0));
-    }catch(Exception &e) { cout << e << "\n"; throw e; }
+    } catch(Exception& e) {
+      cout << "\n\n" << e << "\n\n";
+      throw e;
+    }
   }
 
   void testUF2() {
@@ -450,4 +471,59 @@ public:
     d_cc->addEquality(ffffx_eq_z);
   }
 
+  void testSimpleArray() {
+    //Debug.on("cc");
+    // add terms, then add equalities
+    try {
+    d_ccArray->addTerm(ar_a);
+    d_ccArray->addTerm(ar_b);
+
+    d_ccArray->addEquality(a_eq_b);
+
+    TS_ASSERT(d_out->areCongruent(ar_a, ar_b));
+    TS_ASSERT(d_ccArray->areCongruent(ar_a, ar_b));
+
+    TS_ASSERT(!d_out->areCongruent(a, b));
+    TS_ASSERT(d_ccArray->areCongruent(a, b));
+    } catch(Exception& e) {
+      cout << "\n\n" << e << "\n\n";
+      throw;
+    }
+  }
+
+  void testSimpleReverseArray() {
+    // add equalities, then add terms; should get the same as
+    // testSimple()
+
+    d_ccArray->addEquality(a_eq_b);
+
+    d_ccArray->addTerm(ar_a);
+    d_ccArray->addTerm(ar_b);
+
+    TS_ASSERT(d_out->areCongruent(ar_a, ar_b));
+    TS_ASSERT(d_ccArray->areCongruent(ar_a, ar_b));
+
+    TS_ASSERT(!d_out->areCongruent(a, b));
+    TS_ASSERT(d_ccArray->areCongruent(a, b));
+  }
+
+  void testArray() {
+    Node ar_eq_arar_a = d_nm->mkNode(kind::EQUAL, ar, arar_a);
+    Node ar2 = d_nm->mkVar("ar2", d_nm->mkArrayType(U, U));
+    Node store_arar_b_ar2 = d_nm->mkNode(kind::STORE, arar, b, ar2);
+    Node select__store_arar_b_ar2__a =
+      d_nm->mkNode(kind::SELECT, store_arar_b_ar2, a);
+    Node select__store_arar_b_ar2__a__eq__arar_a =
+      d_nm->mkNode(kind::EQUAL, select__store_arar_b_ar2__a, arar_a);
+
+    d_ccArray->addTerm(ar);
+    d_ccArray->addTerm(select__store_arar_b_ar2__a);
+
+    d_ccArray->addEquality(ar_eq_arar_a);
+    d_ccArray->addEquality(select__store_arar_b_ar2__a__eq__arar_a);
+
+    TS_ASSERT(d_ccArray->areCongruent(ar, select__store_arar_b_ar2__a));
+    TS_ASSERT(d_out->areCongruent(ar, select__store_arar_b_ar2__a));
+  }
+
 };/* class CongruenceClosureWhite */