Refactor transcendental solver (#5514)
[cvc5.git] / src / theory / arith / arith_utilities.h
1 /********************* */
2 /*! \file arith_utilities.h
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Tim King, Andrew Reynolds, Mathias Preiner
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Common functions for dealing with nodes.
13 **/
14
15 #include "cvc4_private.h"
16
17 #ifndef CVC4__THEORY__ARITH__ARITH_UTILITIES_H
18 #define CVC4__THEORY__ARITH__ARITH_UTILITIES_H
19
20 #include <math.h>
21 #include <unordered_map>
22 #include <unordered_set>
23 #include <vector>
24
25 #include "context/cdhashset.h"
26 #include "expr/node.h"
27 #include "theory/arith/arithvar.h"
28 #include "theory/arith/delta_rational.h"
29 #include "util/dense_map.h"
30 #include "util/integer.h"
31 #include "util/rational.h"
32
33 namespace CVC4 {
34 namespace theory {
35 namespace arith {
36
37 //Sets of Nodes
38 typedef std::unordered_set<Node, NodeHashFunction> NodeSet;
39 typedef std::unordered_set<TNode, TNodeHashFunction> TNodeSet;
40 typedef context::CDHashSet<Node, NodeHashFunction> CDNodeSet;
41
42 //Maps from Nodes -> ArithVars, and vice versa
43 typedef std::unordered_map<Node, ArithVar, NodeHashFunction> NodeToArithVarMap;
44 typedef DenseMap<Node> ArithVarToNodeMap;
45
46 inline Node mkRationalNode(const Rational& q){
47 return NodeManager::currentNM()->mkConst<Rational>(q);
48 }
49
50 inline Node mkBoolNode(bool b){
51 return NodeManager::currentNM()->mkConst<bool>(b);
52 }
53
54 inline Node mkIntSkolem(const std::string& name){
55 return NodeManager::currentNM()->mkSkolem(name, NodeManager::currentNM()->integerType());
56 }
57
58 inline Node mkRealSkolem(const std::string& name){
59 return NodeManager::currentNM()->mkSkolem(name, NodeManager::currentNM()->realType());
60 }
61
62 inline Node skolemFunction(const std::string& name, TypeNode dom, TypeNode range){
63 NodeManager* currNM = NodeManager::currentNM();
64 TypeNode functionType = currNM->mkFunctionType(dom, range);
65 return currNM->mkSkolem(name, functionType);
66 }
67
68 /** \f$ k \in {LT, LEQ, EQ, GEQ, GT} \f$ */
69 inline bool isRelationOperator(Kind k){
70 using namespace kind;
71
72 switch(k){
73 case LT:
74 case LEQ:
75 case EQUAL:
76 case GEQ:
77 case GT:
78 return true;
79 default:
80 return false;
81 }
82 }
83
84 /**
85 * Given a relational kind, k, return the kind k' s.t.
86 * swapping the lefthand and righthand side is equivalent.
87 *
88 * The following equivalence should hold,
89 * (k l r) <=> (k' r l)
90 */
91 inline Kind reverseRelationKind(Kind k){
92 using namespace kind;
93
94 switch(k){
95 case LT: return GT;
96 case LEQ: return GEQ;
97 case EQUAL: return EQUAL;
98 case GEQ: return LEQ;
99 case GT: return LT;
100
101 default:
102 Unreachable();
103 }
104 }
105
106 inline bool evaluateConstantPredicate(Kind k, const Rational& left, const Rational& right){
107 using namespace kind;
108
109 switch(k){
110 case LT: return left < right;
111 case LEQ: return left <= right;
112 case EQUAL: return left == right;
113 case GEQ: return left >= right;
114 case GT: return left > right;
115 default:
116 Unreachable();
117 return true;
118 }
119 }
120
121 /**
122 * Returns the appropriate coefficient for the infinitesimal given the kind
123 * for an arithmetic atom inorder to represent strict inequalities as inequalities.
124 * x < c becomes x <= c + (-1) * \f$ \delta \f$
125 * x > c becomes x >= x + ( 1) * \f$ \delta \f$
126 * Non-strict inequalities have a coefficient of zero.
127 */
128 inline int deltaCoeff(Kind k){
129 switch(k){
130 case kind::LT:
131 return -1;
132 case kind::GT:
133 return 1;
134 default:
135 return 0;
136 }
137 }
138
139 /**
140 * Given a literal to TheoryArith return a single kind to
141 * to indicate its underlying structure.
142 * The function returns the following in each case:
143 * - (K left right) -> K where is a wildcard for EQUAL, LT, GT, LEQ, or GEQ:
144 * - (NOT (EQUAL left right)) -> DISTINCT
145 * - (NOT (LEQ left right)) -> GT
146 * - (NOT (GEQ left right)) -> LT
147 * - (NOT (LT left right)) -> GEQ
148 * - (NOT (GT left right)) -> LEQ
149 * If none of these match, it returns UNDEFINED_KIND.
150 */
151 inline Kind oldSimplifiedKind(TNode literal){
152 switch(literal.getKind()){
153 case kind::LT:
154 case kind::GT:
155 case kind::LEQ:
156 case kind::GEQ:
157 case kind::EQUAL:
158 return literal.getKind();
159 case kind::NOT:
160 {
161 TNode atom = literal[0];
162 switch(atom.getKind()){
163 case kind::LEQ: //(not (LEQ x c)) <=> (GT x c)
164 return kind::GT;
165 case kind::GEQ: //(not (GEQ x c)) <=> (LT x c)
166 return kind::LT;
167 case kind::LT: //(not (LT x c)) <=> (GEQ x c)
168 return kind::GEQ;
169 case kind::GT: //(not (GT x c) <=> (LEQ x c)
170 return kind::LEQ;
171 case kind::EQUAL:
172 return kind::DISTINCT;
173 default:
174 Unreachable();
175 return kind::UNDEFINED_KIND;
176 }
177 }
178 default:
179 Unreachable();
180 return kind::UNDEFINED_KIND;
181 }
182 }
183
184 inline Kind negateKind(Kind k){
185 switch(k){
186 case kind::LT: return kind::GEQ;
187 case kind::GT: return kind::LEQ;
188 case kind::LEQ: return kind::GT;
189 case kind::GEQ: return kind::LT;
190 case kind::EQUAL: return kind::DISTINCT;
191 case kind::DISTINCT: return kind::EQUAL;
192 default:
193 return kind::UNDEFINED_KIND;
194 }
195 }
196
197 inline Node negateConjunctionAsClause(TNode conjunction){
198 Assert(conjunction.getKind() == kind::AND);
199 NodeBuilder<> orBuilder(kind::OR);
200
201 for(TNode::iterator i = conjunction.begin(), end=conjunction.end(); i != end; ++i){
202 TNode child = *i;
203 Node negatedChild = NodeBuilder<1>(kind::NOT)<<(child);
204 orBuilder << negatedChild;
205 }
206 return orBuilder;
207 }
208
209 inline Node maybeUnaryConvert(NodeBuilder<>& builder){
210 Assert(builder.getKind() == kind::OR || builder.getKind() == kind::AND
211 || builder.getKind() == kind::PLUS || builder.getKind() == kind::MULT);
212 Assert(builder.getNumChildren() >= 1);
213 if(builder.getNumChildren() == 1){
214 return builder[0];
215 }else{
216 return builder;
217 }
218 }
219
220 inline void flattenAnd(Node n, std::vector<TNode>& out){
221 Assert(n.getKind() == kind::AND);
222 for(Node::iterator i=n.begin(), i_end=n.end(); i != i_end; ++i){
223 Node curr = *i;
224 if(curr.getKind() == kind::AND){
225 flattenAnd(curr, out);
226 }else{
227 out.push_back(curr);
228 }
229 }
230 }
231
232 inline Node flattenAnd(Node n){
233 std::vector<TNode> out;
234 flattenAnd(n, out);
235 return NodeManager::currentNM()->mkNode(kind::AND, out);
236 }
237
238 // Returns an node that is the identity of a select few kinds.
239 inline Node getIdentity(Kind k){
240 switch(k){
241 case kind::AND:
242 return mkBoolNode(true);
243 case kind::PLUS:
244 return mkRationalNode(0);
245 case kind::MULT:
246 case kind::NONLINEAR_MULT:
247 return mkRationalNode(1);
248 default: Unreachable(); return {}; // silence warning
249 }
250 }
251
252 inline Node safeConstructNary(NodeBuilder<>& nb) {
253 switch (nb.getNumChildren()) {
254 case 0:
255 return getIdentity(nb.getKind());
256 case 1:
257 return nb[0];
258 default:
259 return (Node)nb;
260 }
261 }
262
263 inline Node safeConstructNary(Kind k, const std::vector<Node>& children) {
264 switch (children.size()) {
265 case 0:
266 return getIdentity(k);
267 case 1:
268 return children[0];
269 default:
270 return NodeManager::currentNM()->mkNode(k, children);
271 }
272 }
273
274 // Returns the multiplication of a and b.
275 inline Node mkMult(Node a, Node b) {
276 return NodeManager::currentNM()->mkNode(kind::MULT, a, b);
277 }
278
279 // Return a constraint that is equivalent to term being is in the range
280 // [start, end). This includes start and excludes end.
281 inline Node mkInRange(Node term, Node start, Node end) {
282 NodeManager* nm = NodeManager::currentNM();
283 Node above_start = nm->mkNode(kind::LEQ, start, term);
284 Node below_end = nm->mkNode(kind::LT, term, end);
285 return nm->mkNode(kind::AND, above_start, below_end);
286 }
287
288 // Creates an expression that constrains q to be equal to one of two expressions
289 // when n is 0 or not. Useful for division by 0 logic.
290 // (ite (= n 0) (= q if_zero) (= q not_zero))
291 inline Node mkOnZeroIte(Node n, Node q, Node if_zero, Node not_zero) {
292 Node zero = mkRationalNode(0);
293 return n.eqNode(zero).iteNode(q.eqNode(if_zero), q.eqNode(not_zero));
294 }
295
296 inline Node mkPi()
297 {
298 return NodeManager::currentNM()->mkNullaryOperator(
299 NodeManager::currentNM()->realType(), kind::PI);
300 }
301 /** Join kinds, where k1 and k2 are arithmetic relations returns an
302 * arithmetic relation ret such that
303 * if (a <k1> b) and (a <k2> b), then (a <ret> b).
304 */
305 Kind joinKinds(Kind k1, Kind k2);
306
307 /** Transitive kinds, where k1 and k2 are arithmetic relations returns an
308 * arithmetic relation ret such that
309 * if (a <k1> b) and (b <k2> c) then (a <ret> c).
310 */
311 Kind transKinds(Kind k1, Kind k2);
312
313 /** Is k a transcendental function kind? */
314 bool isTranscendentalKind(Kind k);
315 /**
316 * Get a lower/upper approximation of the constant r within the given
317 * level of precision. In other words, this returns a constant c' such that
318 * c' <= c <= c' + 1/(10^prec) if isLower is true, or
319 * c' + 1/(10^prec) <= c <= c' if isLower is false.
320 * where c' is a rational of the form n/d for some n and d <= 10^prec.
321 */
322 Node getApproximateConstant(Node c, bool isLower, unsigned prec);
323
324 /** print rational approximation of cr with precision prec on trace c */
325 void printRationalApprox(const char* c, Node cr, unsigned prec = 5);
326
327 /** Arithmetic substitute
328 *
329 * This computes the substitution n { vars -> subs }, but with the caveat
330 * that subterms of n that belong to a theory other than arithmetic are
331 * not traversed. In other words, terms that belong to other theories are
332 * treated as atomic variables. For example:
333 * (5*f(x) + 7*x ){ x -> 3 } returns 5*f(x) + 7*3.
334 */
335 Node arithSubstitute(Node n, std::vector<Node>& vars, std::vector<Node>& subs);
336
337 /** Make the node u >= a ^ a >= l */
338 Node mkBounded(Node l, Node a, Node u);
339
340 Rational leastIntGreaterThan(const Rational&);
341
342 Rational greatestIntLessThan(const Rational&);
343
344 /** Negates a node in arithmetic proof normal form. */
345 Node negateProofLiteral(TNode n);
346
347 }/* CVC4::theory::arith namespace */
348 }/* CVC4::theory namespace */
349 }/* CVC4 namespace */
350
351 #endif /* CVC4__THEORY__ARITH__ARITH_UTILITIES_H */