refactor div rewriter, add support for ran (#7941)
authorGereon Kremer <gkremer@stanford.edu>
Fri, 14 Jan 2022 20:27:50 +0000 (12:27 -0800)
committerGitHub <noreply@github.com>
Fri, 14 Jan 2022 20:27:50 +0000 (20:27 +0000)
This extends the rewriter for division to also support real algebraic numbers.

src/theory/arith/arith_rewriter.cpp

index 09a340f82df3588b65fcd9da404a46789442683d..1fcb71643284349550954921e3814ee8d26de080 100644 (file)
@@ -920,6 +920,7 @@ Node ArithRewriter::makeUnaryMinusNode(TNode n){
 
 RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
   Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
+  Assert(t.getNumChildren() == 2);
 
   Node left = t[0];
   Node right = t[1];
@@ -941,15 +942,49 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){
     if (left.isConst())
     {
       const Rational& num = left.getConst<Rational>();
-      Rational div = num / den;
-      Node result = nm->mkConstReal(div);
-      return RewriteResponse(REWRITE_DONE, result);
+      return RewriteResponse(REWRITE_DONE, nm->mkConstReal(num / den));
+    }
+    if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& num =
+          left.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(
+          REWRITE_DONE,
+          nm->mkRealAlgebraicNumber(num / RealAlgebraicNumber(den)));
     }
 
-    Rational div = den.inverse();
-
-    Node result = nm->mkConstReal(div);
+    Node result = nm->mkConstReal(den.inverse());
+    Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+    if (pre)
+    {
+      return RewriteResponse(REWRITE_DONE, mult);
+    }
+    else
+    {
+      return RewriteResponse(REWRITE_AGAIN, mult);
+    }
+  }
+  if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  {
+    NodeManager* nm = NodeManager::currentNM();
+    const RealAlgebraicNumber& den =
+        right.getOperator().getConst<RealAlgebraicNumber>();
+    if (left.isConst())
+    {
+      const Rational& num = left.getConst<Rational>();
+      return RewriteResponse(
+          REWRITE_DONE,
+          nm->mkRealAlgebraicNumber(RealAlgebraicNumber(num) / den));
+    }
+    if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& num =
+          left.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(REWRITE_DONE,
+                             nm->mkRealAlgebraicNumber(num / den));
+    }
 
+    Node result = nm->mkRealAlgebraicNumber(inverse(den));
     Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result);
     if(pre){
       return RewriteResponse(REWRITE_DONE, mult);