Use TypeNode in UninterpretedConstant (#4748)
authorAndres Noetzli <andres.noetzli@gmail.com>
Wed, 15 Jul 2020 15:27:13 +0000 (08:27 -0700)
committerGitHub <noreply@github.com>
Wed, 15 Jul 2020 15:27:13 +0000 (10:27 -0500)
This commit changes UninterpretedConstant to use TypeNode instead of
Type.

12 files changed:
src/api/cvc4cpp.cpp
src/expr/uninterpreted_constant.cpp
src/expr/uninterpreted_constant.h
src/theory/builtin/theory_builtin_type_rules.h
src/theory/builtin/type_enumerator.h
src/theory/datatypes/datatypes_rewriter.cpp
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/type_enumerator.cpp
src/theory/rewriter.cpp
test/unit/theory/theory_sets_type_enumerator_white.h
test/unit/theory/type_enumerator_white.h
test/unit/util/array_store_all_white.h

index fcf0c028e4be1e3c569ed1c50701e202fdeac0cc..c4ba701b93816d760cca88753bc654e487698a12 100644 (file)
@@ -3596,7 +3596,7 @@ Term Solver::mkUninterpretedConst(Sort sort, int32_t index) const
   CVC4_API_SOLVER_CHECK_SORT(sort);
 
   return mkValHelper<CVC4::UninterpretedConstant>(
-      CVC4::UninterpretedConstant(*sort.d_type, index));
+      CVC4::UninterpretedConstant(TypeNode::fromType(*sort.d_type), index));
 
   CVC4_API_SOLVER_TRY_CATCH_END;
 }
index 2b66a3b7abeb998c46ca4919c970b50d75389147..c9cbcba20a59f58c3d3ab03d933587c263850834 100644 (file)
 #include <string>
 
 #include "base/check.h"
