Refactor atom rewriting to be RAN-aware (#7928)
authorGereon Kremer <gkremer@stanford.edu>
Wed, 12 Jan 2022 22:51:29 +0000 (14:51 -0800)
committerGitHub <noreply@github.com>
Wed, 12 Jan 2022 22:51:29 +0000 (22:51 +0000)
This PR starts refactoring the arithmetic rewriter by making rewriting of atoms aware of real algebraic numbers.
It also is slightly more aggressive now, directly rewriting relational operators where lhs = rhs in the preRewrite stage.

src/theory/arith/arith_rewriter.cpp
src/theory/arith/arith_rewriter.h

index 2c3fcdf48b48e49a904dab3d9a0289e993c39a45..d865eebe99a172f4bb96ee30fedc800769afc495 100644 (file)
@@ -32,6 +32,7 @@
 #include "util/bitvector.h"
 #include "util/divisible.h"
 #include "util/iand.h"
+#include "util/real_algebraic_number.h"
 
 using namespace cvc5::kind;
 
@@ -39,8 +40,184 @@ namespace cvc5 {
 namespace theory {
 namespace arith {
 
+namespace {
+
+template <typename L, typename R>
+bool evaluateRelation(Kind rel, const L& l, const R& r)
+{
+  switch (rel)
+  {
+    case Kind::LT: return l < r;
+    case Kind::LEQ: return l <= r;
+    case Kind::EQUAL: return l == r;
+    case Kind::GEQ: return l >= r;
+    case Kind::GT: return l > r;
+    default: Unreachable(); return false;
+  }
+}
+
+}  // namespace
+
 ArithRewriter::ArithRewriter(OperatorElim& oe) : d_opElim(oe) {}
 
+RewriteResponse ArithRewriter::preRewrite(TNode t)
+{
+  if (isAtom(t))
+  {
+    return preRewriteAtom(t);
+  }
+  return preRewriteTerm(t);
+}
+
+RewriteResponse ArithRewriter::postRewrite(TNode t)
+{
+  if (isAtom(t))
+  {
+    RewriteResponse response = postRewriteAtom(t);
+    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
+    {
+      Comparison::parseNormalForm(response.d_node);
+    }
+    return response;
+  }
+  RewriteResponse response = postRewriteTerm(t);
+  if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
+  {
+    Polynomial::parsePolynomial(response.d_node);
+  }
+  return response;
+}
+
+RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
+{
+  Assert(isAtom(atom));
+
+  NodeManager* nm = NodeManager::currentNM();
+
+  if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
+  {
+    switch (atom.getKind())
+    {
+      case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
+      case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
+      default:;
+    }
+  }
+
+  switch (atom.getKind())
+  {
+    case Kind::GT:
+      return RewriteResponse(REWRITE_DONE,
+                             nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
+    case Kind::LT:
+      return RewriteResponse(REWRITE_DONE,
+                             nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
+    case Kind::IS_INTEGER:
+      if (atom[0].getType().isInteger())
+      {
+        return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      }
+      break;
+    case Kind::DIVISIBLE:
+      if (atom.getOperator().getConst<Divisible>().k.isOne())
+      {
+        return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+      }
+      break;
+    default:;
+  }
+
+  return RewriteResponse(REWRITE_DONE, atom);
+}
+
+RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
+{
+  Assert(isAtom(atom));
+  if (atom.getKind() == kind::IS_INTEGER)
+  {
+    return rewriteExtIntegerOp(atom);
+  }
+  else if (atom.getKind() == kind::DIVISIBLE)
+  {
+    if (atom[0].isConst())
+    {
+      return RewriteResponse(REWRITE_DONE,
+                             NodeManager::currentNM()->mkConst(bool(
+                                 (atom[0].getConst<Rational>()
+                                  / atom.getOperator().getConst<Divisible>().k)
+                                     .isIntegral())));
+    }
+    if (atom.getOperator().getConst<Divisible>().k.isOne())
+    {
+      return RewriteResponse(REWRITE_DONE,
+                             NodeManager::currentNM()->mkConst(true));
+    }
+    NodeManager* nm = NodeManager::currentNM();
+    return RewriteResponse(
+        REWRITE_AGAIN,
+        nm->mkNode(kind::EQUAL,
+                   nm->mkNode(kind::INTS_MODULUS_TOTAL,
+                              atom[0],
+                              nm->mkConstInt(Rational(
+                                  atom.getOperator().getConst<Divisible>().k))),
+                   nm->mkConstInt(Rational(0))));
+  }
+
+  // left |><| right
+  TNode left = atom[0];
+  TNode right = atom[1];
+
+  auto* nm = NodeManager::currentNM();
+  if (left.isConst())
+  {
+    const Rational& l = left.getConst<Rational>();
+    if (right.isConst())
+    {
+      const Rational& r = right.getConst<Rational>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+    else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& r =
+          right.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+  }
+  else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+  {
+    const RealAlgebraicNumber& l =
+        left.getOperator().getConst<RealAlgebraicNumber>();
+    if (right.isConst())
+    {
+      const Rational& r = right.getConst<Rational>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+    else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+    {
+      const RealAlgebraicNumber& r =
+          right.getOperator().getConst<RealAlgebraicNumber>();
+      return RewriteResponse(
+          REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+    }
+  }
+
+  Polynomial pleft = Polynomial::parsePolynomial(left);
+  Polynomial pright = Polynomial::parsePolynomial(right);
+
+  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
+  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
+
+  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
+  Assert(cmp.isNormalForm());
+  return RewriteResponse(REWRITE_DONE, cmp.getNode());
+}
+
 bool ArithRewriter::isAtom(TNode n) {
   Kind k = n.getKind();
   return arith::isRelationOperator(k) || k == kind::IS_INTEGER
@@ -648,100 +825,6 @@ RewriteResponse ArithRewriter::postRewriteTranscendental(TNode t) {
   return RewriteResponse(REWRITE_DONE, t);
 }
 
-RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){
-  if(atom.getKind() == kind::IS_INTEGER) {
-    return rewriteExtIntegerOp(atom);
-  } else if(atom.getKind() == kind::DIVISIBLE) {
-    if(atom[0].isConst()) {
-      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(bool((atom[0].getConst<Rational>() / atom.getOperator().getConst<Divisible>().k).isIntegral())));
-    }
-    if(atom.getOperator().getConst<Divisible>().k.isOne()) {
-      return RewriteResponse(REWRITE_DONE, NodeManager::currentNM()->mkConst(true));
-    }
-    NodeManager* nm = NodeManager::currentNM();
-    return RewriteResponse(
-        REWRITE_AGAIN,
-        nm->mkNode(kind::EQUAL,
-                   nm->mkNode(kind::INTS_MODULUS_TOTAL,
-                              atom[0],
-                              nm->mkConstInt(Rational(
-                                  atom.getOperator().getConst<Divisible>().k))),
-                   nm->mkConstInt(Rational(0))));
-  }
-
-  // left |><| right
-  TNode left = atom[0];
-  TNode right = atom[1];
-
-  Polynomial pleft = Polynomial::parsePolynomial(left);
-  Polynomial pright = Polynomial::parsePolynomial(right);
-
-  Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
-  Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
-
-  Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
-  Assert(cmp.isNormalForm());
-  return RewriteResponse(REWRITE_DONE, cmp.getNode());
-}
-
-RewriteResponse ArithRewriter::preRewriteAtom(TNode atom){
-  Assert(isAtom(atom));
-
-  NodeManager* currNM = NodeManager::currentNM();
-
-  if(atom.getKind() == kind::EQUAL) {
-    if(atom[0] == atom[1]) {
-      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
-    }
-  }else if(atom.getKind() == kind::GT){
-    Node leq = currNM->mkNode(kind::LEQ, atom[0], atom[1]);
-    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, leq));
-  }else if(atom.getKind() == kind::LT){
-    Node geq = currNM->mkNode(kind::GEQ, atom[0], atom[1]);
-    return RewriteResponse(REWRITE_DONE, currNM->mkNode(kind::NOT, geq));
-  }else if(atom.getKind() == kind::IS_INTEGER){
-    if(atom[0].getType().isInteger()){
-      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
-    }
-  }else if(atom.getKind() == kind::DIVISIBLE){
-    if(atom.getOperator().getConst<Divisible>().k.isOne()){
-      return RewriteResponse(REWRITE_DONE, currNM->mkConst(true));
-    }
-  }
-
-  return RewriteResponse(REWRITE_DONE, atom);
-}
-
-RewriteResponse ArithRewriter::postRewrite(TNode t){
-  if(isTerm(t)){
-    RewriteResponse response = postRewriteTerm(t);
-    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
-    {
-      Polynomial::parsePolynomial(response.d_node);
-    }
-    return response;
-  }else if(isAtom(t)){
-    RewriteResponse response = postRewriteAtom(t);
-    if (Debug.isOn("arith::rewriter") && response.d_status == REWRITE_DONE)
-    {
-      Comparison::parseNormalForm(response.d_node);
-    }
-    return response;
-  }else{
-    Unreachable();
-  }
-}
-
-RewriteResponse ArithRewriter::preRewrite(TNode t){
-  if(isTerm(t)){
-    return preRewriteTerm(t);
-  }else if(isAtom(t)){
-    return preRewriteAtom(t);
-  }else{
-    Unreachable();
-  }
-}
-
 Node ArithRewriter::makeUnaryMinusNode(TNode n){
   NodeManager* nm = NodeManager::currentNM();
   Rational qNegOne(-1);
index 5fcb628b957d22872c5f5610e6dc5bf097dd4941..ad9e171451aee732b91e73a0f5cd6c7783145182 100644 (file)
@@ -43,6 +43,9 @@ class ArithRewriter : public TheoryRewriter
   TrustNode expandDefinition(Node node) override;
 
  private:
+  static RewriteResponse preRewriteAtom(TNode t);
+  static RewriteResponse postRewriteAtom(TNode t);
+
   static Node makeSubtractionNode(TNode l, TNode r);
   static Node makeUnaryMinusNode(TNode n);
 
@@ -71,9 +74,6 @@ class ArithRewriter : public TheoryRewriter
   static RewriteResponse preRewriteTranscendental(TNode t);
   static RewriteResponse postRewriteTranscendental(TNode t);
 
-  static RewriteResponse preRewriteAtom(TNode t);
-  static RewriteResponse postRewriteAtom(TNode t);
-
   static bool isAtom(TNode n);
 
   static inline bool isTerm(TNode n) {