#include "theory/arith/rewriter/addition.h"
#include "theory/arith/rewriter/node_utils.h"
#include "theory/arith/rewriter/ordering.h"
+#include "theory/arith/rewriter/rewrite_atom.h"
#include "theory/theory.h"
#include "util/bitvector.h"
#include "util/divisible.h"
namespace {
-template <typename L, typename R>
-bool evaluateRelation(Kind rel, const L& l, const R& r)
-{
- switch (rel)
- {
- case Kind::LT: return l < r;
- case Kind::LEQ: return l <= r;
- case Kind::EQUAL: return l == r;
- case Kind::GEQ: return l >= r;
- case Kind::GT: return l > r;
- default: Unreachable(); return false;
- }
-}
-
/**
* Check whether the parent has a child that is a constant zero.
* If so, return this child. Otherwise, return std::nullopt.
RewriteResponse ArithRewriter::preRewrite(TNode t)
{
Trace("arith-rewriter") << "preRewrite(" << t << ")" << std::endl;
- if (isAtom(t))
+ if (rewriter::isAtom(t))
{
auto res = preRewriteAtom(t);
Trace("arith-rewriter")
RewriteResponse ArithRewriter::postRewrite(TNode t)
{
Trace("arith-rewriter") << "postRewrite(" << t << ")" << std::endl;
- if (isAtom(t))
+ if (rewriter::isAtom(t))
{
auto res = postRewriteAtom(t);
Trace("arith-rewriter")
RewriteResponse ArithRewriter::preRewriteAtom(TNode atom)
{
- Assert(isAtom(atom));
-
- NodeManager* nm = NodeManager::currentNM();
+ Assert(rewriter::isAtom(atom));
- if (isRelationOperator(atom.getKind()) && atom[0] == atom[1])
+ if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response)
{
- switch (atom.getKind())
- {
- case Kind::LT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
- case Kind::LEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
- case Kind::EQUAL: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
- case Kind::GEQ: return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
- case Kind::GT: return RewriteResponse(REWRITE_DONE, nm->mkConst(false));
- default:;
- }
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
}
switch (atom.getKind())
{
case Kind::GT:
- return RewriteResponse(REWRITE_DONE,
- nm->mkNode(kind::LEQ, atom[0], atom[1]).notNode());
+ return RewriteResponse(
+ REWRITE_DONE,
+ rewriter::buildRelation(kind::LEQ, atom[0], atom[1], true));
case Kind::LT:
- return RewriteResponse(REWRITE_DONE,
- nm->mkNode(kind::GEQ, atom[0], atom[1]).notNode());
+ return RewriteResponse(
+ REWRITE_DONE,
+ rewriter::buildRelation(kind::GEQ, atom[0], atom[1], true));
case Kind::IS_INTEGER:
if (atom[0].getType().isInteger())
{
- return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true));
}
break;
case Kind::DIVISIBLE:
if (atom.getOperator().getConst<Divisible>().k.isOne())
{
- return RewriteResponse(REWRITE_DONE, nm->mkConst(true));
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true));
}
break;
default:;
RewriteResponse ArithRewriter::postRewriteAtom(TNode atom)
{
- Assert(isAtom(atom));
+ Assert(rewriter::isAtom(atom));
+
if (atom.getKind() == kind::IS_INTEGER)
{
return rewriteExtIntegerOp(atom);
}
else if (atom.getKind() == kind::DIVISIBLE)
{
+ const Integer& k = atom.getOperator().getConst<Divisible>().k;
if (atom[0].isConst())
{
+ const Rational& num = atom[0].getConst<Rational>();
return RewriteResponse(REWRITE_DONE,
- NodeManager::currentNM()->mkConst(bool(
- (atom[0].getConst<Rational>()
- / atom.getOperator().getConst<Divisible>().k)
- .isIntegral())));
+ rewriter::mkConst((num / k).isIntegral()));
}
- if (atom.getOperator().getConst<Divisible>().k.isOne())
+ if (k.isOne())
{
- return RewriteResponse(REWRITE_DONE,
- NodeManager::currentNM()->mkConst(true));
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(true));
}
NodeManager* nm = NodeManager::currentNM();
return RewriteResponse(
REWRITE_AGAIN,
- nm->mkNode(kind::EQUAL,
- nm->mkNode(kind::INTS_MODULUS_TOTAL,
- atom[0],
- nm->mkConstInt(Rational(
- atom.getOperator().getConst<Divisible>().k))),
- nm->mkConstInt(Rational(0))));
+ nm->mkNode(
+ kind::EQUAL,
+ nm->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], rewriter::mkConst(k)),
+ rewriter::mkConst(Integer(0))));
+ }
+
+ if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response)
+ {
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
}
// left |><| right
+ Kind kind = atom.getKind();
TNode left = atom[0];
TNode right = atom[1];
+ Assert(isRelationOperator(kind));
- auto* nm = NodeManager::currentNM();
- if (left.isConst())
+ if (auto response = rewriter::tryEvaluateRelation(kind, left, right);
+ response)
{
- const Rational& l = left.getConst<Rational>();
- if (right.isConst())
- {
- const Rational& r = right.getConst<Rational>();
- return RewriteResponse(
- REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
- }
- else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
- {
- const RealAlgebraicNumber& r =
- right.getOperator().getConst<RealAlgebraicNumber>();
- return RewriteResponse(
- REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
- }
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
}
- else if (left.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+
+ bool negate = false;
+
+ switch (atom.getKind())
{
- const RealAlgebraicNumber& l =
- left.getOperator().getConst<RealAlgebraicNumber>();
- if (right.isConst())
+ case Kind::LEQ:
+ kind = Kind::GEQ;
+ negate = true;
+ break;
+ case Kind::LT:
+ kind = Kind::GT;
+ negate = true;
+ break;
+ default: break;
+ }
+
+ rewriter::Sum sum;
+ rewriter::addToSum(sum, left, negate);
+ rewriter::addToSum(sum, right, !negate);
+
+ // Now we have (rsum <kind> 0)
+ if (rewriter::isIntegral(atom))
+ {
+ if (kind == Kind::EQUAL)
{
- const Rational& r = right.getConst<Rational>();
- return RewriteResponse(
- REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::buildIntegerEquality(std::move(sum)));
}
- else if (right.getKind() == Kind::REAL_ALGEBRAIC_NUMBER)
+ return RewriteResponse(
+ REWRITE_DONE, rewriter::buildIntegerInequality(std::move(sum), kind));
+ }
+ else
+ {
+ if (kind == Kind::EQUAL)
{
- const RealAlgebraicNumber& r =
- right.getOperator().getConst<RealAlgebraicNumber>();
- return RewriteResponse(
- REWRITE_DONE, nm->mkConst(evaluateRelation(atom.getKind(), l, r)));
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::buildRealEquality(std::move(sum)));
}
+ return RewriteResponse(REWRITE_DONE,
+ rewriter::buildRealInequality(std::move(sum), kind));
}
-
- Polynomial pleft = Polynomial::parsePolynomial(left);
- Polynomial pright = Polynomial::parsePolynomial(right);
-
- Debug("arith::rewriter") << "pleft " << pleft.getNode() << std::endl;
- Debug("arith::rewriter") << "pright " << pright.getNode() << std::endl;
-
- Comparison cmp = Comparison::mkComparison(atom.getKind(), pleft, pright);
- Assert(cmp.isNormalForm());
- return RewriteResponse(REWRITE_DONE, cmp.getNode());
-}
-
-bool ArithRewriter::isAtom(TNode n) {
- Kind k = n.getKind();
- return arith::isRelationOperator(k) || k == kind::IS_INTEGER
- || k == kind::DIVISIBLE;
}
RewriteResponse ArithRewriter::rewriteNeg(TNode t, bool pre)