{
Assert(rewriter::isAtom(atom));
- if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response)
+ Kind kind = atom.getKind();
+ if (atom.getNumChildren() == 2)
{
- return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
+ if (auto response =
+ rewriter::tryEvaluateRelationReflexive(kind, atom[0], atom[1]);
+ response)
+ {
+ return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
+ }
}
- switch (atom.getKind())
+ switch (kind)
{
case Kind::GT:
return RewriteResponse(
nm->mkNode(kind::INTS_MODULUS_TOTAL, atom[0], rewriter::mkConst(k)),
rewriter::mkConst(Integer(0))));
}
+ // left |><| right
+ Kind kind = atom.getKind();
+ Node left = removeToReal(atom[0]);
+ Node right = removeToReal(atom[1]);
- if (auto response = rewriter::tryEvaluateRelationReflexive(atom); response)
+ if (auto response = rewriter::tryEvaluateRelationReflexive(kind, left, right);
+ response)
{
return RewriteResponse(REWRITE_DONE, rewriter::mkConst(*response));
}
- // left |><| right
- Kind kind = atom.getKind();
- TNode left = atom[0];
- TNode right = atom[1];
Assert(isRelationOperator(kind));
if (auto response = rewriter::tryEvaluateRelation(kind, left, right);
case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, true);
case kind::ABS: return rewriteAbs(t);
case kind::IS_INTEGER:
- case kind::TO_INTEGER: return RewriteResponse(REWRITE_DONE, t);
+ case kind::TO_INTEGER:
case kind::TO_REAL:
- case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
- case kind::POW: return RewriteResponse(REWRITE_DONE, t);
+ case kind::CAST_TO_REAL:
+ case kind::POW:
case kind::PI: return RewriteResponse(REWRITE_DONE, t);
default: Unhandled() << k;
}
case kind::INTS_MODULUS_TOTAL: return rewriteIntsDivModTotal(t, false);
case kind::ABS: return rewriteAbs(t);
case kind::TO_REAL:
- case kind::CAST_TO_REAL: return RewriteResponse(REWRITE_DONE, t[0]);
+ case kind::CAST_TO_REAL: return rewriteToReal(t);
case kind::TO_INTEGER: return rewriteExtIntegerOp(t);
case kind::POW:
{
RewriteResponse ArithRewriter::rewriteRAN(TNode t)
{
Assert(rewriter::isRAN(t));
+ Assert(t.getType().isReal());
const RealAlgebraicNumber& r = rewriter::getRAN(t);
if (r.isRational())
{
{
Assert(t.getKind() == kind::NEG);
+ NodeManager* nm = NodeManager::currentNM();
if (t[0].isConst())
{
Rational neg = -(t[0].getConst<Rational>());
- return RewriteResponse(REWRITE_DONE, rewriter::mkConst(neg));
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkConstRealOrInt(t.getType(), neg));
}
if (rewriter::isRAN(t[0]))
{
rewriter::mkConst(-rewriter::getRAN(t[0])));
}
- auto* nm = NodeManager::currentNM();
- Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[0]);
+ Node noUminus = nm->mkNode(kind::MULT, rewriter::mkConst(Rational(-1)), t[0]);
if (pre)
{
return RewriteResponse(REWRITE_DONE, noUminus);
Assert(t.getKind() == kind::SUB);
Assert(t.getNumChildren() == 2);
+ NodeManager* nm = NodeManager::currentNM();
if (t[0] == t[1])
{
- return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ return RewriteResponse(REWRITE_DONE,
+ nm->mkConstRealOrInt(t.getType(), Rational(0)));
}
- auto* nm = NodeManager::currentNM();
return RewriteResponse(
REWRITE_AGAIN_FULL,
nm->mkNode(Kind::ADD,
t[0],
- nm->mkNode(kind::MULT, rewriter::mkConst(Integer(-1)), t[1])));
+ nm->mkNode(kind::MULT,
+ nm->mkConstRealOrInt(t[1].getType(), Rational(-1)),
+ t[1])));
}
RewriteResponse ArithRewriter::preRewritePlus(TNode t)
rewriter::Sum sum;
for (const auto& child : children)
{
- rewriter::addToSum(sum, child);
+ rewriter::addToSum(sum, removeToReal(child));
}
- return RewriteResponse(REWRITE_DONE, rewriter::collectSum(sum));
+ Node retSum = rewriter::collectSum(sum);
+ retSum = maybeEnsureReal(t.getType(), retSum);
+ return RewriteResponse(REWRITE_DONE, retSum);
}
RewriteResponse ArithRewriter::preRewriteMult(TNode node)
if (auto res = rewriter::getZeroChild(node); res)
{
- return RewriteResponse(REWRITE_DONE, *res);
+ return RewriteResponse(REWRITE_DONE, maybeEnsureReal(node.getType(), *res));
}
return RewriteResponse(REWRITE_DONE, node);
}
if (auto res = rewriter::getZeroChild(children); res)
{
- return RewriteResponse(REWRITE_DONE, *res);
+ return RewriteResponse(REWRITE_DONE, maybeEnsureReal(t.getType(), *res));
}
+ // remove TO_REAL
+ for (TNode& tc : children)
+ {
+ tc = removeToReal(tc);
+ }
+
+ Node ret;
// Distribute over addition
if (std::any_of(children.begin(), children.end(), [](TNode child) {
return child.getKind() == Kind::ADD;
}))
{
- return RewriteResponse(REWRITE_DONE,
- rewriter::distributeMultiplication(children));
+ ret = rewriter::distributeMultiplication(children);
}
-
- RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
- std::vector<Node> leafs;
-
- for (const auto& child : children)
+ else
{
- if (child.isConst())
+ RealAlgebraicNumber ran = RealAlgebraicNumber(Integer(1));
+ std::vector<Node> leafs;
+
+ for (const auto& child : children)
{
- if (child.getConst<Rational>().isZero())
+ if (child.isConst())
{
- return RewriteResponse(REWRITE_DONE, child);
+ if (child.getConst<Rational>().isZero())
+ {
+ return RewriteResponse(REWRITE_DONE,
+ maybeEnsureReal(t.getType(), child));
+ }
+ ran *= child.getConst<Rational>();
+ }
+ else if (rewriter::isRAN(child))
+ {
+ ran *= rewriter::getRAN(child);
+ }
+ else
+ {
+ leafs.emplace_back(child);
}
- ran *= child.getConst<Rational>();
- }
- else if (rewriter::isRAN(child))
- {
- ran *= rewriter::getRAN(child);
- }
- else
- {
- leafs.emplace_back(child);
}
+ ret = rewriter::mkMultTerm(ran, std::move(leafs));
}
-
- return RewriteResponse(REWRITE_DONE,
- rewriter::mkMultTerm(ran, std::move(leafs)));
+ ret = maybeEnsureReal(t.getType(), ret);
+ return RewriteResponse(REWRITE_DONE, ret);
}
RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre)
Assert(t.getKind() == kind::DIVISION_TOTAL || t.getKind() == kind::DIVISION);
Assert(t.getNumChildren() == 2);
- Node left = t[0];
- Node right = t[1];
+ Node left = removeToReal(t[0]);
+ Node right = removeToReal(t[1]);
+ NodeManager* nm = NodeManager::currentNM();
if (right.isConst())
{
- NodeManager* nm = NodeManager::currentNM();
const Rational& den = right.getConst<Rational>();
if (den.isZero())
{
if (t.getKind() == kind::DIVISION_TOTAL)
{
- return RewriteResponse(REWRITE_DONE, nm->mkConstReal(0));
+ Node ret = nm->mkConstReal(0);
+ return RewriteResponse(REWRITE_DONE, ret);
}
else
{
- // This is unsupported, but this is not a good place to complain
return RewriteResponse(REWRITE_DONE, t);
}
}
}
Node result = nm->mkConstReal(den.inverse());
- Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+ Node mult =
+ ensureReal(NodeManager::currentNM()->mkNode(kind::MULT, left, result));
if (pre)
{
return RewriteResponse(REWRITE_DONE, mult);
}
- else
- {
- return RewriteResponse(REWRITE_AGAIN, mult);
- }
+ // requires again full since ensureReal may have added a to_real
+ return RewriteResponse(REWRITE_AGAIN_FULL, mult);
}
if (rewriter::isRAN(right))
{
const RealAlgebraicNumber& den = rewriter::getRAN(right);
-
+ // mkConst is applied to RAN in this block, which are always Real
if (left.isConst())
{
return RewriteResponse(
- REWRITE_DONE, rewriter::mkConst(left.getConst<Rational>() / den));
+ REWRITE_DONE,
+ ensureReal(rewriter::mkConst(left.getConst<Rational>() / den)));
}
if (rewriter::isRAN(left))
{
- return RewriteResponse(REWRITE_DONE,
- rewriter::mkConst(rewriter::getRAN(left) / den));
+ return RewriteResponse(
+ REWRITE_DONE,
+ ensureReal(rewriter::mkConst(rewriter::getRAN(left) / den)));
}
Node result = rewriter::mkConst(inverse(den));
- Node mult = NodeManager::currentNM()->mkNode(kind::MULT, left, result);
+ Node mult =
+ ensureReal(NodeManager::currentNM()->mkNode(kind::MULT, left, result));
if (pre)
{
return RewriteResponse(REWRITE_DONE, mult);
}
- else
- {
- return RewriteResponse(REWRITE_AGAIN, mult);
- }
+ // requires again full since ensureReal may have added a to_real
+ return RewriteResponse(REWRITE_AGAIN_FULL, mult);
+ }
+ Node ret = nm->mkNode(t.getKind(), left, right);
+ return RewriteResponse(REWRITE_DONE, ret);
+}
+
+RewriteResponse ArithRewriter::rewriteToReal(TNode t)
+{
+ Assert(t.getKind() == kind::CAST_TO_REAL || t.getKind() == kind::TO_REAL);
+ if (!t[0].getType().isInteger())
+ {
+ // if it is already real type, then just return the argument
+ return RewriteResponse(REWRITE_DONE, t[0]);
+ }
+ NodeManager* nm = NodeManager::currentNM();
+ if (t[0].isConst())
+ {
+ // If the argument is constant, return a real constant.
+ // !!!! Note that this does not preserve the type of t, since rat is
+ // an integral rational. This will be corrected when the type rule for
+ // CONST_RATIONAL is changed to always return Real.
+ const Rational& rat = t[0].getConst<Rational>();
+ return RewriteResponse(REWRITE_DONE, nm->mkConstReal(rat));
}
+ // CAST_TO_REAL is our way of marking integral constants coming from the
+ // user as Real. It should only be applied to constants, which is handled
+ // above.
+ Assert(t.getKind() != kind::CAST_TO_REAL);
return RewriteResponse(REWRITE_DONE, t);
}
Node ret = isPred ? nm->mkConst(false) : nm->mkConstReal(Rational(3));
return returnRewrite(t, ret, Rewrite::INT_EXT_PI);
}
+ else if (t[0].getKind() == kind::TO_REAL)
+ {
+ Node ret = nm->mkNode(t.getKind(), t[0][0]);
+ return returnRewrite(t, ret, Rewrite::INT_EXT_TO_REAL);
+ }
return RewriteResponse(REWRITE_DONE, t);
}
const Rational& rat = t[0].getConst<Rational>();
if (rat.sgn() == 0)
{
- return RewriteResponse(REWRITE_DONE, rewriter::mkConst(Integer(0)));
+ return RewriteResponse(REWRITE_DONE, nm->mkConstReal(Rational(0)));
}
else if (rat.sgn() == -1)
{
- Node ret = nm->mkNode(
- kind::NEG, nm->mkNode(kind::SINE, rewriter::mkConst(-rat)));
+ Node ret = nm->mkNode(kind::NEG,
+ nm->mkNode(kind::SINE, nm->mkConstReal(-rat)));
return RewriteResponse(REWRITE_AGAIN_FULL, ret);
}
}
{
new_arg = nm->mkNode(kind::ADD, new_arg, rem);
}
+ new_arg = ensureReal(new_arg);
// sin( 2*n*PI + x ) = sin( x )
return RewriteResponse(REWRITE_AGAIN_FULL,
nm->mkNode(kind::SINE, new_arg));
if (r_abs.getDenominator() == two)
{
Assert(r_abs.getNumerator() == one);
- return RewriteResponse(REWRITE_DONE,
- nm->mkConstReal(Rational(r.sgn())));
+ return RewriteResponse(
+ REWRITE_DONE, ensureReal(nm->mkConstReal(Rational(r.sgn()))));
}
else if (r_abs.getDenominator() == six)
{
return ret;
}
+TNode ArithRewriter::removeToReal(TNode t)
+{
+ return t.getKind() == kind::TO_REAL ? t[0] : t;
+}
+
+Node ArithRewriter::maybeEnsureReal(TypeNode tn, TNode t)
+{
+ // if we require being a real
+ if (!tn.isInteger())
+ {
+ // ensure that t has type real
+ Assert(tn.isReal());
+ return ensureReal(t);
+ }
+ return t;
+}
+
+Node ArithRewriter::ensureReal(TNode t)
+{
+ if (t.getType().isInteger())
+ {
+ if (t.isConst())
+ {
+ // short-circuit
+ Node ret = NodeManager::currentNM()->mkConstReal(t.getConst<Rational>());
+ Assert(ret.getType().isReal());
+ return ret;
+ }
+ Trace("arith-rewriter-debug") << "maybeEnsureReal: " << t << std::endl;
+ return NodeManager::currentNM()->mkNode(kind::TO_REAL, t);
+ }
+ return t;
+}
+
RewriteResponse ArithRewriter::returnRewrite(TNode t, Node ret, Rewrite r)
{
Trace("arith-rewriter") << "ArithRewriter : " << t << " == " << ret << " by "