From 341794b1cbd5693010c78b9f5bfe232ee90404b0 Mon Sep 17 00:00:00 2001 From: Tim King Date: Sun, 11 Nov 2012 00:28:05 +0000 Subject: [PATCH] Fixes for the arithmetic normal form and rewriter to handle arbitrary constants for total functions. --- src/theory/arith/arith_rewriter.cpp | 232 +++++++++++++++++----------- src/theory/arith/arith_rewriter.h | 2 + src/theory/arith/normal_form.cpp | 6 +- 3 files changed, 143 insertions(+), 97 deletions(-) diff --git a/src/theory/arith/arith_rewriter.cpp b/src/theory/arith/arith_rewriter.cpp index b6275ba24..689f231e6 100644 --- a/src/theory/arith/arith_rewriter.cpp +++ b/src/theory/arith/arith_rewriter.cpp @@ -80,58 +80,28 @@ RewriteResponse ArithRewriter::preRewriteTerm(TNode t){ return rewriteConstant(t); }else if(t.isVar()){ return rewriteVariable(t); - }else if(t.getKind() == kind::MINUS){ - return rewriteMinus(t, true); - }else if(t.getKind() == kind::UMINUS){ - return rewriteUMinus(t, true); - }else if(t.getKind() == kind::DIVISION){ - return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten - }else if(t.getKind() == kind::DIVISION_TOTAL){ - if(t[1].getKind()== kind::CONST_RATIONAL && - t[1].getConst().isZero()){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); - }else{ - return RewriteResponse(REWRITE_DONE, t); // wait until t[1] is rewritten - } - }else if(t.getKind() == kind::PLUS){ - return preRewritePlus(t); - }else if(t.getKind() == kind::MULT){ - return preRewriteMult(t); - }else if(t.getKind() == kind::INTS_DIVISION){ - Rational intOne(1); - if(t[1].getKind()== kind::CONST_RATIONAL && - t[1].getConst().isOne()){ - return RewriteResponse(REWRITE_AGAIN, t[0]); - }else{ - return RewriteResponse(REWRITE_DONE, t); - } - }else if(t.getKind() == kind::INTS_DIVISION_TOTAL){ - if(t[1].getKind()== kind::CONST_RATIONAL){ - Rational intOne(1), intZero(0); - if(t[1].getConst().isOne()){ - return RewriteResponse(REWRITE_AGAIN, t[0]); - } else if(t[1].getConst().isZero()){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); - } - } - return RewriteResponse(REWRITE_DONE, t); - }else if(t.getKind() == kind::INTS_MODULUS){ - Rational intOne(1); - if(t[1].getKind()== kind::CONST_RATIONAL && - t[1].getConst().isOne()){ - return RewriteResponse(REWRITE_AGAIN, mkRationalNode(0)); - }else{ - return RewriteResponse(REWRITE_DONE, t); - } - }else if(t.getKind() == kind::INTS_MODULUS_TOTAL){ - if(t[1].getKind()== kind::CONST_RATIONAL){ - if(t[1].getConst().isOne() || t[1].getConst().isZero()){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); - } - } - return RewriteResponse(REWRITE_DONE, t); }else{ - Unreachable(); + switch(t.getKind()){ + case kind::MINUS: + return rewriteMinus(t, true); + case kind::UMINUS: + return rewriteUMinus(t, true); + case kind::DIVISION: + return rewriteDiv(t,true); + case kind::DIVISION_TOTAL: + return rewriteDivTotal(t,true); + case kind::PLUS: + return preRewritePlus(t); + case kind::MULT: + return preRewriteMult(t); + //case kind::INTS_DIVISION: + //case kind::INTS_MODULUS: + case kind::INTS_DIVISION_TOTAL: + case kind::INTS_MODULUS_TOTAL: + return rewriteIntsDivModTotal(t,true); + default: + Unreachable(); + } } } RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ @@ -139,33 +109,32 @@ RewriteResponse ArithRewriter::postRewriteTerm(TNode t){ return rewriteConstant(t); }else if(t.isVar()){ return rewriteVariable(t); - }else if(t.getKind() == kind::MINUS){ - return rewriteMinus(t, false); - }else if(t.getKind() == kind::UMINUS){ - return rewriteUMinus(t, false); - }else if(t.getKind() == kind::DIVISION || - t.getKind() == kind::DIVISION_TOTAL){ - return rewriteDiv(t, false); - }else if(t.getKind() == kind::PLUS){ - return postRewritePlus(t); - }else if(t.getKind() == kind::MULT){ - return postRewriteMult(t); - }else if(t.getKind() == kind::INTS_DIVISION || - t.getKind() == kind::INTS_MODULUS){ - return RewriteResponse(REWRITE_DONE, t); - }else if(t.getKind() == kind::INTS_DIVISION_TOTAL || - t.getKind() == kind::INTS_MODULUS_TOTAL){ - if(t[1].getKind() == kind::CONST_RATIONAL && - t[1].getConst().isZero()){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); - }else{ - return RewriteResponse(REWRITE_DONE, t); - } }else{ - Unreachable(); + switch(t.getKind()){ + case kind::MINUS: + return rewriteMinus(t, false); + case kind::UMINUS: + return rewriteUMinus(t, false); + case kind::DIVISION: + return rewriteDiv(t, false); + case kind::DIVISION_TOTAL: + return rewriteDivTotal(t, false); + case kind::PLUS: + return postRewritePlus(t); + case kind::MULT: + return postRewriteMult(t); + //case kind::INTS_DIVISION: + //case kind::INTS_MODULUS: + case kind::INTS_DIVISION_TOTAL: + case kind::INTS_MODULUS_TOTAL: + return rewriteIntsDivModTotal(t, false); + default: + Unreachable(); + } } } + RewriteResponse ArithRewriter::preRewriteMult(TNode t){ Assert(t.getKind()== kind::MULT); @@ -217,19 +186,6 @@ RewriteResponse ArithRewriter::postRewriteMult(TNode t){ return RewriteResponse(REWRITE_DONE, res.getNode()); } -// RewriteResponse ArithRewriter::postRewriteAtomConstantRHS(TNode t){ -// TNode left = t[0]; -// TNode right = t[1]; - -// Polynomial pLeft = Polynomial::parsePolynomial(left); - - -// Comparison cmp = Comparison::mkComparison(t.getKind(), Polynomial::parsePolynomial(left), Constant(right)); - -// Assert(cmp.isNormalForm()); -// return RewriteResponse(REWRITE_DONE, cmp.getNode()); -// } - RewriteResponse ArithRewriter::postRewriteAtom(TNode atom){ // left |><| right TNode left = atom[0]; @@ -304,9 +260,52 @@ Node ArithRewriter::makeSubtractionNode(TNode l, TNode r){ return diff; } - RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ - Assert(t.getKind()== kind::DIVISION || t.getKind() == kind::DIVISION_TOTAL); + Assert(t.getKind()== kind::DIVISION); + + Node left = t[0]; + Node right = t[1]; + + if(right.getKind() == kind::CONST_RATIONAL && + left.getKind() != kind::CONST_RATIONAL){ + + const Rational& den = right.getConst(); + + Assert(!den.isZero()); + + Rational div = den.inverse(); + Node result = mkRationalNode(div); + Node mult = NodeManager::currentNM()->mkNode(kind::MULT,left,result); + if(pre){ + return RewriteResponse(REWRITE_DONE, mult); + }else{ + return RewriteResponse(REWRITE_AGAIN, mult); + } + } + + if(pre){ + if(right.getKind() != kind::CONST_RATIONAL || + left.getKind() != kind::CONST_RATIONAL){ + return RewriteResponse(REWRITE_DONE, t); + } + } + + Assert(right.getKind() == kind::CONST_RATIONAL); + Assert(left.getKind() == kind::CONST_RATIONAL); + + const Rational& den = right.getConst(); + + Assert(!den.isZero()); + + const Rational& num = left.getConst(); + Rational div = num / den; + Node result = mkRationalNode(div); + return RewriteResponse(REWRITE_DONE, result); +} + +RewriteResponse ArithRewriter::rewriteDivTotal(TNode t, bool pre){ + Assert(t.getKind() == kind::DIVISION_TOTAL); + Node left = t[0]; Node right = t[1]; @@ -314,14 +313,17 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ const Rational& den = right.getConst(); if(den.isZero()){ - if(t.getKind() == kind::DIVISION_TOTAL){ - return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); - }else{ - return RewriteResponse(REWRITE_DONE, t); - } + return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); } Assert(den != Rational(0)); + if(left.getKind() == kind::CONST_RATIONAL){ + const Rational& num = left.getConst(); + Rational div = num / den; + Node result = mkRationalNode(div); + return RewriteResponse(REWRITE_DONE, result); + } + Rational div = den.inverse(); Node result = mkRationalNode(div); @@ -337,6 +339,48 @@ RewriteResponse ArithRewriter::rewriteDiv(TNode t, bool pre){ } } +RewriteResponse ArithRewriter::rewriteIntsDivModTotal(TNode t, bool pre){ + Kind k = t.getKind(); + // Assert(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL || + // k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL); + + //Leaving the function as before (INTS_MODULUS can be handled), + // but restricting its use here + Assert(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL); + TNode n = t[0], d = t[1]; + bool dIsConstant = d.getKind() == kind::CONST_RATIONAL; + if(dIsConstant && d.getConst().isZero()){ + if(k == kind::INTS_MODULUS_TOTAL || k == kind::INTS_DIVISION_TOTAL){ + return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); + }else{ + // Do nothing for k == INTS_MODULUS + return RewriteResponse(REWRITE_DONE, t); + } + }else if(dIsConstant && d.getConst().isOne()){ + if(k == kind::INTS_MODULUS || k == kind::INTS_MODULUS_TOTAL){ + return RewriteResponse(REWRITE_DONE, mkRationalNode(0)); + }else{ + Assert(k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL); + return RewriteResponse(REWRITE_AGAIN, n); + } + }else if(dIsConstant && n.getKind() == kind::CONST_RATIONAL){ + Assert(d.getConst().isIntegral()); + Assert(n.getConst().isIntegral()); + Assert(!d.getConst().isZero()); + Integer di = d.getConst().getNumerator(); + Integer ni = n.getConst().getNumerator(); + + bool isDiv = (k == kind::INTS_DIVISION || k == kind::INTS_DIVISION_TOTAL); + + Integer result = isDiv ? ni.floorDivideQuotient(di) : ni.floorDivideRemainder(di); + + Node resultNode = mkRationalNode(Rational(result)); + return RewriteResponse(REWRITE_DONE, resultNode); + }else{ + return RewriteResponse(REWRITE_DONE, t); + } +} + }/* CVC4::theory::arith namespace */ }/* CVC4::theory namespace */ }/* CVC4 namespace */ diff --git a/src/theory/arith/arith_rewriter.h b/src/theory/arith/arith_rewriter.h index 986ff369d..10e255535 100644 --- a/src/theory/arith/arith_rewriter.h +++ b/src/theory/arith/arith_rewriter.h @@ -50,6 +50,8 @@ private: static RewriteResponse rewriteMinus(TNode t, bool pre); static RewriteResponse rewriteUMinus(TNode t, bool pre); static RewriteResponse rewriteDiv(TNode t, bool pre); + static RewriteResponse rewriteDivTotal(TNode t, bool pre); + static RewriteResponse rewriteIntsDivModTotal(TNode t, bool pre); static RewriteResponse preRewritePlus(TNode t); static RewriteResponse postRewritePlus(TNode t); diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp index 9bd0a3b6c..c863bf3c5 100644 --- a/src/theory/arith/normal_form.cpp +++ b/src/theory/arith/normal_form.cpp @@ -27,9 +27,9 @@ namespace arith { bool Variable::isDivMember(Node n){ switch(n.getKind()){ - case kind::DIVISION: - case kind::INTS_DIVISION: - case kind::INTS_MODULUS: + //case kind::DIVISION: + //case kind::INTS_DIVISION: + //case kind::INTS_MODULUS: case kind::DIVISION_TOTAL: case kind::INTS_DIVISION_TOTAL: case kind::INTS_MODULUS_TOTAL: -- 2.30.2