+#include "expr/type_node.h"
 
 using namespace std;
 
 namespace CVC4 {
 
-UninterpretedConstant::UninterpretedConstant(Type type, Integer index)
-    : d_type(type), d_index(index)
+UninterpretedConstant::UninterpretedConstant(const TypeNode& type,
+                                             Integer index)
+    : d_type(new TypeNode(type)), d_index(index)
 {
   //PrettyCheckArgument(type.isSort(), type, "uninterpreted constants can only be created for uninterpreted sorts, not `%s'", type.toString().c_str());
   PrettyCheckArgument(index >= 0, index, "index >= 0 required for uninterpreted constant index, not `%s'", index.toString().c_str());
 }
 
+UninterpretedConstant::~UninterpretedConstant() {}
+
+UninterpretedConstant::UninterpretedConstant(const UninterpretedConstant& other)
+    : d_type(new TypeNode(other.getType())), d_index(other.getIndex())
+{
+}
+
+const TypeNode& UninterpretedConstant::getType() const { return *d_type; }
+const Integer& UninterpretedConstant::getIndex() const { return d_index; }
+bool UninterpretedConstant::operator==(const UninterpretedConstant& uc) const
+{
+  return getType() == uc.getType() && d_index == uc.d_index;
+}
+bool UninterpretedConstant::operator!=(const UninterpretedConstant& uc) const
+{
+  return !(*this == uc);
+}
+
+bool UninterpretedConstant::operator<(const UninterpretedConstant& uc) const
+{
+  return getType() < uc.getType()
+         || (getType() == uc.getType() && d_index < uc.d_index);
+}
+bool UninterpretedConstant::operator<=(const UninterpretedConstant& uc) const
+{
+  return getType() < uc.getType()
+         || (getType() == uc.getType() && d_index <= uc.d_index);
+}
+bool UninterpretedConstant::operator>(const UninterpretedConstant& uc) const
+{
+  return !(*this <= uc);
+}
+bool UninterpretedConstant::operator>=(const UninterpretedConstant& uc) const
+{
+  return !(*this < uc);
+}
+
 std::ostream& operator<<(std::ostream& out, const UninterpretedConstant& uc) {
   std::stringstream ss;
   ss << uc.getType();
@@ -49,4 +88,11 @@ std::ostream& operator<<(std::ostream& out, const UninterpretedConstant& uc) {
   return out << "uc_" << st.c_str() << "_" << uc.getIndex();
 }
 
+size_t UninterpretedConstantHashFunction::operator()(
+    const UninterpretedConstant& uc) const
+{
+  return TypeNodeHashFunction()(uc.getType())
+         * IntegerHashFunction()(uc.getIndex());
+}
+
 }/* CVC4 namespace */
index 46d9a2800f4128fc9e86f0596d5fc3b62409f43c..eb6cc203af13e4f2e0399a06f954fd1e5f9e21c1 100644 (file)
 
 #include "cvc4_public.h"
 
-#pragma once
+#ifndef CVC4__UNINTERPRETED_CONSTANT_H
+#define CVC4__UNINTERPRETED_CONSTANT_H
 
 #include <iosfwd>
+#include <memory>
 
-#include "expr/type.h"
+#include "util/integer.h"
 
 namespace CVC4 {
 
-class CVC4_PUBLIC UninterpretedConstant {
+class TypeNode;
+
+class UninterpretedConstant
+{
  public:
-  UninterpretedConstant(Type type, Integer index);
+  UninterpretedConstant(const TypeNode& type, Integer index);
+  ~UninterpretedConstant();
 
-  Type getType() const { return d_type; }
-  const Integer& getIndex() const { return d_index; }
-  bool operator==(const UninterpretedConstant& uc) const
-  {
-    return d_type == uc.d_type && d_index == uc.d_index;
-  }
-  bool operator!=(const UninterpretedConstant& uc) const
-  {
-    return !(*this == uc);
-  }
+  UninterpretedConstant(const UninterpretedConstant& other);
 
-  bool operator<(const UninterpretedConstant& uc) const
-  {
-    return d_type < uc.d_type ||
-           (d_type == uc.d_type && d_index < uc.d_index);
-  }
-  bool operator<=(const UninterpretedConstant& uc) const
-  {
-    return d_type < uc.d_type ||
-           (d_type == uc.d_type && d_index <= uc.d_index);
-  }
-  bool operator>(const UninterpretedConstant& uc) const
-  {
-    return !(*this <= uc);
-  }
-  bool operator>=(const UninterpretedConstant& uc) const
-  {
-    return !(*this < uc);
-  }
+  const TypeNode& getType() const;
+  const Integer& getIndex() const;
+  bool operator==(const UninterpretedConstant& uc) const;
+  bool operator!=(const UninterpretedConstant& uc) const;
+  bool operator<(const UninterpretedConstant& uc) const;
+  bool operator<=(const UninterpretedConstant& uc) const;
+  bool operator>(const UninterpretedConstant& uc) const;
+  bool operator>=(const UninterpretedConstant& uc) const;
 
  private:
-  const Type d_type;
+  std::unique_ptr<TypeNode> d_type;
   const Integer d_index;
-};/* class UninterpretedConstant */
+}; /* class UninterpretedConstant */
 
-std::ostream& operator<<(std::ostream& out, const UninterpretedConstant& uc) CVC4_PUBLIC;
+std::ostream& operator<<(std::ostream& out, const UninterpretedConstant& uc);
 
 /**
  * Hash function for the BitVector constants.
  */
-struct CVC4_PUBLIC UninterpretedConstantHashFunction {
-  inline size_t operator()(const UninterpretedConstant& uc) const {
-    return TypeHashFunction()(uc.getType()) * IntegerHashFunction()(uc.getIndex());
-  }
-};/* struct UninterpretedConstantHashFunction */
+struct CVC4_PUBLIC UninterpretedConstantHashFunction
+{
+  size_t operator()(const UninterpretedConstant& uc) const;
+}; /* struct UninterpretedConstantHashFunction */
+
+}  // namespace CVC4
 
-}/* CVC4 namespace */
+#endif /* CVC4__UNINTERPRETED_CONSTANT_H */
index 3a6b9bfff3143dd4b99c076928dd57d2a51d5a7f..29ac4f2d168b1ccfb6e1ed5ebf8a768dc801b8f7 100644 (file)
@@ -96,7 +96,7 @@ class SExprTypeRule {
 class UninterpretedConstantTypeRule {
  public:
   inline static TypeNode computeType(NodeManager* nodeManager, TNode n, bool check) {
-    return TypeNode::fromType(n.getConst<UninterpretedConstant>().getType());
+    return n.getConst<UninterpretedConstant>().getType();
   }
 };/* class UninterpretedConstantTypeRule */
 
index edac15d8650da554983874acf5c14502a4e4ba35..18dcf4521b19c17299c570c647771f8f548ef110 100644 (file)
@@ -60,7 +60,8 @@ class UninterpretedSortEnumerator : public TypeEnumeratorBase<UninterpretedSortE
     if(isFinished()) {
       throw NoMoreValuesException(getType());
     }
-    return NodeManager::currentNM()->mkConst(UninterpretedConstant(getType().toType(), d_count));
+    return NodeManager::currentNM()->mkConst(
+        UninterpretedConstant(getType(), d_count));
   }
 
   UninterpretedSortEnumerator& operator++() override
index 699e26d21e2924b3bfd1bd6ac2885593e9a6147d..450a0fd370979214692d2bf7c47561ba49f62f5a 100644 (file)
@@ -720,7 +720,7 @@ Node DatatypesRewriter::normalizeCodatatypeConstantEqc(
     {
       int debruijn = depth - it->second - 1;
       return NodeManager::currentNM()->mkConst(
-          UninterpretedConstant(n.getType().toType(), debruijn));
+          UninterpretedConstant(n.getType(), debruijn));
     }
     std::vector<Node> children;
     bool childChanged = false;
