Eliminate the use of CAST_TO_REAL (#8759)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sun, 15 May 2022 17:11:42 +0000 (12:11 -0500)
committerGitHub <noreply@github.com>
Sun, 15 May 2022 17:11:42 +0000 (17:11 +0000)
This simplifies the implementation of the API by not relying on CAST_TO_REAL. This was used as a way of manually marking integral reals as having real type.

12 files changed:
src/api/cpp/cvc5.cpp
src/api/cpp/cvc5.h
src/expr/subtype_elim_node_converter.cpp
src/expr/subtype_elim_node_converter.h
src/printer/smt2/smt2_printer.cpp
src/printer/smt2/smt2_printer.h
src/proof/alethe/alethe_nosubtype_node_converter.cpp
src/proof/lfsc/lfsc_node_converter.cpp
src/theory/arith/arith_rewriter.cpp
src/theory/arith/kinds
src/theory/arith/theory_arith_type_rules.cpp
src/theory/evaluator.cpp

index 665bbdc6b102ed47f1ba21ee5b7cc95a183923bd..05e247fd6547dbcb43e2697ae16d9eb1841e4a99 100644 (file)
@@ -2212,10 +2212,6 @@ size_t Term::getNumChildren() const
   {
     return d_node->getNumChildren() + 1;
   }
-  if (isCastedReal())
-  {
-    return 0;
-  }
   return d_node->getNumChildren();
   ////////
   CVC5_API_TRY_CATCH_END;
@@ -2603,8 +2599,6 @@ const internal::Rational& getRational(const internal::Node& node)
 {
   switch (node.getKind())
   {
-    case internal::Kind::CAST_TO_REAL:
-      return node[0].getConst<internal::Rational>();
     case internal::Kind::CONST_INTEGER:
     case internal::Kind::CONST_RATIONAL:
       return node.getConst<internal::Rational>();
@@ -2637,8 +2631,7 @@ bool checkReal64Bounds(const internal::Rational& r)
 bool isReal(const internal::Node& node)
 {
   return node.getKind() == internal::Kind::CONST_RATIONAL
-         || node.getKind() == internal::Kind::CONST_INTEGER
-         || node.getKind() == internal::Kind::CAST_TO_REAL;
+         || node.getKind() == internal::Kind::CONST_INTEGER;
 }
 bool isReal32(const internal::Node& node)
 {
@@ -3341,22 +3334,9 @@ Kind Term::getKindHelper() const
   }
   // Notice that kinds like APPLY_TYPE_ASCRIPTION will be converted to
   // INTERNAL_KIND.
-  if (isCastedReal())
-  {
-    return CONST_RATIONAL;
-  }
   return intToExtKind(d_node->getKind());
 }
 
-bool Term::isCastedReal() const
-{
-  if (d_node->getKind() == internal::kind::CAST_TO_REAL)
-  {
-    return (*d_node)[0].isConst() && (*d_node)[0].getType().isInteger();
-  }
-  return false;
-}
-
 /* -------------------------------------------------------------------------- */
 /* Datatypes                                                                  */
 /* -------------------------------------------------------------------------- */
@@ -4923,13 +4903,7 @@ Term Solver::mkRationalValHelper(const internal::Rational& r, bool isInt) const
   internal::NodeManager* nm = getNodeManager();
   internal::Node res = isInt ? nm->mkConstInt(r) : nm->mkConstReal(r);
   (void)res.getType(true); /* kick off type checking */
-  Term t = Term(this, res);
-  // NOTE: this block will be eliminated when arithmetic subtyping is eliminated
-  if (!isInt)
-  {
-    t = ensureRealSort(t);
-  }
-  return t;
+  return Term(this, res);
 }
 
 Term Solver::mkRealOrIntegerFromStrHelper(const std::string& s,
@@ -5000,12 +4974,7 @@ Term Solver::getValueHelper(const Term& term) const
   //////// all checks before this line
   internal::Node value = d_slv->getValue(*term.d_node);
   Term res = Term(this, value);
-  // May need to wrap in real cast so that user know this is a real.
-  internal::TypeNode tn = (*term.d_node).getType();
-  if (!tn.isInteger() && value.getType().isInteger())
-  {
-    return ensureRealSort(res);
-  }
+  Assert(res.getSort() == term.getSort());
   return res;
 }
 
@@ -5208,23 +5177,6 @@ Term Solver::ensureTermSort(const Term& term, const Sort& sort) const
   return res;
 }
 
