14c52526729b53c0ff4b48ce344ca3c8d9b1b954
[cvc5.git] / src / theory / arith / normal_form.cpp
1 /********************* */
2 /*! \file normal_form.cpp
3 ** \verbatim
4 ** Original author: taking
5 ** Major contributors: none
6 ** Minor contributors (to current version): none
7 ** This file is part of the CVC4 prototype.
8 ** Copyright (c) 2009, 2010 The Analysis of Computer Systems Group (ACSys)
9 ** Courant Institute of Mathematical Sciences
10 ** New York University
11 ** See the file COPYING in the top-level source directory for licensing
12 ** information.\endverbatim
13 **
14 ** \brief [[ Add one-line brief description here ]]
15 **
16 ** [[ Add lengthier description here ]]
17 ** \todo document this file
18 **/
19
20 #include "theory/arith/normal_form.h"
21 #include <list>
22
23 using namespace std;
24 using namespace CVC4;
25 using namespace CVC4::theory;
26 using namespace CVC4::theory::arith;
27
28 bool VarList::isSorted(iterator start, iterator end) {
29 return __gnu_cxx::is_sorted(start, end);
30 }
31
32 bool VarList::isMember(Node n) {
33 if(n.getNumChildren() == 0) {
34 return Variable::isMember(n);
35 } else if(n.getKind() == kind::MULT) {
36 Node::iterator curr = n.begin(), end = n.end();
37 Node prev = *curr;
38 if(!Variable::isMember(prev)) return false;
39
40 while( (++curr) != end) {
41 if(!Variable::isMember(*curr)) return false;
42 if(!(prev <= *curr)) return false;
43 prev = *curr;
44 }
45 return true;
46 } else {
47 return false;
48 }
49 }
50 int VarList::cmp(const VarList& vl) const {
51 int dif = this->size() - vl.size();
52 if (dif == 0) {
53 return this->getNode().getId() - vl.getNode().getId();
54 } else if(dif < 0) {
55 return -1;
56 } else {
57 return 1;
58 }
59 }
60
61 VarList VarList::parseVarList(Node n) {
62 if(n.getNumChildren() == 0) {
63 return VarList(Variable(n));
64 } else {
65 Assert(n.getKind() == kind::MULT);
66 for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) {
67 Assert(Variable::isMember(*i));
68 }
69 return VarList(n);
70 }
71 }
72
73 VarList VarList::operator*(const VarList& vl) const {
74 if(this->empty()) {
75 return vl;
76 } else if(vl.empty()) {
77 return *this;
78 } else {
79 vector<Node> result;
80 back_insert_iterator< vector<Node> > bii(result);
81
82 Node::iterator
83 thisBegin = this->backingNode.begin(),
84 thisEnd = this->backingNode.end(),
85 v1Begin = vl.backingNode.begin(),
86 v1End = vl.backingNode.end();
87
88 merge(thisBegin, thisEnd, v1Begin, v1End, bii);
89 Assert(result.size() >= 2);
90 Node mult = NodeManager::currentNM()->mkNode(kind::MULT, result);
91 return VarList::parseVarList(mult);
92 }
93 }
94
95 Monomial Monomial::mkMonomial(const Constant& c, const VarList& vl) {
96 if(c.isZero() || vl.empty() ) {
97 return Monomial(c);
98 } else if(c.isOne()) {
99 return Monomial(vl);
100 } else {
101 return Monomial(c, vl);
102 }
103 }
104 Monomial Monomial::parseMonomial(Node n) {
105 if(n.getKind() == kind::CONST_RATIONAL) {
106 return Monomial(Constant(n));
107 } else if(multStructured(n)) {
108 return Monomial::mkMonomial(Constant(n[0]),VarList::parseVarList(n[1]));
109 } else {
110 return Monomial(VarList::parseVarList(n));
111 }
112 }
113
114 Monomial Monomial::operator*(const Monomial& mono) const {
115 Constant newConstant = this->getConstant() * mono.getConstant();
116 VarList newVL = this->getVarList() * mono.getVarList();
117
118 return Monomial::mkMonomial(newConstant, newVL);
119 }
120
121 vector<Monomial> Monomial::sumLikeTerms(const vector<Monomial> & monos) {
122 Assert(isSorted(monos));
123
124 Debug("blah") << "start sumLikeTerms" << std::endl;
125 printList(monos);
126 vector<Monomial> outMonomials;
127 typedef vector<Monomial>::const_iterator iterator;
128 for(iterator rangeIter = monos.begin(), end=monos.end(); rangeIter != end;) {
129 Rational constant = (*rangeIter).getConstant().getValue();
130 VarList varList = (*rangeIter).getVarList();
131 ++rangeIter;
132 while(rangeIter != end && varList == (*rangeIter).getVarList()) {
133 constant += (*rangeIter).getConstant().getValue();
134 ++rangeIter;
135 }
136 if(constant != 0) {
137 Constant asConstant = Constant::mkConstant(constant);
138 Monomial nonZero = Monomial::mkMonomial(asConstant, varList);
139 outMonomials.push_back(nonZero);
140 }
141 }
142 Debug("blah") << "outmonomials" << std::endl;
143 printList(monos);
144 Debug("blah") << "end sumLikeTerms" << std::endl;
145
146 Assert(isStrictlySorted(outMonomials));
147 return outMonomials;
148 }
149
150 void Monomial::printList(const std::vector<Monomial>& monos) {
151 typedef std::vector<Monomial>::const_iterator iterator;
152 for(iterator i = monos.begin(), end = monos.end(); i != end; ++i) {
153 Debug("blah") << ((*i).getNode()) << std::endl;
154 }
155 }
156
157 Polynomial Polynomial::operator+(const Polynomial& vl) const {
158 this->printList();
159 vl.printList();
160
161 std::vector<Monomial> sortedMonos;
162 std::back_insert_iterator<std::vector<Monomial> > bii(sortedMonos);
163 std::merge(begin(), end(), vl.begin(), vl.end(), bii);
164
165 std::vector<Monomial> combined = Monomial::sumLikeTerms(sortedMonos);
166
167 Polynomial result = mkPolynomial(combined);
168 result.printList();
169 return result;
170 }
171
172 Polynomial Polynomial::operator*(const Monomial& mono) const {
173 if(mono.isZero()) {
174 return Polynomial(mono); //Don't multiply by zero
175 } else {
176 std::vector<Monomial> newMonos;
177 for(iterator i = this->begin(), end = this->end(); i != end; ++i) {
178 newMonos.push_back(mono * (*i));
179 }
180
181 // We may need to sort newMonos.
182 // Suppose this = (+ x y), mono = x, (* x y).getId() < (* x x).getId()
183 // newMonos = <(* x x), (* x y)> after this loop.
184 // This is not sorted according to the current VarList order.
185 std::sort(newMonos.begin(), newMonos.end());
186 return Polynomial::mkPolynomial(newMonos);
187 }
188 }
189
190 Polynomial Polynomial::operator*(const Polynomial& poly) const {
191 Polynomial res = Polynomial::mkZero();
192 for(iterator i = this->begin(), end = this->end(); i != end; ++i) {
193 Monomial curr = *i;
194 Polynomial prod = poly * curr;
195 Polynomial sum = res + prod;
196 res = sum;
197 }
198 return res;
199 }
200
201
202 Node Comparison::toNode(Kind k, const Polynomial& l, const Constant& r) {
203 Assert(!l.isConstant());
204 Assert(isRelationOperator(k));
205 switch(k) {
206 case kind::GEQ:
207 case kind::EQUAL:
208 case kind::LEQ:
209 return NodeManager::currentNM()->mkNode(k, l.getNode(),r.getNode());
210 case kind::LT:
211 return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::GEQ,l,r));
212 case kind::GT:
213 return NodeManager::currentNM()->mkNode(kind::NOT, toNode(kind::LEQ,l,r));
214 default:
215 Unreachable();
216 }
217 }
218
219 Comparison Comparison::parseNormalForm(TNode n) {
220 if(n.getKind() == kind::CONST_BOOLEAN) {
221 return Comparison(n.getConst<bool>());
222 } else {
223 bool negated = n.getKind() == kind::NOT;
224 Node relation = negated ? n[0] : n;
225 Assert( !negated ||
226 relation.getKind() == kind::LEQ ||
227 relation.getKind() == kind::GEQ);
228
229 Polynomial left = Polynomial::parsePolynomial(relation[0]);
230 Constant right(relation[1]);
231
232 Kind newOperator = relation.getKind();
233 if(negated) {
234 if(newOperator == kind::LEQ) {
235 newOperator = kind::GT;
236 } else {
237 newOperator = kind::LT;
238 }
239 }
240 return Comparison(n, newOperator, left, right);
241 }
242 }
243
244 Comparison Comparison::mkComparison(Kind k, const Polynomial& left, const Constant& right) {
245 Assert(isRelationOperator(k));
246 if(left.isConstant()) {
247 const Rational& lConst = left.getNode().getConst<Rational>();
248 const Rational& rConst = right.getNode().getConst<Rational>();
249 bool res = evaluateConstantPredicate(k, lConst, rConst);
250 return Comparison(res);
251 } else {
252 return Comparison(toNode(k, left, right), k, left, right);
253 }
254 }
255
256 Comparison Comparison::addConstant(const Constant& constant) const {
257 Assert(!isBoolean());
258 Monomial mono(constant);
259 Polynomial constAsPoly( mono );
260 Polynomial newLeft = getLeft() + constAsPoly;
261 Constant newRight = getRight() + constant;
262 return mkComparison(oper, newLeft, newRight);
263 }
264
265 Comparison Comparison::multiplyConstant(const Constant& constant) const {
266 Assert(!isBoolean());
267 Kind newOper = (constant.getValue() < 0) ? negateRelationKind(oper) : oper;
268
269 return mkComparison(newOper, left*Monomial(constant), right*constant);
270 }