index d6cb3b37e94c55e7556e00ce9a1b1990a612eb51..505d08c38f2bda6e1fd14be25429c891cd4a0ee7 100644 (file)
@@ -1642,7 +1642,8 @@ Node TheoryDatatypes::getCodatatypesValue( Node n, std::map< Node, Node >& eqc_c
   std::map< Node, int >::iterator itv = vmap.find( n );
   if( itv!=vmap.end() ){
     int debruijn = depth - 1 - itv->second;
-    return NodeManager::currentNM()->mkConst(UninterpretedConstant(n.getType().toType(), debruijn));
+    return NodeManager::currentNM()->mkConst(
+        UninterpretedConstant(n.getType(), debruijn));
   }else if( n.getType().isDatatype() ){
     Node nc = eqc_cons[n];
     if( !nc.isNull() ){
index ccaba009df16343e5703810d270b6899f08444e8..3dca74b190bf591883f85fd3b3489864f6557d8e 100644 (file)
@@ -107,7 +107,7 @@ Node DatatypesEnumerator::getTermEnum( TypeNode tn, unsigned i ){
      if (d_child_enum)
      {
        ret = NodeManager::currentNM()->mkConst(
-           UninterpretedConstant(d_type.toType(), d_size_limit));
+           UninterpretedConstant(d_type, d_size_limit));
      }
      else
      {
index d77f6fe836674bcc56c3ebac969a71806c6c4d2b..f2e13d1e0ce21c895655250e45a7d10aa50b45b1 100644 (file)
@@ -333,7 +333,8 @@ Node Rewriter::rewriteTo(theory::TheoryId theoryId,
 #ifdef CVC4_ASSERTIONS
           RewriteResponse r2 =
               d_theoryRewriters[newTheoryId]->postRewrite(response.d_node);
-          Assert(r2.d_node == response.d_node);
+          Assert(r2.d_node == response.d_node)
+              << r2.d_node << " != " << response.d_node;
 #endif
           rewriteStackTop.d_node = response.d_node;
           break;
index 67a36200f7e8941fc0c9ff6b4869efe821542568..0a97f4c427003e97ab4e90eed69cddceb1d76ce5 100644 (file)
@@ -86,13 +86,11 @@ class SetEnumeratorWhite : public CxxTest::TestSuite
 
   void testSetOfUF()
   {
-    TypeNode typeNode = d_nm->mkSort("Atom");
-    Type sort = typeNode.toType();
-    SetEnumerator setEnumerator(d_nm->mkSetType(typeNode));
+    TypeNode sort = d_nm->mkSort("Atom");
+    SetEnumerator setEnumerator(d_nm->mkSetType(sort));
 
     Node actual0 = *setEnumerator;
-    Node expected0 =
-        d_nm->mkConst(EmptySet(d_nm->mkSetType(typeNode)));
+    Node expected0 = d_nm->mkConst(EmptySet(d_nm->mkSetType(sort)));
     TS_ASSERT_EQUALS(expected0, actual0);
     TS_ASSERT(!setEnumerator.isFinished());
 
index b996919ee80d3b805182545b69cd0af3e9af7ea2..cf1f002aa772902e4638f82c1ec71da1d49472e2 100644 (file)
@@ -70,11 +70,9 @@ class TypeEnumeratorWhite : public CxxTest::TestSuite {
   }
 
   void testUF() {
-    TypeNode sortn = d_nm->mkSort("T");
-    Type sort = sortn.toType();
-    TypeNode sortn2 = d_nm->mkSort("U");
-    Type sort2 = sortn2.toType();
-    TypeEnumerator te(sortn);
+    TypeNode sort = d_nm->mkSort("T");
+    TypeNode sort2 = d_nm->mkSort("U");
+    TypeEnumerator te(sort);
     TS_ASSERT_EQUALS(*te, d_nm->mkConst(UninterpretedConstant(sort, 0)));
     for(size_t i = 1; i < 100; ++i) {
       TS_ASSERT_DIFFERS(*te, d_nm->mkConst(UninterpretedConstant(sort, i)));
index fb7857003457755358b6f3b365f81f5f9ac7186f..a027d86ce7932fd592963cc565a181875a3ed322 100644 (file)
@@ -50,7 +50,7 @@ class ArrayStoreAllWhite : public CxxTest::TestSuite
     ArrayStoreAll(d_nm->mkArrayType(d_nm->integerType(), d_nm->realType()),
                   d_nm->mkConst(Rational(9, 2)));
     ArrayStoreAll(d_nm->mkArrayType(d_nm->mkSort("U"), usort),
-                  d_nm->mkConst(UninterpretedConstant(usort.toType(), 0)));
+                  d_nm->mkConst(UninterpretedConstant(usort, 0)));
     ArrayStoreAll(d_nm->mkArrayType(d_nm->mkBitVectorType(8), d_nm->realType()),
                   d_nm->mkConst(Rational(0)));
     ArrayStoreAll(
@@ -62,7 +62,7 @@ class ArrayStoreAllWhite : public CxxTest::TestSuite
   {
     TS_ASSERT_THROWS(ArrayStoreAll(d_nm->integerType(),
                                    d_nm->mkConst(UninterpretedConstant(
-                                       d_nm->mkSort("U").toType(), 0))),
+                                       d_nm->mkSort("U"), 0))),
                      IllegalArgumentException&);
     TS_ASSERT_THROWS(
         ArrayStoreAll(d_nm->integerType(), d_nm->mkConst(Rational(9, 2))),