Fix getModelValue for arithmetic (#8316)
[cvc5.git] / src / theory / arith / arith_poly_norm.cpp
1 /******************************************************************************
2 * Top contributors (to current version):
3 * Andrew Reynolds
4 *
5 * This file is part of the cvc5 project.
6 *
7 * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
8 * in the top-level source directory and their institutional affiliations.
9 * All rights reserved. See the file COPYING in the top-level source
10 * directory for licensing information.
11 * ****************************************************************************
12 *
13 * Arithmetic utility for polynomial normalization
14 */
15
16 #include "theory/arith/arith_poly_norm.h"
17
18 using namespace cvc5::kind;
19
20 namespace cvc5 {
21 namespace theory {
22 namespace arith {
23
24 void PolyNorm::addMonomial(TNode x, const Rational& c, bool isNeg)
25 {
26 Assert(c.sgn() != 0);
27 std::unordered_map<Node, Rational>::iterator it = d_polyNorm.find(x);
28 if (it == d_polyNorm.end())
29 {
30 d_polyNorm[x] = isNeg ? -c : c;
31 return;
32 }
33 Rational res(it->second + (isNeg ? -c : c));
34 if (res.sgn() == 0)
35 {
36 // cancels
37 d_polyNorm.erase(it);
38 }
39 else
40 {
41 d_polyNorm[x] = res;
42 }
43 }
44
45 void PolyNorm::multiplyMonomial(TNode x, const Rational& c)
46 {
47 Assert(c.sgn() != 0);
48 if (x.isNull())
49 {
50 // multiply by constant
51 for (std::pair<const Node, Rational>& m : d_polyNorm)
52 {
53 // c1*x*c2 = (c1*c2)*x
54 m.second *= c;
55 }
56 }
57 else
58 {
59 std::unordered_map<Node, Rational> ptmp = d_polyNorm;
60 d_polyNorm.clear();
61 for (const std::pair<const Node, Rational>& m : ptmp)
62 {
63 // c1*x1*c2*x2 = (c1*c2)*(x1*x2)
64 Node newM = multMonoVar(m.first, x);
65 d_polyNorm[newM] = m.second * c;
66 }
67 }
68 }
69
70 void PolyNorm::add(const PolyNorm& p)
71 {
72 for (const std::pair<const Node, Rational>& m : p.d_polyNorm)
73 {
74 addMonomial(m.first, m.second);
75 }
76 }
77
78 void PolyNorm::subtract(const PolyNorm& p)
79 {
80 for (const std::pair<const Node, Rational>& m : p.d_polyNorm)
81 {
82 addMonomial(m.first, m.second, true);
83 }
84 }
85
86 void PolyNorm::multiply(const PolyNorm& p)
87 {
88 if (p.d_polyNorm.size() == 1)
89 {
90 for (const std::pair<const Node, Rational>& m : p.d_polyNorm)
91 {
92 multiplyMonomial(m.first, m.second);
93 }
94 }
95 else
96 {
97 // If multiplying by sum, must distribute; if multiplying by zero, clear.
98 // First, remember the current state and clear.
99 std::unordered_map<Node, Rational> ptmp = d_polyNorm;
100 d_polyNorm.clear();
101 for (const std::pair<const Node, Rational>& m : p.d_polyNorm)
102 {
103 PolyNorm pbase;
104 pbase.d_polyNorm = ptmp;
105 pbase.multiplyMonomial(m.first, m.second);
106 // add this to current
107 add(pbase);
108 }
109 }
110 }
111
112 void PolyNorm::clear() { d_polyNorm.clear(); }
113
114 bool PolyNorm::empty() const { return d_polyNorm.empty(); }
115
116 bool PolyNorm::isEqual(const PolyNorm& p) const
117 {
118 if (d_polyNorm.size() != p.d_polyNorm.size())
119 {
120 return false;
121 }
122 std::unordered_map<Node, Rational>::const_iterator it;
123 for (const std::pair<const Node, Rational>& m : d_polyNorm)
124 {
125 Assert(m.second.sgn() != 0);
126 it = p.d_polyNorm.find(m.first);
127 if (it == p.d_polyNorm.end() || m.second != it->second)
128 {
129 return false;
130 }
131 }
132 return true;
133 }
134
135 Node PolyNorm::multMonoVar(TNode m1, TNode m2)
136 {
137 std::vector<TNode> vars = getMonoVars(m1);
138 std::vector<TNode> vars2 = getMonoVars(m2);
139 vars.insert(vars.end(), vars2.begin(), vars2.end());
140 if (vars.empty())
141 {
142 // constants
143 return Node::null();
144 }
145 else if (vars.size() == 1)
146 {
147 return vars[0];
148 }
149 // use default sorting
150 std::sort(vars.begin(), vars.end());
151 return NodeManager::currentNM()->mkNode(NONLINEAR_MULT, vars);
152 }
153
154 std::vector<TNode> PolyNorm::getMonoVars(TNode m)
155 {
156 std::vector<TNode> vars;
157 // m is null if this is the empty variable (for constant monomials)
158 if (!m.isNull())
159 {
160 Kind k = m.getKind();
161 Assert(k != CONST_RATIONAL);
162 if (k == MULT || k == NONLINEAR_MULT)
163 {
164 vars.insert(vars.end(), m.begin(), m.end());
165 }
166 else
167 {
168 vars.push_back(m);
169 }
170 }
171 return vars;
172 }
173
174 PolyNorm PolyNorm::mkPolyNorm(TNode n)
175 {
176 Assert(n.getType().isRealOrInt());
177 Rational one(1);
178 Node null;
179 std::unordered_map<TNode, PolyNorm> visited;
180 std::unordered_map<TNode, PolyNorm>::iterator it;
181 std::vector<TNode> visit;
182 TNode cur;
183 visit.push_back(n);
184 do
185 {
186 cur = visit.back();
187 it = visited.find(cur);
188 Kind k = cur.getKind();
189 if (it == visited.end())
190 {
191 if (k == CONST_RATIONAL)
192 {
193 Rational r = cur.getConst<Rational>();
194 if (r.sgn() == 0)
195 {
196 // zero is not an entry
197 visited[cur] = PolyNorm();
198 }
199 else
200 {
201 visited[cur].addMonomial(null, r);
202 }
203 }
204 else if (k == ADD || k == SUB || k == NEG || k == MULT
205 || k == NONLINEAR_MULT)
206 {
207 visited[cur] = PolyNorm();
208 for (const Node& cn : cur)
209 {
210 visit.push_back(cn);
211 }
212 }
213 else
214 {
215 // it is a leaf
216 visited[cur].addMonomial(cur, one);
217 visit.pop_back();
218 }
219 continue;
220 }
221 visit.pop_back();
222 if (it->second.empty())
223 {
224 PolyNorm& ret = visited[cur];
225 switch (k)
226 {
227 case ADD:
228 case SUB:
229 case NEG:
230 case MULT:
231 case NONLINEAR_MULT:
232 for (size_t i = 0, nchild = cur.getNumChildren(); i < nchild; i++)
233 {
234 it = visited.find(cur[i]);
235 Assert(it != visited.end());
236 if ((k == SUB && i == 1) || k == NEG)
237 {
238 ret.subtract(it->second);
239 }
240 else if (i > 0 && (k == MULT || k == NONLINEAR_MULT))
241 {
242 ret.multiply(it->second);
243 }
244 else
245 {
246 ret.add(it->second);
247 }
248 }
249 break;
250 case CONST_RATIONAL: break;
251 default: Unhandled() << "Unhandled polynomial operation " << cur; break;
252 }
253 }
254 } while (!visit.empty());
255 Assert(visited.find(n) != visited.end());
256 return visited[n];
257 }
258
259 bool PolyNorm::isArithPolyNorm(TNode a, TNode b)
260 {
261 PolyNorm pa = PolyNorm::mkPolyNorm(a);
262 PolyNorm pb = PolyNorm::mkPolyNorm(b);
263 return pa.isEqual(pb);
264 }
265
266 } // namespace arith
267 } // namespace theory
268 } // namespace cvc5