-Term Solver::ensureRealSort(const Term& t) const
-{
-  Assert(this == t.d_solver);
-  CVC5_API_ARG_CHECK_EXPECTED(
-      t.getSort() == getIntegerSort() || t.getSort() == getRealSort(),
-      " an integer or real term");
-  // Note: Term is checked in the caller to avoid double checks
-  //////// all checks before this line
-  if (t.getSort() == getIntegerSort())
-  {
-    internal::Node n =
-        getNodeManager()->mkNode(internal::kind::CAST_TO_REAL, *t.d_node);
-    return Term(this, n);
-  }
-  return t;
-}
-
 bool Solver::isValidInteger(const std::string& s) const
 {
   //////// all checks before this line
@@ -5862,14 +5814,7 @@ Term Solver::mkConstArray(const Sort& sort, const Term& val) const
   CVC5_API_CHECK(val.getSort() == sort.getArrayElementSort())
       << "Value does not match element sort";
   //////// all checks before this line
-
-  // handle the special case of (CAST_TO_REAL n) where n is an integer
   internal::Node n = *val.d_node;
-  if (val.isCastedReal())
-  {
-    // this is safe because the constant array stores its type
-    n = n[0];
-  }
   Term res = mkValHelper(internal::ArrayStoreAll(*sort.d_type, n));
   return res;
   ////////
index 0e7dd5da33da54938b49a7c60c3e99145e2f6c99..975093e8d33ee6a23394d813384246eebcd30585 100644 (file)
@@ -1726,11 +1726,6 @@ class CVC5_EXPORT Term
    */
   Kind getKindHelper() const;
 
-  /**
-   * @return True if the current term is a constant integer that is casted into
-   *         real using the operator CAST_TO_REAL, and returns false otherwise
-   */
-  bool isCastedReal() const;
   /**
    * The internal node wrapped by this term.
    * @note This is a ``std::shared_ptr`` rather than a ``std::unique_ptr`` to
@@ -5016,14 +5011,6 @@ class CVC5_EXPORT Solver
   /** Get value helper, which accounts for subtyping */
   Term getValueHelper(const Term& term) const;
 
-  /**
-   * Helper function that ensures that a given term is of sort real (as opposed
-   * to being of sort integer).
-   * @param t A term of sort integer or real.
-   * @return A term of sort real.
-   */
-  Term ensureRealSort(const Term& t) const;
-
   /**
    * Create n-ary term of given kind. This handles the cases of left/right
    * associative operators, chainable operators, and cases when the number of
index 1f86605cf9d3fbaa3798169a64b0f698e725ee36..5e6362a1adf3193547e2c88feabdcfa54e8a7c07 100644 (file)
@@ -34,11 +34,12 @@ Node SubtypeElimNodeConverter::postConvert(Node n)
   {
     convertToRealChildren = isRealTypeStrict(n.getType());
   }
-  else if (k == EQUAL || k == GEQ)
+  else if (k == GEQ || k == GT || k == LEQ || k == LT)
   {
     convertToRealChildren =
         isRealTypeStrict(n[0].getType()) || isRealTypeStrict(n[1].getType());
   }
+  // note that EQUAL is strictly typed so we don't need to handle it here
   if (convertToRealChildren)
   {
     NodeManager* nm = NodeManager::currentNM();
@@ -47,10 +48,16 @@ Node SubtypeElimNodeConverter::postConvert(Node n)
     {
       if (nc.getType().isInteger())
       {
-        // we use CAST_TO_REAL for constants, so that e.g. 5 is printed as
-        // 5.0 not (to_real 5)
-        Kind nk = nc.isConst() ? CAST_TO_REAL : TO_REAL;
-        children.push_back(nm->mkNode(nk, nc));
+        if (nc.isConst())
+        {
+          // we convert constant integers to constant reals
+          children.push_back(nm->mkConstReal(nc.getConst<Rational>()));
+        }
+        else
+        {
+          // otherwise, use TO_REAL
+          children.push_back(nm->mkNode(TO_REAL, nc));
+        }
       }
       else
       {
index ea7d3f8f44126b3067f68a7023c13bee122d84f4..efce4a9d785f23b1bad2e8247a96cba264e1ea1f 100644 (file)
@@ -30,6 +30,10 @@ namespace cvc5::internal {
  * This converts a node into one that does not involve (arithmetic) subtyping.
  * In particular, all applications of arithmetic symbols that involve at least
  * one (strict) Real child are such that all children are cast to real.
+ *
+ * Note this converter is necessary since our type rules for arithmetic
+ * operators are more permissive internally than in SMT-LIB, since e.g. ADD
+ * can mix Int and Real children.
  */
 class SubtypeElimNodeConverter : public NodeConverter
 {
index 49e0ed5256ef09bd5d5faed8d1e4c01a2eb6968c..1276c89937df59a8130f912831770141dc01fbdf 100644 (file)
@@ -287,7 +287,7 @@ void Smt2Printer::toStream(std::ostream& out,
         for (const Node& snvc : snvec)
         {
           out << " (seq.unit ";
-          toStreamCastToType(out, snvc, toDepth, elemType);
+          toStream(out, snvc, toDepth);
           out << ")";
         }
         out << ")";
@@ -295,7 +295,7 @@ void Smt2Printer::toStream(std::ostream& out,
       else
       {
         out << "(seq.unit ";
-        toStreamCastToType(out, snvec[0], toDepth, elemType);
+        toStream(out, snvec[0], toDepth);
         out << ")";
       }
       break;
@@ -306,10 +306,7 @@ void Smt2Printer::toStream(std::ostream& out,
       out << "((as const ";
       toStreamType(out, asa.getType());
       out << ") ";
-      toStreamCastToType(out,
-                         asa.getValue(),
-                         toDepth < 0 ? toDepth : toDepth - 1,
-                         asa.getType().getArrayConstituentType());
+      toStream(out, asa.getValue(), toDepth < 0 ? toDepth : toDepth - 1);
       out << ")";
       break;
     }
@@ -501,11 +498,6 @@ void Smt2Printer::toStream(std::ostream& out,
     force_nt = n.getOperator().getConst<AscriptionType>().getType();
     type_asc_arg = n[0];
   }
-  else if (k == kind::CAST_TO_REAL)
-  {
-    force_nt = nm->realType();
-    type_asc_arg = n[0];
-  }
   if (!type_asc_arg.isNull())
   {
     if (force_nt.isRealOrInt())
@@ -729,9 +721,7 @@ void Smt2Printer::toStream(std::ostream& out,
   case kind::SEQ_UNIT:
   {
     out << smtKindString(k, d_variant) << " ";
-    TypeNode elemType = n.getType().getSequenceElementType();
-    toStreamCastToType(
-        out, n[0], toDepth < 0 ? toDepth : toDepth - 1, elemType);
+    toStream(out, n[0], toDepth < 0 ? toDepth : toDepth - 1);
     out << ")";
     return;
   }
@@ -741,9 +731,7 @@ void Smt2Printer::toStream(std::ostream& out,
   case kind::SET_SINGLETON:
   {
     out << smtKindString(k, d_variant) << " ";
-    TypeNode elemType = n.getType().getSetElementType();
-    toStreamCastToType(
-        out, n[0], toDepth < 0 ? toDepth : toDepth - 1, elemType);
+    toStream(out, n[0], toDepth < 0 ? toDepth : toDepth - 1);
     out << ")";
     return;
   }
@@ -755,9 +743,7 @@ void Smt2Printer::toStream(std::ostream& out,
   {
     // print (bag (BAG_MAKE_OP Real) 1 3) as (bag 1.0 3)
     out << smtKindString(k, d_variant) << " ";
-    TypeNode elemType = n.getType().getBagElementType();
-    toStreamCastToType(
-        out, n[0], toDepth < 0 ? toDepth : toDepth - 1, elemType);
+    toStream(out, n[0], toDepth < 0 ? toDepth : toDepth - 1);
     out << " " << n[1] << ")";
     return;
   }
@@ -1042,25 +1028,6 @@ void Smt2Printer::toStream(std::ostream& out,
   }
 }
 
-void Smt2Printer::toStreamCastToType(std::ostream& out,
-                                     TNode n,
-                                     int toDepth,
-                                     TypeNode tn) const
-{
-  Node nasc;
-  if (n.getType().isInteger() && !tn.isInteger())
-  {
-    Assert(tn.isReal());
-    // probably due to subtyping integers and reals, cast it
-    nasc = NodeManager::currentNM()->mkNode(kind::CAST_TO_REAL, n);
-  }
-  else
-  {
-    nasc = n;
-  }
-  toStream(out, nasc, toDepth);
-}
-
 std::string Smt2Printer::smtKindString(Kind k, Variant v)
 {
   switch(k) {
@@ -1530,14 +1497,14 @@ void Smt2Printer::toStreamModelTerm(std::ostream& out,
     TypeNode rangeType = n.getType().getRangeType();
     out << "(define-fun " << n << " " << value[0] << " " << rangeType << " ";
     // call toStream and force its type to be proper
-    toStreamCastToType(out, value[1], -1, rangeType);
+    toStream(out, value[1], -1);
     out << ")" << endl;
   }
   else
   {
     out << "(define-fun " << n << " () " << n.getType() << " ";
     // call toStream and force its type to be proper
-    toStreamCastToType(out, value, -1, n.getType());
+    toStream(out, value, -1);
     out << ")" << endl;
   }
 }
index 57688255d0befd24b0e328be3eec444fb5dd0cea..9198c4628ddb9f97a4ee19b70d8dec48ffe513fc 100644 (file)
@@ -298,15 +298,7 @@ class Smt2Printer : public cvc5::internal::Printer
   void toStreamDeclareType(std::ostream& out, TypeNode tn) const;
   /** To stream type node, which ensures tn is printed in smt2 format */
   void toStreamType(std::ostream& out, TypeNode tn) const;
-  /**
-   * To stream, with a forced type. This method is used in some corner cases
-   * to force a node n to be printed as if it had type tn. This is used e.g.
-   * for the body of define-fun commands and arguments of singleton terms.
-   */
-  void toStreamCastToType(std::ostream& out,
-                          TNode n,
-                          int toDepth,
-                          TypeNode tn) const;
+  /** To stream datatype */
   void toStream(std::ostream& out, const DType& dt) const;
   /**
    * To stream model sort. This prints the appropriate output for type
index f780b6034368300bf99c4e3dee3fdb2460e1cc73..0f5c3b5b741780aa9d417742771ce049a373a572 100644 (file)
@@ -42,7 +42,7 @@ Node AletheNoSubtypeNodeConverter::postConvert(Node n)
           << "\t\t..arg " << i << " is integer constant " << n[i]
           << " in real position.\n";
       childChanged = true;
-      children.push_back(nm->mkNode(kind::CAST_TO_REAL, n[i]));
+      children.push_back(nm->mkNode(kind::TO_REAL, n[i]));
     }
     if (childChanged)
     {
index aef461adc814849f9d652a8b79a3488028985f34..ee644e2d4dcd92bbbda25fdae701636c498e4355 100644 (file)
@@ -214,18 +214,8 @@ Node LfscNodeConverter::postConvert(Node n)
     Node hconstf = getSymbolInternal(k, tnh, "apply");
     return nm->mkNode(APPLY_UF, hconstf, n[0], n[1]);
   }
-  else if (k == CONST_RATIONAL || k == CONST_INTEGER || k == CAST_TO_REAL)
+  else if (k == CONST_RATIONAL || k == CONST_INTEGER)
   {
-    if (k == CAST_TO_REAL)
-    {
-      // already converted
-      do
-      {
-        n = n[0];
-        Assert(n.getKind() == APPLY_UF || n.getKind() == CONST_RATIONAL
-               || n.getKind() == CONST_INTEGER);
-      } while (n.getKind() != CONST_RATIONAL && n.getKind() != CONST_INTEGER);
-    }
     TypeNode tnv = nm->mkFunctionType(tn, tn);
     Node rconstf;
     Node arg;
index 77507b1b0dc525a1a28fe2b201f2bca1176c050c..cfd4498c10d1ea132cecff67bd4199ff0cf1a804 100644 (file)
@@ -250,7 +250,6 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
       case kind::IS_INTEGER:
       case kind::TO_INTEGER:
       case kind::TO_REAL:
-      case kind::CAST_TO_REAL:
       case kind::POW:
       case kind::PI: return RewriteResponse(REWRITE_DONE, t);
       default: Unhandled() << k;
@@ -297,8 +296,7 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
       case kind::INTS_DIVISION_TOTAL:
       case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
       case kind::ABS: return rewriteAbs(t);
-      case kind::TO_REAL:
-      case kind::CAST_TO_REAL: return rewriteToReal(t);
+      case kind::TO_REAL: return rewriteToReal(t);
       case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
       case kind::POW:
       {
@@ -592,7 +590,7 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
 
 RewriteResponse ArithRewriter::rewriteToReal(TNode t)
 {
-  Assert(t.getKind() == kind::CAST_TO_REAL || t.getKind() == kind::TO_REAL);
+  Assert(t.getKind() == kind::TO_REAL);
   if (!t[0].getType().isInteger())
   {
     // if it is already real type, then just return the argument
@@ -602,16 +600,9 @@ RewriteResponse ArithRewriter::rewriteToReal(TNode t)
   if (t[0].isConst())
   {
     // If the argument is constant, return a real constant.
-    // !!!! Note that this does not preserve the type of t, since rat is
-    // an integral rational. This will be corrected when the type rule for
-    // CONST_RATIONAL is changed to always return Real.
     const Rational& rat = t[0].getConst<Rational>();
     return RewriteResponse(REWRITE_DONE, nm->mkConstReal(rat));
   }
-  // CAST_TO_REAL is our way of marking integral constants coming from the
-  // user as Real. It should only be applied to constants, which is handled
-  // above.
-  Assert(t.getKind() != kind::CAST_TO_REAL);
   return RewriteResponse(REWRITE_DONE, t);
 }
 
index 106b858984d15f1863bf75784464d59fabefb3d0..a9ef66b03148c93841696e45c8d4d86886d61ad1 100644 (file)
@@ -111,13 +111,6 @@ operator IS_INTEGER 1 "term-is-integer predicate (parameter is a real-sorted ter
 operator TO_INTEGER 1 "convert term to integer by the floor function (parameter is a real-sorted term)"
 operator TO_REAL 1 "cast term to real (parameter is an integer-sorted term; this is a no-op in cvc5, as integer is a subtype of real)"
 
-# CAST_TO_REAL is added to distinguish between integers casted to reals internally, and
-# integers casted to reals or using the API \
-# Solver::mkReal(int val) would return an internal node (CAST_TO_REAL val), but in the api it appears as term (val) \
-# Solver::mkTerm(TO_REAL, Solver::mkInteger(int val)) would return both term and node (TO_REAL val) \
-# This way, we avoid having 2 nested TO_REAL nodess as a result of Solver::mkTerm(TO_REAL, Solver::mkReal(int val))
-operator CAST_TO_REAL 1 "cast term to real same as TO_REAL, but it is used internally, whereas TO_REAL is accessible in the API"
-
 typerule ADD ::cvc5::internal::theory::arith::ArithOperatorTypeRule
 typerule MULT ::cvc5::internal::theory::arith::ArithOperatorTypeRule
 typerule NONLINEAR_MULT ::cvc5::internal::theory::arith::ArithOperatorTypeRule
@@ -141,7 +134,6 @@ typerule INDEXED_ROOT_PREDICATE_OP "SimpleTypeRule<RBuiltinOperator>"
 typerule INDEXED_ROOT_PREDICATE ::cvc5::internal::theory::arith::IndexedRootPredicateTypeRule
 
 typerule TO_REAL "SimpleTypeRule<RReal, ARealOrInteger>"
-typerule CAST_TO_REAL "SimpleTypeRule<RReal, ARealOrInteger>"
 typerule TO_INTEGER "SimpleTypeRule<RInteger, ARealOrInteger>"
 typerule IS_INTEGER "SimpleTypeRule<RBool, ARealOrInteger>"
 
index b60a9b8b83001554381b8a4863e403d8ed4e0a77..8b66d092fa707dc1e3afda1962ae7882f272c128 100644 (file)
@@ -92,8 +92,7 @@ TypeNode ArithOperatorTypeRule::computeType(NodeManager* nodeManager,
   }
   switch (k)
   {
-    case kind::TO_REAL:
-    case kind::CAST_TO_REAL: return realType;
+    case kind::TO_REAL: return realType;
     case kind::TO_INTEGER: return integerType;
     default:
     {
index bd82153cd40be12227136fce35dbe7b00ab8467c..7caca8427345d64805b97fbb630b93119e064931 100644 (file)
@@ -526,7 +526,6 @@ EvalResult Evaluator::evalInternal(
           break;
         }
         case kind::TO_REAL:
-        case kind::CAST_TO_REAL:
         {
           // casting to real is a no-op
           const Rational& x = results[currNode[0]].d_rat;