Slightly refactor arithmetic rewriting for extended operators (#8169)
authorGereon Kremer <gkremer@cs.stanford.edu>
Fri, 25 Feb 2022 20:17:34 +0000 (21:17 +0100)
committerGitHub <noreply@github.com>
Fri, 25 Feb 2022 20:17:34 +0000 (20:17 +0000)
This PR mostly reorders the implementation to match the order in the header, and does a few very minor refactorings for the rewriters for transcendental functions, and the pow2 and iand operators.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h

index 86835f447184d1d95662821c50590acd4fee2489..f8d8594ebab585b03cf5a6b2f757ca5f0be33c11 100644 (file)
@@ -602,299 +602,15 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){
                          rewriter::mkMultTerm(ran, std::move(leafs)));
 }
 
-RewriteResponse ArithRewriter::postRewritePow2(TNode t)
-{
-  Assert(t.getKind() == kind::POW2);
-  NodeManager* nm = NodeManager::currentNM();
-  // if constant, we eliminate
-  if (t[0].isConst())
-  {
-    // pow2 is only supported for integers
-    Assert(t[0].getType().isInteger());
-    Integer i = t[0].getConst<Rational>().getNumerator();
-    if (i < 0)
-    {
-      return RewriteResponse(REWRITE_DONE, nm->mkConstInt(Rational(0)));
-    }
-    // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
-    Node two = nm->mkConstInt(Rational(Integer(2)));
-    Node ret = nm->mkNode(kind::POW, two, t[0]);
-    return RewriteResponse(REWRITE_AGAIN, ret);
-  }
-  return RewriteResponse(REWRITE_DONE, t);
-}
-
-RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
+Node ArithRewriter::makeUnaryMinusNode(TNode n)
 {
-  Assert(t.getKind() == kind::IAND);
-  size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
-  NodeManager* nm = NodeManager::currentNM();
-  // if constant, we eliminate
-  if (t[0].isConst() && t[1].isConst())
-  {
-    Node iToBvop = nm->mkConst(IntToBitVector(bsize));
-    Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
-    Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
-    Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
-    Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
-    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
-  }
-  else if (t[0] > t[1])
-  {
-    // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
-    Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
-    return RewriteResponse(REWRITE_AGAIN, ret);
-  }
-  else if (t[0] == t[1])
-  {
-    // ((_ iand k) x x) ---> (mod x 2^k)
-    Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
-    Node ret = nm->mkNode(kind::INTS_MODULUS, t[0],  twok);
-    return RewriteResponse(REWRITE_AGAIN, ret);
-  }
-  // simplifications involving constants
-  for (unsigned i = 0; i < 2; i++)
-  {
-    if (!t[i].isConst())
-    {
-      continue;
-    }
-    if (t[i].getConst<Rational>().sgn() == 0)
-    {
-      // ((_ iand k) 0 y) ---> 0
-      return RewriteResponse(REWRITE_DONE, t[i]);
-    }
-    if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
-    {
-      // ((_ iand k) 111...1 y) ---> (mod y 2^k)
-      Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
-      Node ret = nm->mkNode(kind::INTS_MODULUS, t[1-i],  twok);
-      return RewriteResponse(REWRITE_AGAIN, ret);
-    }
-  }
-  return RewriteResponse(REWRITE_DONE, t);
-}
-
-RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t) {
-  return RewriteResponse(REWRITE_DONE, t);
-}
-
-RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) { 
-  Trace("arith-tf-rewrite") << "Rewrite transcendental function : " << t << std::endl;
-  NodeManager* nm = NodeManager::currentNM();
-  switch( t.getKind() ){
-  case kind::EXPONENTIAL: {
-    if (t[0].isConst())
-    {
-      Node one = nm->mkConstReal(Rational(1));
-      if(t[0].getConst<Rational>().sgn()>=0 && t[0].getType().isInteger() && t[0]!=one){
-        return RewriteResponse(
-            REWRITE_AGAIN,
-            nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
-      }else{          
-        return RewriteResponse(REWRITE_DONE, t);
-      }
-    }
-    else if (t[0].getKind() == kind::ADD)
-    {
-      std::vector<Node> product;
-      for (const Node tc : t[0])
-      {
-        product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
-      }
-      // We need to do a full rewrite here, since we can get exponentials of
-      // constants, e.g. when we are rewriting exp(2 + x)
-      return RewriteResponse(REWRITE_AGAIN_FULL,
-                             nm->mkNode(kind::MULT, product));
-    }
-  }
-    break;
-  case kind::SINE:
-    if (t[0].isConst())
-    {
-      const Rational& rat = t[0].getConst<Rational>();
-      if(rat.sgn() == 0){
-        return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
-      }
-      else if (rat.sgn() == -1)
-      {
-        Node ret = nm->mkNode(kind::NEG,
-                              nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
-        return RewriteResponse(REWRITE_AGAIN_FULL, ret);
-      }
-    }
-    else if ((t[0].getKind() == MULT || t[0].getKind() == NONLINEAR_MULT)
-             && t[0][0].isConst() && t[0][0].getConst<Rational>().sgn() == -1)
-    {
-      // sin(-n*x) ---> -sin(n*x)
-      std::vector<Node> mchildren(t[0].begin(), t[0].end());
-      mchildren[0] = nm->mkConstReal(-t[0][0].getConst<Rational>());
-      Node ret = nm->mkNode(
-          kind::NEG,
-          nm->mkNode(kind::SINE, nm->mkNode(t[0].getKind(), mchildren)));
-      return RewriteResponse(REWRITE_AGAIN_FULL, ret);
-    }
-    else
-    {
-      // get the factor of PI in the argument
-      Node pi_factor;
-      Node pi;
-      Node rem;
-      std::map<Node, Node> msum;
-      if (ArithMSum::getMonomialSum(t[0], msum))
-      {
-        pi = mkPi();
-        std::map<Node, Node>::iterator itm = msum.find(pi);
-        if (itm != msum.end())
-        {
-          if (itm->second.isNull())
-          {
-            pi_factor = nm->mkConstReal(Rational(1));
-          }
-          else
-          {
-            pi_factor = itm->second;
-          }
-          msum.erase(pi);
-          if (!msum.empty())
-          {
-            rem = ArithMSum::mkNode(t[0].getType(), msum);
-          }
-        }
-      }
-      else
-      {
-        Assert(false);
-      }
-
-      // if there is a factor of PI
-      if( !pi_factor.isNull() ){
-        Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
-        Rational r = pi_factor.getConst<Rational>();
-        Rational r_abs = r.abs();
-        Rational rone = Rational(1);
-        Rational rtwo = Rational(2);
-        if (r_abs > rone)
-        {
-          //add/substract 2*pi beyond scope
-          Rational ra_div_two = (r_abs + rone) / rtwo;
-          Node new_pi_factor;
-          if (r.sgn() == 1)
-          {
-            new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
-          }
-          else
-          {
-            Assert(r.sgn() == -1);
-            new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
-          }
-          Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
-          if (!rem.isNull())
-          {
-            new_arg = nm->mkNode(kind::ADD, new_arg, rem);
-          }
-          // sin( 2*n*PI + x ) = sin( x )
-          return RewriteResponse(REWRITE_AGAIN_FULL,
-                                 nm->mkNode(kind::SINE, new_arg));
-        }
-        else if (r_abs == rone)
-        {
-          // sin( PI + x ) = -sin( x )
-          if (rem.isNull())
-          {
-            return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
-          }
-          else
-          {
-            return RewriteResponse(
-                REWRITE_AGAIN_FULL,
-                nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
-          }
-        }
-        else if (rem.isNull())
-        {
-          // other rational cases based on Niven's theorem
-          // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
-          Integer one = Integer(1);
-          Integer two = Integer(2);
-          Integer six = Integer(6);
-          if (r_abs.getDenominator() == two)
-          {
-            Assert(r_abs.getNumerator() == one);
-            return RewriteResponse(REWRITE_DONE,
-                                   nm->mkConstReal(Rational(r.sgn())));
-          }
-          else if (r_abs.getDenominator() == six)
-          {
-            Integer five = Integer(5);
-            if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
-            {
-              return RewriteResponse(
-                  REWRITE_DONE,
-                  nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
-            }
-          }
-        }
-      }
-    }
-    break;
-  case kind::COSINE: {
-    return RewriteResponse(
-        REWRITE_AGAIN_FULL,
-        nm->mkNode(
-            kind::SINE,
-            nm->mkNode(kind::SUB,
-                       nm->mkNode(kind::MULT,
-                                  nm->mkConstReal(Rational(1) / Rational(2)),
-                                  mkPi()),
-                       t[0])));
-  }
-  break;
-  case kind::TANGENT:
-  {
-    return RewriteResponse(REWRITE_AGAIN_FULL,
-                           nm->mkNode(kind::DIVISION,
-                                      nm->mkNode(kind::SINE, t[0]),
-                                      nm->mkNode(kind::COSINE, t[0])));
-  }
-  break;
-  case kind::COSECANT:
-  {
-    return RewriteResponse(REWRITE_AGAIN_FULL,
-                           nm->mkNode(kind::DIVISION,
-                                      nm->mkConstReal(Rational(1)),
-                                      nm->mkNode(kind::SINE, t[0])));
-  }
-  break;
-  case kind::SECANT:
-  {
-    return RewriteResponse(REWRITE_AGAIN_FULL,
-                           nm->mkNode(kind::DIVISION,
-                                      nm->mkConstReal(Rational(1)),
-                                      nm->mkNode(kind::COSINE, t[0])));
-  }
-  break;
-  case kind::COTANGENT:
-  {
-    return RewriteResponse(REWRITE_AGAIN_FULL,
-                           nm->mkNode(kind::DIVISION,
-                                      nm->mkNode(kind::COSINE, t[0]),
-                                      nm->mkNode(kind::SINE, t[0])));
-  }
-  break;
-  default:
-    break;
-  }
-  return RewriteResponse(REWRITE_DONE, t);
-}
-
-Node ArithRewriter::makeUnaryMinusNode(TNode n){
   NodeManager* nm = NodeManager::currentNM();
   Rational qNegOne(-1);
   return nm->mkNode(kind::MULT, nm->mkConstRealOrInt(n.getType(), qNegOne), n);
 }
 
-RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
+RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
+{
   Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
   Assert(t.getNumChildren() == 2);
 
@@ -905,10 +621,14 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
     NodeManager* nm = NodeManager::currentNM();
     const Rational& den = right.getConst<Rational>();
 
-    if(den.isZero()){
-      if(t.getKind() == kind::DIVISION_TOTAL){
+    if (den.isZero())
+    {
+      if (t.getKind() == kind::DIVISION_TOTAL)
+      {
         return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
-      }else{
+      }
+      else
+      {
         // This is unsupported, but this is not a good place to complain
         return RewriteResponse(REWRITE_DONE, t);
       }
@@ -961,10 +681,13 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
     }
 
     Node result = nm->mkRealAlgebraicNumber(inverse(den));
-    Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
-    if(pre){
+    Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+    if (pre)
+    {
       return RewriteResponse(REWRITE_DONE, mult);
-    }else{
+    }
+    else
+    {
       return RewriteResponse(REWRITE_AGAIN, mult);
     }
   }
@@ -1070,10 +793,13 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
   TNode n = t[0];
   TNode d = t[1];
   bool dIsConstant = d.isConst();
-  if(dIsConstant && d.getConst<Rational>().isZero()){
+  if (dIsConstant && d.getConst<Rational>().isZero())
+  {
     // (div x 0) ---> 0 or (mod x 0) ---> 0
     return returnRewrite(t, nm->mkConstInt(0), Rewrite::DIV_MOD_BY_ZERO);
-  }else if(dIsConstant && d.getConst<Rational>().isOne()){
+  }
+  else if (dIsConstant && d.getConst<Rational>().isOne())
+  {
     if (k == kind::INTS_MODULUS_TOTAL)
     {
       // (mod x 1) --> 0
@@ -1104,7 +830,8 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
 
     bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL);
 
-    Integer result = isDiv ? ni.euclidianDivideQuotient(di) : ni.euclidianDivideRemainder(di);
+    Integer result = isDiv ? ni.euclidianDivideQuotient(di)
+                           : ni.euclidianDivideRemainder(di);
 
     // constant evaluation
     // (mod c1 c2) ---> c3 or (div c1 c2) ---> c3
@@ -1161,6 +888,305 @@ RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre)
   return RewriteResponse(REWRITE_DONE, t);
 }
 
+RewriteResponse ArithRewriter::postRewriteIAnd(TNode t)
+{
+  Assert(t.getKind() == kind::IAND);
+  size_t bsize = t.getOperator().getConst<IntAnd>().d_size;
+  NodeManager* nm = NodeManager::currentNM();
+  // if constant, we eliminate
+  if (t[0].isConst() && t[1].isConst())
+  {
+    Node iToBvop = nm->mkConst(IntToBitVector(bsize));
+    Node arg1 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[0]);
+    Node arg2 = nm->mkNode(kind::INT_TO_BITVECTOR, iToBvop, t[1]);
+    Node bvand = nm->mkNode(kind::BITVECTOR_AND, arg1, arg2);
+    Node ret = nm->mkNode(kind::BITVECTOR_TO_NAT, bvand);
+    return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+  }
+  else if (t[0] > t[1])
+  {
+    // ((_ iand k) x y) ---> ((_ iand k) y x) if x > y by node ordering
+    Node ret = nm->mkNode(kind::IAND, t.getOperator(), t[1], t[0]);
+    return RewriteResponse(REWRITE_AGAIN, ret);
+  }
+  else if (t[0] == t[1])
+  {
+    // ((_ iand k) x x) ---> (mod x 2^k)
+    Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+    Node ret = nm->mkNode(kind::INTS_MODULUS, t[0], twok);
+    return RewriteResponse(REWRITE_AGAIN, ret);
+  }
+  // simplifications involving constants
+  for (unsigned i = 0; i < 2; i++)
+  {
+    if (!t[i].isConst())
+    {
+      continue;
+    }
+    if (t[i].getConst<Rational>().sgn() == 0)
+    {
+      // ((_ iand k) 0 y) ---> 0
+      return RewriteResponse(REWRITE_DONE, t[i]);
+    }
+    if (t[i].getConst<Rational>().getNumerator() == Integer(2).pow(bsize) - 1)
+    {
+      // ((_ iand k) 111...1 y) ---> (mod y 2^k)
+      Node twok = nm->mkConstInt(Rational(Integer(2).pow(bsize)));
+      Node ret = nm->mkNode(kind::INTS_MODULUS, t[1 - i], twok);
+      return RewriteResponse(REWRITE_AGAIN, ret);
+    }
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewritePow2(TNode t)
+{
+  Assert(t.getKind() == kind::POW2);
+  NodeManager* nm = NodeManager::currentNM();
+  // if constant, we eliminate
+  if (t[0].isConst())
+  {
+    // pow2 is only supported for integers
+    Assert(t[0].getType().isInteger());
+    Integer i = t[0].getConst<Rational>().getNumerator();
+    if (i < 0)
+    {
+      return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+    }
+    // (pow2 t) ---> (pow 2 t) and continue rewriting to eliminate pow
+    Node two = rewriter::mkConst(Integer(2));
+    Node ret = nm->mkNode(kind::POW, two, t[0]);
+    return RewriteResponse(REWRITE_AGAIN, ret);
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::preRewriteTranscendental(TNode t)
+{
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
+RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t)
+{
+  Trace("arith-tf-rewrite")
+      << "Rewrite transcendental function : " << t << std::endl;
+  NodeManager* nm = NodeManager::currentNM();
+  switch (t.getKind())
+  {
+    case kind::EXPONENTIAL:
+    {
+      if (t[0].isConst())
+      {
+        Node one = rewriter::mkConst(Integer(1));
+        if (t[0].getConst<Rational>().sgn() >= 0 && t[0].getType().isInteger()
+            && t[0] != one)
+        {
+          return RewriteResponse(
+              REWRITE_AGAIN,
+              nm->mkNode(kind::POW, nm->mkNode(kind::EXPONENTIAL, one), t[0]));
+        }
+        else
+        {
+          return RewriteResponse(REWRITE_DONE, t);
+        }
+      }
+      else if (t[0].getKind() == kind::ADD)
+      {
+        std::vector<Node> product;
+        for (const Node tc : t[0])
+        {
+          product.push_back(nm->mkNode(kind::EXPONENTIAL, tc));
+        }
+        // We need to do a full rewrite here, since we can get exponentials of
+        // constants, e.g. when we are rewriting exp(2 + x)
+        return RewriteResponse(REWRITE_AGAIN_FULL,
+                               nm->mkNode(kind::MULT, product));
+      }
+    }
+    break;
+    case kind::SINE:
+      if (t[0].isConst())
+      {
+        const Rational& rat = t[0].getConst<Rational>();
+        if (rat.sgn() == 0)
+        {
+          return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+        }
+        else if (rat.sgn() == -1)
+        {
+          Node ret = nm->mkNode(
+              kind::NEG, nm->mkNode(kind::SINE, rewriter::mkConst(-rat)));
+          return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+        }
+      }
+      else if ((t[0].getKind() == MULT || t[0].getKind() == NONLINEAR_MULT)
+               && t[0][0].isConst() && t[0][0].getConst<Rational>().sgn() == -1)
+      {
+        // sin(-n*x) ---> -sin(n*x)
+        std::vector<Node> mchildren(t[0].begin(), t[0].end());
+        mchildren[0] = nm->mkConstReal(-t[0][0].getConst<Rational>());
+        Node ret = nm->mkNode(
+            kind::NEG,
+            nm->mkNode(kind::SINE, nm->mkNode(t[0].getKind(), mchildren)));
+        return RewriteResponse(REWRITE_AGAIN_FULL, ret);
+      }
+      else
+      {
+        // get the factor of PI in the argument
+        Node pi_factor;
+        Node pi;
+        Node rem;
+        std::map<Node, Node> msum;
+        if (ArithMSum::getMonomialSum(t[0], msum))
+        {
+          pi = mkPi();
+          std::map<Node, Node>::iterator itm = msum.find(pi);
+          if (itm != msum.end())
+          {
+            if (itm->second.isNull())
+            {
+              pi_factor = rewriter::mkConst(Integer(1));
+            }
+            else
+            {
+              pi_factor = itm->second;
+            }
+            msum.erase(pi);
+            if (!msum.empty())
+            {
+              rem = ArithMSum::mkNode(t[0].getType(), msum);
+            }
+          }
+        }
+        else
+        {
+          Assert(false);
+        }
+
+        // if there is a factor of PI
+        if (!pi_factor.isNull())
+        {
+          Trace("arith-tf-rewrite-debug")
+              << "Process pi factor = " << pi_factor << std::endl;
+          Rational r = pi_factor.getConst<Rational>();
+          Rational r_abs = r.abs();
+          Rational rone = Rational(1);
+          Rational rtwo = Rational(2);
+          if (r_abs > rone)
+          {
+            // add/substract 2*pi beyond scope
+            Rational ra_div_two = (r_abs + rone) / rtwo;
+            Node new_pi_factor;
+            if (r.sgn() == 1)
+            {
+              new_pi_factor = nm->mkConstReal(r - rtwo * ra_div_two.floor());
+            }
+            else
+            {
+              Assert(r.sgn() == -1);
+              new_pi_factor = nm->mkConstReal(r + rtwo * ra_div_two.floor());
+            }
+            Node new_arg = nm->mkNode(kind::MULT, new_pi_factor, pi);
+            if (!rem.isNull())
+            {
+              new_arg = nm->mkNode(kind::ADD, new_arg, rem);
+            }
+            // sin( 2*n*PI + x ) = sin( x )
+            return RewriteResponse(REWRITE_AGAIN_FULL,
+                                   nm->mkNode(kind::SINE, new_arg));
+          }
+          else if (r_abs == rone)
+          {
+            // sin( PI + x ) = -sin( x )
+            if (rem.isNull())
+            {
+              return RewriteResponse(REWRITE_DONE,
+                                     nm->mkConstReal(Rational(0)));
+            }
+            else
+            {
+              return RewriteResponse(
+                  REWRITE_AGAIN_FULL,
+                  nm->mkNode(kind::NEG, nm->mkNode(kind::SINE, rem)));
+            }
+          }
+          else if (rem.isNull())
+          {
+            // other rational cases based on Niven's theorem
+            // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
+            Integer one = Integer(1);
+            Integer two = Integer(2);
+            Integer six = Integer(6);
+            if (r_abs.getDenominator() == two)
+            {
+              Assert(r_abs.getNumerator() == one);
+              return RewriteResponse(REWRITE_DONE,
+                                     nm->mkConstReal(Rational(r.sgn())));
+            }
+            else if (r_abs.getDenominator() == six)
+            {
+              Integer five = Integer(5);
+              if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
+              {
+                return RewriteResponse(
+                    REWRITE_DONE,
+                    nm->mkConstReal(Rational(r.sgn()) / Rational(2)));
+              }
+            }
+          }
+        }
+      }
+      break;
+    case kind::COSINE:
+    {
+      return RewriteResponse(
+          REWRITE_AGAIN_FULL,
+          nm->mkNode(
+              kind::SINE,
+              nm->mkNode(kind::SUB,
+                         nm->mkNode(kind::MULT,
+                                    nm->mkConstReal(Rational(1) / Rational(2)),
+                                    mkPi()),
+                         t[0])));
+    }
+    break;
+    case kind::TANGENT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkNode(kind::SINE, t[0]),
+                                        nm->mkNode(kind::COSINE, t[0])));
+    }
+    break;
+    case kind::COSECANT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkConstReal(Rational(1)),
+                                        nm->mkNode(kind::SINE, t[0])));
+    }
+    break;
+    case kind::SECANT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkConstReal(Rational(1)),
+                                        nm->mkNode(kind::COSINE, t[0])));
+    }
+    break;
+    case kind::COTANGENT:
+    {
+      return RewriteResponse(REWRITE_AGAIN_FULL,
+                             nm->mkNode(kind::DIVISION,
+                                        nm->mkNode(kind::COSINE, t[0]),
+                                        nm->mkNode(kind::SINE, t[0])));
+    }
+    break;
+    default: break;
+  }
+  return RewriteResponse(REWRITE_DONE, t);
+}
+
 TrustNode ArithRewriter::expandDefinition(Node node)
 {
   // call eliminate operators, to eliminate partial operators only
@@ -1172,8 +1198,8 @@ TrustNode ArithRewriter::expandDefinition(Node node)
 
 RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
 {
-  Trace("arith-rewrite") << "ArithRewriter : " << t << " == " << ret << " by "
-                         << r << std::endl;
+  Trace("arith-rewriter") << "ArithRewriter : " << t << " == " << ret << " by "
+                          << r << std::endl;
   return RewriteResponse(REWRITE_AGAIN_FULL, ret);
 }
 
index 072894000fd2662a2ab0f301a4520fb0014151c0..9e2a15c77f9d2b2284af39758e1b8946460ab4cd 100644 (file)
@@ -71,10 +71,14 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse preRewriteMult(TNode t);
   static RewriteResponse postRewriteMult(TNode t);
 
+  /** postRewrite IAND */
   static RewriteResponse postRewriteIAnd(TNode t);
+  /** postRewrite POW2 */
   static RewriteResponse postRewritePow2(TNode t);
 
+  /** preRewrite transcendental functions */
   static RewriteResponse preRewriteTranscendental(TNode t);
+  /** postRewrite transcendental functions */
   static RewriteResponse postRewriteTranscendental(TNode t);
 
   static bool isAtom(TNode n);