Refactor rewriting of arithmetic negation and subtraction (#8170)
[cvc5.git] / src / theory / arith / arith_rewriter.cpp
index 5aa904aed204579bcf1d725f1e6c55ea91e227c3..aae34abb62fcf18f95e6862a9fe697999814e638 100644 (file)
-/*********************                                                        */
-/*! \file arith_rewriter.cpp
- ** \verbatim
- ** Original author: Tim King
- ** Major contributors: Morgan Deters
- ** Minor contributors (to current version): Dejan Jovanovic
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2013  New York University and The University of Iowa
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Andrew Reynolds, Tim King, Morgan Deters
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * [[ Add one-line brief description here ]]
+ *
+ * [[ Add lengthier description here ]]
+ * \todo document this file
+ */
 
-#include "theory/theory.h"
-#include "theory/arith/normal_form.h"
 #include "theory/arith/arith_rewriter.h"
-#include "theory/arith/arith_utilities.h"
 
-#include <vector>
+#include <optional>
 #include <set>
+#include <sstream>
 #include <stack>
+#include <vector>
+
+#include "expr/algorithm/flatten.h"
+#include "smt/logic_exception.h"
+#include "theory/arith/arith_msum.h"
+#include "theory/arith/arith_utilities.h"
+#include "theory/arith/normal_form.h"
+#include "theory/arith/operator_elim.h"
+#include "theory/arith/rewriter/addition.h"
+#include "theory/arith/rewriter/node_utils.h"
+#include "theory/arith/rewriter/ordering.h"
+#include "theory/theory.h"
+#include "util/bitvector.h"
+#include "util/divisible.h"
+#include "util/iand.h"
+#include "util/real_algebraic_number.h"
+
+using namespace cvc5::kind;
 
-namespace CVC4 {
+namespace cvc5 {
 namespace theory {
 namespace arith {
 
+namespace {
+
+template <typename L, typename R>
+bool evaluateRelation(Kind rel, const L& l, const R& r)
+{
+  switch (rel)
+  {
+    case Kind::LT: return l < r;
+    case Kind::LEQ: return l <= r;
+    case Kind::EQUAL: return l == r;
+    case Kind::GEQ: return l >= r;
+    case Kind::GT: return l > r;
+    default: Unreachable(); return false;
+  }
+}
+
+/**
+ * Check whether the parent has a child that is a constant zero.
+ * If so, return this child. Otherwise, return std::nullopt.
+ */
+template <typename Iterable>
+std::optional<TNode> getZeroChild(const Iterable& parent)
+{
+  for (const auto& node : parent)
+  {
+    if (node.isConst() && node.template getConst<Rational>().isZero())
+    {
+      return node;
+    }
+  }
+  return std::nullopt;
+}
+
+}  // namespace
+
+ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
+
+RewriteResponse ArithRewriter::preRewrite(TNode t)
+{
+  Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
+  if (isAtom(t))
+  {
+    auto res = preRewriteAtom(t);
+    Trace("arith-rewriter")
+        << res.d_status << " -> " << res.d_node << std::endl;
+    return res;
+  }
+  auto res = preRewriteTerm(t);
+  Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
+  return res;
+}
+
+RewriteResponse ArithRewriter::postRewrite(TNode t)
+{
+  Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
+  if (isAtom(t))
+  {
+    auto res = postRewriteAtom(t);
+    Trace("arith-rewriter")
+        << res.d_status << " -> " << res.d_node << std::endl;
+    return res;
+  }
+  auto res = postRewriteTerm(t);
+  Trace("arith-rewriter") << res.d_status << " -> " << res.d_node << std::endl;
+  return res;
+}
+
+RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
+{
+  Assert(isAtom(atom));
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
+  {
+    switch (atom.getKind())
+    {
+      case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
+      case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
+      default:;
+    }
+  }
+
+  switch (atom.getKind())
+  {
+    case Kind::GT:
+      return RewriteResponse(REWRITE_DONE,
+                             nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
+    case Kind::LT:
+      return RewriteResponse(REWRITE_DONE,
+                             nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
+    case Kind::IS_INTEGER:
+      if (atom[0].getType().isInteger())
+      {
+        return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      }
+      break;
+    case Kind::DIVISIBLE:
+      if (atom.getOperator().getConst<Divisible>().k.isOne())
+      {
+        return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      }
+      break;
+    default:;
+  }
+
+  return RewriteResponse(REWRITE_DONE, atom);
+}
+
+RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
+{
+  Assert(isAtom(atom));
+  if (atom.getKind() == kind::IS_INTEGER)
+  {
+    return rewriteExtIntegerOp(atom);
+  }
+  else if (atom.getKind() == kind::DIVISIBLE)
+  {
+    if (atom[0].isConst())
+    {
+      return RewriteResponse(REWRITE_DONE,
+                             NodeManager::currentNM()->mkConst(bool(
+                                 (atom[0].getConst<Rational>()
+                                  / atom.getOperator().getConst<Divisible>().k)
+                                     .isIntegral())));
+    }
+    if (atom.getOperator().getConst<Divisible>().k.isOne())
+    {
+      return RewriteResponse(REWRITE_DONE,
+                             NodeManager::currentNM()->mkConst(true));
+    }
+    NodeManager* nm = NodeManager::currentNM();
+    return RewriteResponse(
+        REWRITE_AGAIN,
+        nm->mkNode(kind::EQUAL,
+                   nm->mkNode(kind::INTS_MODULUS_TOTAL,
+                              atom[0],
+                              nm->mkConstInt(Rational(
+                                  atom.getOperator().getConst<Divisible>().k))),
+                   nm->mkConstInt(Rational(0))));
+  }
+
+  // left |><| right
+  TNode left = atom[0];
+  TNode right = atom[1];
+
+  auto* nm = NodeManager::currentNM();
+  if (left.isConst())
+  {
+    const Rational& l = left.getConst<Rational>();
+    if (right.isConst())
+    {
+      const Rational& r = right.getConst<Rational>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+    else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& r =
+          right.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+  }
+  else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  {
+    const RealAlgebraicNumber& l =
+        left.getOperator().getConst<RealAlgebraicNumber>();
+    if (right.isConst())
+    {
+      const Rational& r = right.getConst<Rational>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+    else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& r =
+          right.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+  }
+
+  Polynomial pleft = Polynomial::parsePolynomial(left);
+  Polynomial pright = Polynomial::parsePolynomial(right);
+
+  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
+  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
+
+  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
+  Assert(cmp.isNormalForm());
+  return RewriteResponse(REWRITE_DONE, cmp.getNode());
+}
+
 bool ArithRewriter::isAtom(TNode n) {
   Kind k = n.getKind();
-  return arith::isRelationOperator(k) || k == kind::IS_INTEGER || k == kind::DIVISIBLE;
+  return arith::isRelationOperator(k) || k == kind::IS_INTEGER
+      || k == kind::DIVISIBLE;
 }
 
 RewriteResponse ArithRewriter::rewriteConstant(TNode t){
   Assert(t.isConst());
-  Assert(t.getKind() == kind::CONST_RATIONAL);
+  Assert(t.getKind() == CONST_RATIONAL || t.getKind() == CONST_INTEGER);
 
   return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteVariable(TNode t){
-  Assert(t.isVar());
+RewriteResponse ArithRewriter::rewriteRAN(TNode t)
+{
+  Assert(t.getKind() == REAL_ALGEBRAIC_NUMBER);
+
+  const RealAlgebraicNumber& r =
+      t.getOperator().getConst<RealAlgebraicNumber>();
+  if (r.isRational())
+  {
+    return RewriteResponse(
+        REWRITE_DONE, NodeManager::currentNM()->mkConstReal(r.toRational()));
+  }
 
   return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteMinus(TNode t, bool pre){
-  Assert(t.getKind()== kind::MINUS);
+RewriteResponse ArithRewriter::rewriteVariable(TNode t)
+{
+  Assert(t.isVar());
 
-  if(pre){
-    if(t[0] == t[1]){
-      Rational zero(0);
-      Node zeroNode  = mkRationalNode(zero);
-      return RewriteResponse(REWRITE_DONE, zeroNode);
-    }else{
-      Node noMinus = makeSubtractionNode(t[0],t[1]);
-      return RewriteResponse(REWRITE_DONE, noMinus);
-    }
-  }else{
-    Polynomial minuend = Polynomial::parsePolynomial(t[0]);
-    Polynomial subtrahend = Polynomial::parsePolynomial(t[1]);
-    Polynomial diff = minuend - subtrahend;
-    return RewriteResponse(REWRITE_DONE, diff.getNode());
-  }
+  return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteUMinus(TNode t, bool pre){
-  Assert(t.getKind()== kind::UMINUS);
+RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)
+{
+  Assert(t.getKind() == kind::NEG);
 
-  if(t[0].getKind() == kind::CONST_RATIONAL){
+  if (t[0].isConst())
+  {
     Rational neg = -(t[0].getConst<Rational>());
-    return RewriteResponse(REWRITE_DONE, mkRationalNode(neg));
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg));
+  }
+  if (rewriter::isRAN(t[0]))
+  {
+    return RewriteResponse(REWRITE_DONE,
+                           rewriter::mkConst(-rewriter::getRAN(t[0])));
   }
 
-  Node noUminus = makeUnaryMinusNode(t[0]);
-  if(pre)
+  auto* nm = NodeManager::currentNM();
+  Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]);
+  if (pre)
     return RewriteResponse(REWRITE_DONE, noUminus);
   else
     return RewriteResponse(REWRITE_AGAIN, noUminus);
 }
 
+RewriteResponse ArithRewriter::rewriteSub(TNode t)
+{
+  Assert(t.getKind() == kind::SUB);
+  Assert(t.getNumChildren() == 2);
+
+  if (t[0] == t[1])
+  {
+    return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+  }
+  auto* nm = NodeManager::currentNM();
+  return RewriteResponse(
+      REWRITE_AGAIN_FULL,
+      nm->mkNode(Kind::ADD,
+                 t[0],
+                 nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1])));
+}
+
 RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
   if(t.isConst()){
     return rewriteConstant(t);
@@ -88,43 +324,42 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){
     return rewriteVariable(t);
   }else{
     switch(Kind k = t.getKind()){
-    case kind::MINUS:
-      return rewriteMinus(t, true);
-    case kind::UMINUS:
-      return rewriteUMinus(t, true);
-    case kind::DIVISION:
-    case kind::DIVISION_TOTAL:
-      return rewriteDiv(t,true);
-    case kind::PLUS:
-      return preRewritePlus(t);
-    case kind::MULT:
-      return preRewriteMult(t);
-    case kind::INTS_DIVISION:
-    case kind::INTS_MODULUS:
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::INTS_DIVISION_TOTAL:
-    case kind::INTS_MODULUS_TOTAL:
-      return rewriteIntsDivModTotal(t,true);
-    case kind::ABS:
-      if(t[0].isConst()) {
-        const Rational& rat = t[0].getConst<Rational>();
-        if(rat >= 0) {
-          return RewriteResponse(REWRITE_DONE, t[0]);
-        } else {
-          return RewriteResponse(REWRITE_DONE,
-                                 NodeManager::currentNM()->mkConst(-rat));
-        }
-      }
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::IS_INTEGER:
-    case kind::TO_INTEGER:
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::TO_REAL:
-      return RewriteResponse(REWRITE_DONE, t[0]);
-    case kind::POW:
-      return RewriteResponse(REWRITE_DONE, t);
-    default:
-      Unhandled(k);
+      case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
+      case kind::SUB: return rewriteSub(t);
+      case kind::NEG: return rewriteNeg(t, true);
+      case kind::DIVISION:
+      case kind::DIVISION_TOTAL: return rewriteDiv(t, true);
+      case kind::ADD: return preRewritePlus(t);
+      case kind::MULT:
+      case kind::NONLINEAR_MULT: return preRewriteMult(t);
+      case kind::IAND: return RewriteResponse(REWRITE_DONE, t);
+      case kind::POW2: return RewriteResponse(REWRITE_DONE, t);
+      case kind::EXPONENTIAL:
+      case kind::SINE:
+      case kind::COSINE:
+      case kind::TANGENT:
+      case kind::COSECANT:
+      case kind::SECANT:
+      case kind::COTANGENT:
+      case kind::ARCSINE:
+      case kind::ARCCOSINE:
+      case kind::ARCTANGENT:
+      case kind::ARCCOSECANT:
+      case kind::ARCSECANT:
+      case kind::ARCCOTANGENT:
+      case kind::SQRT: return preRewriteTranscendental(t);
+      case kind::INTS_DIVISION:
+      case kind::INTS_MODULUS: return rewriteIntsDivMod(t, true);
+      case kind::INTS_DIVISION_TOTAL:
+      case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
+      case kind::ABS: return rewriteAbs(t);
+      case kind::IS_INTEGER:
+      case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
+      case kind::TO_REAL:
+      case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
+      case kind::POW: return RewriteResponse(REWRITE_DONE, t);
+      case kind::PI: return RewriteResponse(REWRITE_DONE, t);
+      default: Unhandled() << k;
     }
   }
 }
@@ -135,83 +370,88 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
   }else if(t.isVar()){
     return rewriteVariable(t);
   }else{
+    Trace("arith-rewriter") << "postRewriteTerm: " << t << std::endl;
     switch(t.getKind()){
-    case kind::MINUS:
-      return rewriteMinus(t, false);
-    case kind::UMINUS:
-      return rewriteUMinus(t, false);
-    case kind::DIVISION:
-    case kind::DIVISION_TOTAL:
-      return rewriteDiv(t, false);
-    case kind::PLUS:
-      return postRewritePlus(t);
-    case kind::MULT:
-      return postRewriteMult(t);
-    case kind::INTS_DIVISION:
-    case kind::INTS_MODULUS:
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::INTS_DIVISION_TOTAL:
-    case kind::INTS_MODULUS_TOTAL:
-      return rewriteIntsDivModTotal(t, false);
-    case kind::ABS:
-      if(t[0].isConst()) {
-        const Rational& rat = t[0].getConst<Rational>();
-        if(rat >= 0) {
-          return RewriteResponse(REWRITE_DONE, t[0]);
-        } else {
-          return RewriteResponse(REWRITE_DONE,
-                                 NodeManager::currentNM()->mkConst(-rat));
-        }
-      }
-    case kind::TO_REAL:
-      return RewriteResponse(REWRITE_DONE, t[0]);
-    case kind::TO_INTEGER:
-      if(t[0].isConst()) {
-        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(t[0].getConst<Rational>().floor())));
-      }
-      if(t[0].getType().isInteger()) {
-        return RewriteResponse(REWRITE_DONE, t[0]);
-      }
-      //Unimplemented("TO_INTEGER, nonconstant");
-      //return rewriteToInteger(t);
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::IS_INTEGER:
-      if(t[0].isConst()) {
-        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(t[0].getConst<Rational>().getDenominator() == 1));
-      }
-      if(t[0].getType().isInteger()) {
-        return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
-      }
-      //Unimplemented("IS_INTEGER, nonconstant");
-      //return rewriteIsInteger(t);
-      return RewriteResponse(REWRITE_DONE, t);
-    case kind::POW:
+      case kind::REAL_ALGEBRAIC_NUMBER: return rewriteRAN(t);
+      case kind::SUB: return rewriteSub(t);
+      case kind::NEG: return rewriteNeg(t, false);
+      case kind::DIVISION:
+      case kind::DIVISION_TOTAL: return rewriteDiv(t, false);
+      case kind::ADD: return postRewritePlus(t);
+      case kind::MULT:
+      case kind::NONLINEAR_MULT: return postRewriteMult(t);
+      case kind::IAND: return postRewriteIAnd(t);
+      case kind::POW2: return postRewritePow2(t);
+      case kind::EXPONENTIAL:
+      case kind::SINE:
+      case kind::COSINE:
+      case kind::TANGENT:
+      case kind::COSECANT:
+      case kind::SECANT:
+      case kind::COTANGENT:
+      case kind::ARCSINE:
+      case kind::ARCCOSINE:
+      case kind::ARCTANGENT:
+      case kind::ARCCOSECANT:
+      case kind::ARCSECANT:
+      case kind::ARCCOTANGENT:
+      case kind::SQRT: return postRewriteTranscendental(t);
+      case kind::INTS_DIVISION:
+      case kind::INTS_MODULUS: return rewriteIntsDivMod(t, false);
+      case kind::INTS_DIVISION_TOTAL:
+      case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
+      case kind::ABS: return rewriteAbs(t);
+      case kind::TO_REAL:
+      case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
+      case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
+      case kind::POW:
       {
-        if(t[1].getKind() == kind::CONST_RATIONAL){
+        if (t[1].isConst())
+        {
           const Rational& exp = t[1].getConst<Rational>();
           TNode base = t[0];
           if(exp.sgn() == 0){
-            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(1)));
+            return RewriteResponse(REWRITE_DONE,
+                                   NodeManager::currentNM()->mkConstRealOrInt(
+                                       t.getType(), Rational(1)));
           }else if(exp.sgn() > 0 && exp.isIntegral()){
-            Integer num = exp.getNumerator();
-            NodeBuilder<> nb(kind::MULT);
-            Integer one(1);
-            for(Integer i(0); i < num; i = i + one){
-              nb << base;
+            cvc5::Rational r(expr::NodeValue::MAX_CHILDREN);
+            if (exp <= r)
+            {
+              unsigned num = exp.getNumerator().toUnsignedInt();
+              if( num==1 ){
+                return RewriteResponse(REWRITE_AGAIN, base);
+              }else{
+                NodeBuilder nb(kind::MULT);
+                for(unsigned i=0; i < num; ++i){
+                  nb << base;
+                }
+                Assert(nb.getNumChildren() > 0);
+                Node mult = nb;
+                return RewriteResponse(REWRITE_AGAIN, mult);
+              }
             }
-            Assert(nb.getNumChildren() > 0);
-            Node mult = nb;
-            return RewriteResponse(REWRITE_AGAIN, mult);
           }
         }
+        else if (t[0].isConst()
+                 && t[0].getConst<Rational>().getNumerator().toUnsignedInt()
+                        == 2)
+        {
+          return RewriteResponse(
+              REWRITE_DONE, NodeManager::currentNM()->mkNode(kind::POW2, t[1]));
+        }
 
         // Todo improve the exception thrown
         std::stringstream ss;
-        ss << "The POW(^) operator can only be used with a natural number ";
-        ss << "in the exponent.  Exception occured in:" << std::endl;
+        ss << "The exponent of the POW(^) operator can only be a positive "
+              "integral constant below "
+           << (expr::NodeValue::MAX_CHILDREN + 1) << ". ";
+        ss << "Exception occurred in:" << std::endl;
         ss << "  " << t;
-        throw Exception(ss.str());
+        throw LogicException(ss.str());
       }
+    case kind::PI:
+      return RewriteResponse(REWRITE_DONE, t);
     default:
       Unreachable();
     }
@@ -219,290 +459,361 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){
 }
 
 
-RewriteResponse ArithRewriter::preRewriteMult(TNode t){
-  Assert(t.getKind()== kind::MULT);
-
-  if(t.getNumChildren() == 2){
-    if(t[0].getKind() == kind::CONST_RATIONAL
-       && t[0].getConst<Rational>().isOne()){
-      return RewriteResponse(REWRITE_DONE, t[1]);
-    }
-    if(t[1].getKind() == kind::CONST_RATIONAL
-       && t[1].getConst<Rational>().isOne()){
-      return RewriteResponse(REWRITE_DONE, t[0]);
-    }
-  }
-
-  // Rewrite multiplications with a 0 argument and to 0
-  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
-    if((*i).getKind() == kind::CONST_RATIONAL) {
-      if((*i).getConst<Rational>().isZero()) {
-        TNode zero = (*i);
-        return RewriteResponse(REWRITE_DONE, zero);
-      }
-    }
-  }
-  return RewriteResponse(REWRITE_DONE, t);
+RewriteResponse ArithRewriter::preRewritePlus(TNode t){
+  Assert(t.getKind() == kind::ADD);
+  return RewriteResponse(REWRITE_DONE, expr::algorithm::flatten(t));
 }
 
-static bool canFlatten(Kind k, TNode t){
-  for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
-    TNode child = *i;
-    if(child.getKind() == k){
-      return true;
-    }
-  }
-  return false;
-}
+RewriteResponse ArithRewriter::postRewritePlus(TNode t){
+  Assert(t.getKind() == kind::ADD);
+  Assert(t.getNumChildren() > 1);
 
-static void flatten(std::vector<TNode>& pb, Kind k, TNode t){
-  if(t.getKind() == k){
-    for(TNode::iterator i = t.begin(); i != t.end(); ++i) {
-      TNode child = *i;
-      if(child.getKind() == k){
-        flatten(pb, k, child);
-      }else{
-        pb.push_back(child);
-      }
+  {
+    Node flat = expr::algorithm::flatten(t);
+    if (flat != t)
+    {
+      return RewriteResponse(REWRITE_AGAIN, flat);
     }
-  }else{
-    pb.push_back(t);
   }
-}
-
-static Node flatten(Kind k, TNode t){
-  std::vector<TNode> pb;
-  flatten(pb, k, t);
-  Assert(pb.size() >= 2);
-  return NodeManager::currentNM()->mkNode(k, pb);
-}
-
-RewriteResponse ArithRewriter::preRewritePlus(TNode t){
-  Assert(t.getKind()== kind::PLUS);
-
-  if(canFlatten(kind::PLUS, t)){
-    return RewriteResponse(REWRITE_DONE, flatten(kind::PLUS, t));
-  }else{
-    return RewriteResponse(REWRITE_DONE, t);
-  }
-}
-
-RewriteResponse ArithRewriter::postRewritePlus(TNode t){
-  Assert(t.getKind()== kind::PLUS);
 
+  Rational rational;
+  RealAlgebraicNumber ran;
   std::vector<Monomial> monomials;
   std::vector<Polynomial> polynomials;
 
-  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
-    TNode curr = *i;
-    if(Monomial::isMember(curr)){
-      monomials.push_back(Monomial::parseMonomial(curr));
-    }else{
-      polynomials.push_back(Polynomial::parsePolynomial(curr));
+  for (const auto& child : t)
+  {
+    if (child.isConst())
+    {
+      if (child.getConst<Rational>().isZero())
+      {
+        continue;
+      }
+      rational += child.getConst<Rational>();
+    }
+    else if (child.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      ran += child.getOperator().getConst<RealAlgebraicNumber>();
+    }
+    else if (Monomial::isMember(child))
+    {
+      monomials.emplace_back(Monomial::parseMonomial(child));
+    }
+    else
+    {
+      polynomials.emplace_back(Polynomial::parsePolynomial(child));
     }
   }
 
   if(!monomials.empty()){
     Monomial::sort(monomials);
     Monomial::combineAdjacentMonomials(monomials);
-    polynomials.push_back(Polynomial::mkPolynomial(monomials));
+    polynomials.emplace_back(Polynomial::mkPolynomial(monomials));
   }
-
-  Polynomial res = Polynomial::sumPolynomials(polynomials);
-
-  return RewriteResponse(REWRITE_DONE, res.getNode());
-}
-
-RewriteResponse ArithRewriter::postRewriteMult(TNode t){
-  Assert(t.getKind()== kind::MULT);
-
-  Polynomial res = Polynomial::mkOne();
-
-  for(TNode::iterator i = t.begin(), end = t.end(); i != end; ++i){
-    Node curr = *i;
-    Polynomial currPoly = Polynomial::parsePolynomial(curr);
-
-    res = res * currPoly;
+  if (!rational.isZero())
+  {
+    polynomials.emplace_back(
+        Polynomial::mkPolynomial(Constant::mkConstant(rational)));
   }
 
-  return RewriteResponse(REWRITE_DONE, res.getNode());
-}
+  Polynomial poly = Polynomial::sumPolynomials(polynomials);
 
-RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
-  if(atom.getKind() == kind::IS_INTEGER) {
-    if(atom[0].isConst()) {
-      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(atom[0].getConst<Rational>().isIntegral()));
-    }
-    if(atom[0].getType().isInteger()) {
-      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
-    }
-    // not supported, but this isn't the right place to complain
-    return RewriteResponse(REWRITE_DONE, atom);
-  } else if(atom.getKind() == kind::DIVISIBLE) {
-    if(atom[0].isConst()) {
-      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
-    }
-    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
-      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
+  if (isZero(ran))
+  {
+    return RewriteResponse(REWRITE_DONE, poly.getNode());
+  }
+  if (poly.containsConstant())
+  {
+    ran += RealAlgebraicNumber(poly.getHead().getConstant().getValue());
+    if (!poly.isConstant())
+    {
+      poly = poly.getTail();
     }
-    return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::EQUAL, NodeManager::currentNM()->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], NodeManager::currentNM()->mkConst(Rational(atom.getOperator().getConst<Divisible>().k))), NodeManager::currentNM()->mkConst(Rational(0))));
   }
 
-  // left |><| right
-  TNode left = atom[0];
-  TNode right = atom[1];
+  auto* nm = NodeManager::currentNM();
+  if (poly.isConstant())
+  {
+    return RewriteResponse(REWRITE_DONE, nm->mkRealAlgebraicNumber(ran));
+  }
+  return RewriteResponse(
+      REWRITE_DONE,
+      nm->mkNode(Kind::ADD, nm->mkRealAlgebraicNumber(ran), poly.getNode()));
+}
 
-  Polynomial pleft = Polynomial::parsePolynomial(left);
-  Polynomial pright = Polynomial::parsePolynomial(right);
+RewriteResponse ArithRewriter::preRewriteMult(TNode node)
+{
+  Assert(node.getKind() == kind::MULT
+         || node.getKind() == kind::NONLINEAR_MULT);
 
-  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
-  Assert(cmp.isNormalForm());
-  return RewriteResponse(REWRITE_DONE, cmp.getNode());
+  if (auto res = rewriter::getZeroChild(node); res)
+  {
+    return RewriteResponse(REWRITE_DONE, *res);
+  }
+  return RewriteResponse(REWRITE_DONE, node);
 }
 
-RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
-  Assert(isAtom(atom));
+RewriteResponse ArithRewriter::postRewriteMult(TNode t){
+  Assert(t.getKind() == kind::MULT || t.getKind() == kind::NONLINEAR_MULT);
+  Assert(t.getNumChildren() >= 2);
 
-  NodeManager* currNM = NodeManager::currentNM();
+  std::vector<TNode> children;
+  expr::algorithm::flatten(t, children, Kind::MULT, Kind::NONLINEAR_MULT);
 
-  if(atom.getKind() == kind::EQUAL) {
-    if(atom[0] == atom[1]) {
-      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
-    }
-  }else if(atom.getKind() == kind::GT){
-    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
-    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
-  }else if(atom.getKind() == kind::LT){
-    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
-    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
-  }else if(atom.getKind() == kind::IS_INTEGER){
-    if(atom[0].getType().isInteger()){
-      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
-    }
-  }else if(atom.getKind() == kind::DIVISIBLE){
-    if(atom.getOperator().getConst<Divisible>().k.isOne()){
-      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
-    }
+  if (auto res = rewriter::getZeroChild(children); res)
+  {
+    return RewriteResponse(REWRITE_DONE, *res);
   }
 
-  return RewriteResponse(REWRITE_DONE, atom);
-}
+  // Distribute over addition
+  if (std::any_of(children.begin(), children.end(), [](TNode child) {
+        return child.getKind() == Kind::ADD;
+      }))
+  {
+    return RewriteResponse(REWRITE_DONE,
+                           rewriter::distributeMultiplication(children));
+  }
 
-RewriteResponse ArithRewriter::postRewrite(TNode t){
-  if(isTerm(t)){
-    RewriteResponse response = postRewriteTerm(t);
-    if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
-      Polynomial::parsePolynomial(response.node);
+  RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
+  std::vector<Node> leafs;
+
+  for (const auto& child : children)
+  {
+    if (child.isConst())
+    {
+      if (child.getConst<Rational>().isZero())
+      {
+        return RewriteResponse(REWRITE_DONE, child);
+      }
+      ran *= child.getConst<Rational>();
     }
-    return response;
-  }else if(isAtom(t)){
-    RewriteResponse response = postRewriteAtom(t);
-    if(Debug.isOn("arith::rewriter") && response.status == REWRITE_DONE) {
-      Comparison::parseNormalForm(response.node);
+    else if (rewriter::isRAN(child))
+    {
+      ran *= rewriter::getRAN(child);
+    }
+    else
+    {
+      leafs.emplace_back(child);
     }
-    return response;
-  }else{
-    Unreachable();
-    return RewriteResponse(REWRITE_DONE, Node::null());
-  }
-}
-
-RewriteResponse ArithRewriter::preRewrite(TNode t){
-  if(isTerm(t)){
-    return preRewriteTerm(t);
-  }else if(isAtom(t)){
-    return preRewriteAtom(t);
-  }else{
-    Unreachable();
-    return RewriteResponse(REWRITE_DONE, Node::null());
   }
-}
-
-Node ArithRewriter::makeUnaryMinusNode(TNode n){
-  Rational qNegOne(-1);
-  return NodeManager::currentNM()->mkNode(kind::MULT, mkRationalNode(qNegOne),n);
-}
 
-Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){
-  Node negR = makeUnaryMinusNode(r);
-  Node diff = NodeManager::currentNM()->mkNode(kind::PLUS, l, negR);
-
-  return diff;
+  return RewriteResponse(REWRITE_DONE,
+                         rewriter::mkMultTerm(ran, std::move(leafs)));
 }
 
-RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
-  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind()== kind::DIVISION);
-
+RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
+{
+  Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
+  Assert(t.getNumChildren() == 2);
 
   Node left = t[0];
   Node right = t[1];
-  if(right.getKind() == kind::CONST_RATIONAL){
+  if (right.isConst())
+  {
+    NodeManager* nm = NodeManager::currentNM();
     const Rational& den = right.getConst<Rational>();
 
-    if(den.isZero()){
-      if(t.getKind() == kind::DIVISION_TOTAL){
-        return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-      }else{
+    if (den.isZero())
+    {
+      if (t.getKind() == kind::DIVISION_TOTAL)
+      {
+        return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
+      }
+      else
+      {
         // This is unsupported, but this is not a good place to complain
         return RewriteResponse(REWRITE_DONE, t);
       }
     }
     Assert(den != Rational(0));
 
-    if(left.getKind() == kind::CONST_RATIONAL){
+    if (left.isConst())
+    {
       const Rational& num = left.getConst<Rational>();
-      Rational div = num / den;
-      Node result =  mkRationalNode(div);
-      return RewriteResponse(REWRITE_DONE, result);
+      return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
+    }
+    if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& num =
+          left.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(
+          REWRITE_DONE,
+          nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
     }
 
-    Rational div = den.inverse();
-
-    Node result = mkRationalNode(div);
+    Node result = nm->mkConstReal(den.inverse());
+    Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+    if (pre)
+    {
+      return RewriteResponse(REWRITE_DONE, mult);
+    }
+    else
+    {
+      return RewriteResponse(REWRITE_AGAIN, mult);
+    }
+  }
+  if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    const RealAlgebraicNumber& den =
+        right.getOperator().getConst<RealAlgebraicNumber>();
+    if (left.isConst())
+    {
+      const Rational& num = left.getConst<Rational>();
+      return RewriteResponse(
+          REWRITE_DONE,
+          nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
+    }
+    if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& num =
+          left.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(REWRITE_DONE,
+                             nm->mkRealAlgebraicNumber(num / den));
+    }
 
-    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
-    if(pre){
+    Node result = nm->mkRealAlgebraicNumber(inverse(den));
+    Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+    if (pre)
+    {
       return RewriteResponse(REWRITE_DONE, mult);
-    }else{
+    }
+    else
+    {
       return RewriteResponse(REWRITE_AGAIN, mult);
     }
-  }else{
-    return RewriteResponse(REWRITE_DONE, t);
   }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::rewriteAbs(TNode t)
+{
+  Assert(t.getKind() == Kind::ABS);
+  Assert(t.getNumChildren() == 1);
+
+  if (t[0].isConst())
+  {
+    const Rational& rat = t[0].getConst<Rational>();
+    if (rat >= 0)
+    {
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    }
+    return RewriteResponse(
+        REWRITE_DONE,
+        NodeManager::currentNM()->mkConstRealOrInt(t[0].getType(), -rat));
+  }
+  if (t[0].getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  {
+    const RealAlgebraicNumber& ran =
+        t[0].getOperator().getConst<RealAlgebraicNumber>();
+    if (ran >= RealAlgebraicNumber())
+    {
+      return RewriteResponse(REWRITE_DONE, t[0]);
+    }
+    return RewriteResponse(
+        REWRITE_DONE, NodeManager::currentNM()->mkRealAlgebraicNumber(-ran));
+  }
+  return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
+RewriteResponse ArithRewriter::rewriteIntsDivMod(TNode t, bool pre)
+{
+  NodeManager* nm = NodeManager::currentNM();
   Kind k = t.getKind();
-  // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL ||
-  //        k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
+  if (k == kind::INTS_MODULUS)
+  {
+    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+    {
+      // can immediately replace by INTS_MODULUS_TOTAL
+      Node ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, t[0], t[1]);
+      return returnRewrite(t, ret, Rewrite::MOD_TOTAL_BY_CONST);
+    }
+  }
+  if (k == kind::INTS_DIVISION)
+  {
+    if (t[1].isConst() && !t[1].getConst<Rational>().isZero())
+    {
+      // can immediately replace by INTS_DIVISION_TOTAL
+      Node ret = nm->mkNode(kind::INTS_DIVISION_TOTAL, t[0], t[1]);
+      return returnRewrite(t, ret, Rewrite::DIV_TOTAL_BY_CONST);
+    }
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::rewriteExtIntegerOp(TNode t)
+{
+  Assert(t.getKind() == kind::TO_INTEGER || t.getKind() == kind::IS_INTEGER);
+  bool isPred = t.getKind() == kind::IS_INTEGER;
+  NodeManager* nm = NodeManager::currentNM();
+  if (t[0].isConst())
+  {
+    Node ret;
+    if (isPred)
+    {
+      ret = nm->mkConst(t[0].getConst<Rational>().isIntegral());
+    }
+    else
+    {
+      ret = nm->mkConstInt(Rational(t[0].getConst<Rational>().floor()));
+    }
+    return returnRewrite(t, ret, Rewrite::INT_EXT_CONST);
+  }
+  if (t[0].getType().isInteger())
+  {
+    Node ret = isPred ? nm->mkConst(true) : Node(t[0]);
+    return returnRewrite(t, ret, Rewrite::INT_EXT_INT);
+  }
+  if (t[0].getKind() == kind::PI)
+  {
+    Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
+    return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
 
-  //Leaving the function as before (INTS_MODULUS can be handled),
-  // but restricting its use here
+RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
+{
+  if (pre)
+  {
+    // do not rewrite at prewrite.
+    return RewriteResponse(REWRITE_DONE, t);
+  }
+  NodeManager* nm = NodeManager::currentNM();
+  Kind k = t.getKind();
   Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL);
-  TNode n = t[0], d = t[1];
-  bool dIsConstant = d.getKind() == kind::CONST_RATIONAL;
-  if(dIsConstant && d.getConst<Rational>().isZero()){
-    if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){
-      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-    }else{
-      // Do nothing for k == INTS_MODULUS
-      return RewriteResponse(REWRITE_DONE, t);
+  TNode n = t[0];
+  TNode d = t[1];
+  bool dIsConstant = d.isConst();
+  if (dIsConstant && d.getConst<Rational>().isZero())
+  {
+    // (div x 0) ---> 0 or (mod x 0) ---> 0
+    return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
+  }
+  else if (dIsConstant && d.getConst<Rational>().isOne())
+  {
+    if (k == kind::INTS_MODULUS_TOTAL)
+    {
+      // (mod x 1) --> 0
+      return returnRewrite(t, nm->mkConstInt(0), Rewrite::MOD_BY_ONE);
     }
-  }else if(dIsConstant && d.getConst<Rational>().isOne()){
-    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
-      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-    }else{
-      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
-      return RewriteResponse(REWRITE_AGAIN, n);
-    }
-  }else if(dIsConstant && d.getConst<Rational>().isNegativeOne()){
-    if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){
-      return RewriteResponse(REWRITE_DONE, mkRationalNode(0));
-    }else{
-      Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
-      return RewriteResponse(REWRITE_AGAIN, NodeManager::currentNM()->mkNode(kind::UMINUS, n));
-    }
-  }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){
+    Assert(k == kind::INTS_DIVISION_TOTAL);
+    // (div x 1) --> x
+    return returnRewrite(t, n, Rewrite::DIV_BY_ONE);
+  }
+  else if (dIsConstant && d.getConst<Rational>().sgn() < 0)
+  {
+    // pull negation
+    // (div x (- c)) ---> (- (div x c))
+    // (mod x (- c)) ---> (mod x c)
+    Node nn = nm->mkNode(k, t[0], nm->mkConstInt(-t[1].getConst<Rational>()));
+    Node ret = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL)
+                   ? nm->mkNode(kind::NEG, nn)
+                   : nn;
+    return returnRewrite(t, ret, Rewrite::DIV_MOD_PULL_NEG_DEN);
+  }
+  else if (dIsConstant && n.isConst())
+  {
     Assert(d.getConst<Rational>().isIntegral());
     Assert(n.getConst<Rational>().isIntegral());
     Assert(!d.getConst<Rational>().isZero());
@@ -511,15 +822,379 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){
 
     bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
 
-    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
+    Integer result = isDiv ? ni.euclidianDivideQuotient(di)
+                           : ni.euclidianDivideRemainder(di);
 
-    Node resultNode = mkRationalNode(Rational(result));
-    return RewriteResponse(REWRITE_DONE, resultNode);
-  }else{
-    return RewriteResponse(REWRITE_DONE, t);
+    // constant evaluation
+    // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
+    Node resultNode = nm->mkConstInt(Rational(result));
+    return returnRewrite(t, resultNode, Rewrite::CONST_EVAL);
   }
+  if (k == kind::INTS_MODULUS_TOTAL)
+  {
+    // Note these rewrites do not need to account for modulus by zero as being
+    // a UF, which is handled by the reduction of INTS_MODULUS.
+    Kind k0 = t[0].getKind();
+    if (k0 == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+    {
+      // (mod (mod x c) c) --> (mod x c)
+      return returnRewrite(t, t[0], Rewrite::MOD_OVER_MOD);
+    }
+    else if (k0 == kind::NONLINEAR_MULT || k0 == kind::MULT || k0 == kind::ADD)
+    {
+      // can drop all
+      std::vector<Node> newChildren;
+      bool childChanged = false;
+      for (const Node& tc : t[0])
+      {
+        if (tc.getKind() == kind::INTS_MODULUS_TOTAL && tc[1] == t[1])
+        {
+          newChildren.push_back(tc[0]);
+          childChanged = true;
+          continue;
+        }
+        newChildren.push_back(tc);
+      }
+      if (childChanged)
+      {
+        // (mod (op ... (mod x c) ...) c) ---> (mod (op ... x ...) c) where
+        // op is one of { NONLINEAR_MULT, MULT, ADD }.
+        Node ret = nm->mkNode(k0, newChildren);
+        ret = nm->mkNode(kind::INTS_MODULUS_TOTAL, ret, t[1]);
+        return returnRewrite(t, ret, Rewrite::MOD_CHILD_MOD);
+      }
+    }
+  }
+  else
+  {
+    Assert(k == kind::INTS_DIVISION_TOTAL);
+    // Note these rewrites do not need to account for division by zero as being
+    // a UF, which is handled by the reduction of INTS_DIVISION.
+    if (t[0].getKind() == kind::INTS_MODULUS_TOTAL && t[0][1] == t[1])
+    {
+      // (div (mod x c) c) --> 0
+      Node ret = nm->mkConstInt(0);
+      return returnRewrite(t, ret, Rewrite::DIV_OVER_MOD);
+    }
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
+{
+  Assert(t.getKind() == kind::IAND);
+  size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
+  NodeManager* nm = NodeManager::currentNM();
+  // if constant, we eliminate
+  if (t[0].isConst() && t[1].isConst())
+  {
+    Node iToBvop = nm->mkConst(IntToBitVector(bsize));
+    Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
+    Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
+    Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
+    Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
+    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+  }
+  else if (t[0] > t[1])
+  {
+    // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
+    Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
+    return RewriteResponse(REWRITE_AGAIN, ret);
+  }
+  else if (t[0] == t[1])
+  {
+    // ((_ iand k) x x) ---> (mod x 2^k)
+    Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+    Node ret = nm->mkNode(kind::INTS_MODULUS, t[0], twok);
+    return RewriteResponse(REWRITE_AGAIN, ret);
+  }
+  // simplifications involving constants
+  for (unsigned i = 0; i < 2; i++)
+  {
+    if (!t[i].isConst())
+    {
+      continue;
+    }
+    if (t[i].getConst<Rational>().sgn() == 0)
+    {
+      // ((_ iand k) 0 y) ---> 0
+      return RewriteResponse(REWRITE_DONE, t[i]);
+    }
+    if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
+    {
+      // ((_ iand k) 111...1 y) ---> (mod y 2^k)
+      Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+      Node ret = nm->mkNode(kind::INTS_MODULUS, t[1 - i], twok);
+      return RewriteResponse(REWRITE_AGAIN, ret);
+    }
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewritePow2(TNode t)
+{
+  Assert(t.getKind() == kind::POW2);
+  NodeManager* nm = NodeManager::currentNM();
+  // if constant, we eliminate
+  if (t[0].isConst())
+  {
+    // pow2 is only supported for integers
+    Assert(t[0].getType().isInteger());
+    Integer i = t[0].getConst<Rational>().getNumerator();
+    if (i < 0)
+    {
+      return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+    }
+    // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
+    Node two = rewriter::mkConst(Integer(2));
+    Node ret = nm->mkNode(kind::POW, two, t[0]);
+    return RewriteResponse(REWRITE_AGAIN, ret);
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t)
+{
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
+{
+  Trace("arith-tf-rewrite")
+      << "Rewrite transcendental function : " << t << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
+  switch (t.getKind())
+  {
+    case kind::EXPONENTIAL:
+    {
+      if (t[0].isConst())
+      {
+        Node one = rewriter::mkConst(Integer(1));
+        if (t[0].getConst<Rational>().sgn() >= 0 && t[0].getType().isInteger()
+            && t[0] != one)
+        {
+          return RewriteResponse(
+              REWRITE_AGAIN,
+              nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
+        }
+        else
+        {
+          return RewriteResponse(REWRITE_DONE, t);
+        }
+      }
+      else if (t[0].getKind() == kind::ADD)
+      {
+        std::vector<Node> product;
+        for (const Node tc : t[0])
+        {
+          product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
+        }
+        // We need to do a full rewrite here, since we can get exponentials of
+        // constants, e.g. when we are rewriting exp(2 + x)
+        return RewriteResponse(REWRITE_AGAIN_FULL,
+                               nm->mkNode(kind::MULT, product));
+      }
+    }
+    break;
+    case kind::SINE:
+      if (t[0].isConst())
+      {
+        const Rational& rat = t[0].getConst<Rational>();
+        if (rat.sgn() == 0)
+        {
+          return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+        }
+        else if (rat.sgn() == -1)
+        {
+          Node ret = nm->mkNode(
+              kind::NEG, nm->mkNode(kind::SINE, rewriter::mkConst(-rat)));
+          return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+        }
+      }
+      else if ((t[0].getKind() == MULT || t[0].getKind() == NONLINEAR_MULT)
+               && t[0][0].isConst() && t[0][0].getConst<Rational>().sgn() == -1)
+      {
+        // sin(-n*x) ---> -sin(n*x)
+        std::vector<Node> mchildren(t[0].begin(), t[0].end());
+        mchildren[0] = nm->mkConstReal(-t[0][0].getConst<Rational>());
+        Node ret = nm->mkNode(
+            kind::NEG,
+            nm->mkNode(kind::SINE, nm->mkNode(t[0].getKind(), mchildren)));
+        return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+      }
+      else
+      {
+        // get the factor of PI in the argument
+        Node pi_factor;
+        Node pi;
+        Node rem;
+        std::map<Node, Node> msum;
+        if (ArithMSum::getMonomialSum(t[0], msum))
+        {
+          pi = mkPi();
+          std::map<Node, Node>::iterator itm = msum.find(pi);
+          if (itm != msum.end())
+          {
+            if (itm->second.isNull())
+            {
+              pi_factor = rewriter::mkConst(Integer(1));
+            }
+            else
+            {
+              pi_factor = itm->second;
+            }
+            msum.erase(pi);
+            if (!msum.empty())
+            {
+              rem = ArithMSum::mkNode(t[0].getType(), msum);
+            }
+          }
+        }
+        else
+        {
+          Assert(false);
+        }
+
+        // if there is a factor of PI
+        if (!pi_factor.isNull())
+        {
+          Trace("arith-tf-rewrite-debug")
+              << "Process pi factor = " << pi_factor << std::endl;
+          Rational r = pi_factor.getConst<Rational>();
+          Rational r_abs = r.abs();
+          Rational rone = Rational(1);
+          Rational rtwo = Rational(2);
+          if (r_abs > rone)
+          {
+            // add/substract 2*pi beyond scope
+            Rational ra_div_two = (r_abs + rone) / rtwo;
+            Node new_pi_factor;
+            if (r.sgn() == 1)
+            {
+              new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
+            }
+            else
+            {
+              Assert(r.sgn() == -1);
+              new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
+            }
+            Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
+            if (!rem.isNull())
+            {
+              new_arg = nm->mkNode(kind::ADD, new_arg, rem);
+            }
+            // sin( 2*n*PI + x ) = sin( x )
+            return RewriteResponse(REWRITE_AGAIN_FULL,
+                                   nm->mkNode(kind::SINE, new_arg));
+          }
+          else if (r_abs == rone)
+          {
+            // sin( PI + x ) = -sin( x )
+            if (rem.isNull())
+            {
+              return RewriteResponse(REWRITE_DONE,
+                                     nm->mkConstReal(Rational(0)));
+            }
+            else
+            {
+              return RewriteResponse(
+                  REWRITE_AGAIN_FULL,
+                  nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
+            }
+          }
+          else if (rem.isNull())
+          {
+            // other rational cases based on Niven's theorem
+            // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
+            Integer one = Integer(1);
+            Integer two = Integer(2);
+            Integer six = Integer(6);
+            if (r_abs.getDenominator() == two)
+            {
+              Assert(r_abs.getNumerator() == one);
+              return RewriteResponse(REWRITE_DONE,
+                                     nm->mkConstReal(Rational(r.sgn())));
+            }
+            else if (r_abs.getDenominator() == six)
+            {
+              Integer five = Integer(5);
+              if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
+              {
+                return RewriteResponse(
+                    REWRITE_DONE,
+                    nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
+              }
+            }
+          }
+        }
+      }
+      break;
+    case kind::COSINE:
+    {
+      return RewriteResponse(
+          REWRITE_AGAIN_FULL,
+          nm->mkNode(
+              kind::SINE,
+              nm->mkNode(kind::SUB,
+                         nm->mkNode(kind::MULT,
+                                    nm->mkConstReal(Rational(1) / Rational(2)),
+                                    mkPi()),
+                         t[0])));
+    }
+    break;
+    case kind::TANGENT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkNode(kind::SINE, t[0]),
+                                        nm->mkNode(kind::COSINE, t[0])));
+    }
+    break;
+    case kind::COSECANT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkConstReal(Rational(1)),
+                                        nm->mkNode(kind::SINE, t[0])));
+    }
+    break;
+    case kind::SECANT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkConstReal(Rational(1)),
+                                        nm->mkNode(kind::COSINE, t[0])));
+    }
+    break;
+    case kind::COTANGENT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkNode(kind::COSINE, t[0]),
+                                        nm->mkNode(kind::SINE, t[0])));
+    }
+    break;
+    default: break;
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+TrustNode ArithRewriter::expandDefinition(Node node)
+{
+  // call eliminate operators, to eliminate partial operators only
+  std::vector<SkolemLemma> lems;
+  TrustNode ret = d_opElim.eliminate(node, lems, true);
+  Assert(lems.empty());
+  return ret;
+}
+
+RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
+{
+  Trace("arith-rewriter") << "ArithRewriter : " << t << " == " << ret << " by "
+                          << r << std::endl;
+  return RewriteResponse(REWRITE_AGAIN_FULL, ret);
 }
 
-}/* CVC4::theory::arith namespace */
-}/* CVC4::theory namespace */
-}/* CVC4 namespace */
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5