Improve the rewriter for SINE. (#1221)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Wed, 29 Nov 2017 01:48:53 +0000 (19:48 -0600)
committerGitHub <noreply@github.com>
Wed, 29 Nov 2017 01:48:53 +0000 (19:48 -0600)
src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_utilities.h

index 57428d20963b0c01173e34d005f4d677f18dd739..72f9cdf4a171929032b0918e4068be6b0c44086e 100644 (file)
@@ -20,6 +20,7 @@
 #include <vector>
 
 #include "smt/logic_exception.h"
+#include "theory/arith/arith_msum.h"
 #include "theory/arith/arith_rewriter.h"
 #include "theory/arith/arith_utilities.h"
 #include "theory/arith/normal_form.h"
@@ -384,24 +385,49 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
         return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(0)));
       }
     }else{
+      // get the factor of PI in the argument
       Node pi_factor;
       Node pi;
-      if( t[0].getKind()==kind::PI ){
-        pi_factor = NodeManager::currentNM()->mkConst(Rational(1));
-        pi = t[0];
-      }else if( t[0].getKind()==kind::MULT && t[0][0].isConst() && t[0][1].getKind()==kind::PI ){
-        pi_factor = t[0][0];
-        pi = t[0][1];
+      Node rem;
+      std::map<Node, Node> msum;
+      if (ArithMSum::getMonomialSum(t[0], msum))
+      {
+        pi = mkPi();
+        std::map<Node, Node>::iterator itm = msum.find(pi);
+        if (itm != msum.end())
+        {
+          if (itm->second.isNull())
+          {
+            pi_factor = mkRationalNode(Rational(1));
+          }
+          else
+          {
+            pi_factor = itm->second;
+          }
+          msum.erase(pi);
+          if (!msum.empty())
+          {
+            rem = ArithMSum::mkNode(msum);
+          }
+        }
+      }
+      else
+      {
+        Assert(false);
       }
+
+      // if there is a factor of PI
       if( !pi_factor.isNull() ){
         Trace("arith-tf-rewrite-debug") << "Process pi factor = " << pi_factor << std::endl;
         Rational r = pi_factor.getConst<Rational>();
-        Rational ra = r.abs();
+        Rational r_abs = r.abs();
         Rational rone = Rational(1);
-        Node ntwo = NodeManager::currentNM()->mkConst( Rational(2) );
-        if( ra > rone ){
+        Node ntwo = mkRationalNode(Rational(2));
+        if (r_abs > rone)
+        {
           //add/substract 2*pi beyond scope
-          Node ra_div_two = NodeManager::currentNM()->mkNode( kind::INTS_DIVISION, NodeManager::currentNM()->mkConst( ra + rone ), ntwo );
+          Node ra_div_two = NodeManager::currentNM()->mkNode(
+              kind::INTS_DIVISION, mkRationalNode(r_abs + rone), ntwo);
           Node new_pi_factor;
           if( r.sgn()==1 ){
             new_pi_factor = NodeManager::currentNM()->mkNode( kind::MINUS, pi_factor, NodeManager::currentNM()->mkNode( kind::MULT, ntwo, ra_div_two ) );
@@ -409,20 +435,55 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
             Assert( r.sgn()==-1 );
             new_pi_factor = NodeManager::currentNM()->mkNode( kind::PLUS, pi_factor, NodeManager::currentNM()->mkNode( kind::MULT, ntwo, ra_div_two ) );
           }
-          return RewriteResponse(REWRITE_AGAIN_FULL, NodeManager::currentNM()->mkNode( kind::SINE,
-                                                       NodeManager::currentNM()->mkNode( kind::MULT, new_pi_factor, pi ) ) );
-        }else if( ra == rone ){
-          return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(Rational(0)));
-        }else{
+          Node new_arg =
+              NodeManager::currentNM()->mkNode(kind::MULT, new_pi_factor, pi);
+          if (!rem.isNull())
+          {
+            new_arg =
+                NodeManager::currentNM()->mkNode(kind::PLUS, new_arg, rem);
+          }
+          // sin( 2*n*PI + x ) = sin( x )
+          return RewriteResponse(
+              REWRITE_AGAIN_FULL,
+              NodeManager::currentNM()->mkNode(kind::SINE, new_arg));
+        }
+        else if (r_abs == rone)
+        {
+          // sin( PI + x ) = -sin( x )
+          if (rem.isNull())
+          {
+            return RewriteResponse(REWRITE_DONE, mkRationalNode(Rational(0)));
+          }
+          else
+          {
+            return RewriteResponse(
+                REWRITE_AGAIN_FULL,
+                NodeManager::currentNM()->mkNode(
+                    kind::UMINUS,
+                    NodeManager::currentNM()->mkNode(kind::SINE, rem)));
+          }
+        }
+        else if (rem.isNull())
+        {
+          // other rational cases based on Niven's theorem
+          // (https://en.wikipedia.org/wiki/Niven%27s_theorem)
           Integer one = Integer(1);
           Integer two = Integer(2);
           Integer six = Integer(6);
-          if( ra.getDenominator()==two ){
-            return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst( Rational( r.sgn() ) ) );
-          }else if( ra.getDenominator()==six ){
+          if (r_abs.getDenominator() == two)
+          {
+            Assert(r_abs.getNumerator() == one);
+            return RewriteResponse(REWRITE_DONE,
+                                   mkRationalNode(Rational(r.sgn())));
+          }
+          else if (r_abs.getDenominator() == six)
+          {
             Integer five = Integer(5);
-            if( ra.getNumerator()==one || ra.getNumerator()==five ){
-              return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst( Rational( r.sgn() )/Rational(2) ) );
+            if (r_abs.getNumerator() == one || r_abs.getNumerator() == five)
+            {
+              return RewriteResponse(
+                  REWRITE_DONE,
+                  mkRationalNode(Rational(r.sgn()) / Rational(2)));
             }
           }
         }
index cfaf6ac03161dcaba13da5d351df351d4fcd1532..30db4ec42bd0910e0e40d2d0d199f2d13615cc14 100644 (file)
@@ -297,6 +297,12 @@ inline Node mkOnZeroIte(Node n, Node q, Node if_zero, Node not_zero) {
   return n.eqNode(zero).iteNode(q.eqNode(if_zero), q.eqNode(not_zero));
 }
 
+inline Node mkPi()
+{
+  return NodeManager::currentNM()->mkNullaryOperator(
+      NodeManager::currentNM()->realType(), kind::PI);
+}
+
 }/* CVC4::theory::arith namespace */
 }/* CVC4::theory namespace */
 }/* CVC4 namespace */