Refactor transcendental solver (#5539)
[cvc5.git] / src / theory / arith / nl / transcendental / sine_solver.cpp
1 /********************* */
2 /*! \file transcendental_solver.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Tim King, Gereon Kremer
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of solver for handling transcendental functions.
13 **/
14
15 #include "theory/arith/nl/transcendental/sine_solver.h"
16
17 #include <cmath>
18 #include <set>
19
20 #include "expr/node_algorithm.h"
21 #include "expr/node_builder.h"
22 #include "options/arith_options.h"
23 #include "theory/arith/arith_msum.h"
24 #include "theory/arith/arith_utilities.h"
25 #include "theory/rewriter.h"
26
27 namespace CVC4 {
28 namespace theory {
29 namespace arith {
30 namespace nl {
31 namespace transcendental {
32
33 SineSolver::SineSolver(TranscendentalState* tstate) : d_data(tstate) {}
34
35 SineSolver::~SineSolver() {}
36
37 void SineSolver::checkInitialRefine()
38 {
39 NodeManager* nm = NodeManager::currentNM();
40 Trace("nl-ext")
41 << "Get initial refinement lemmas for transcendental functions..."
42 << std::endl;
43 for (std::pair<const Kind, std::vector<Node> >& tfl : d_data->d_funcMap)
44 {
45 if (tfl.first != Kind::SINE)
46 {
47 continue;
48 }
49 for (const Node& t : tfl.second)
50 {
51 // initial refinements
52 if (d_tf_initial_refine.find(t) == d_tf_initial_refine.end())
53 {
54 d_tf_initial_refine[t] = true;
55 Node lem;
56 Node symn = nm->mkNode(Kind::SINE,
57 nm->mkNode(Kind::MULT, d_data->d_neg_one, t[0]));
58 symn = Rewriter::rewrite(symn);
59 // Can assume it is its own master since phase is split over 0,
60 // hence -pi <= t[0] <= pi implies -pi <= -t[0] <= pi.
61 d_data->d_trMaster[symn] = symn;
62 d_data->d_trSlaves[symn].insert(symn);
63 Assert(d_data->d_trSlaves.find(t) != d_data->d_trSlaves.end());
64 std::vector<Node> children;
65
66 lem =
67 nm->mkNode(Kind::AND,
68 // bounds
69 nm->mkNode(Kind::AND,
70 nm->mkNode(Kind::LEQ, t, d_data->d_one),
71 nm->mkNode(Kind::GEQ, t, d_data->d_neg_one)),
72 // symmetry
73 nm->mkNode(Kind::PLUS, t, symn).eqNode(d_data->d_zero),
74 // sign
75 nm->mkNode(Kind::EQUAL,
76 nm->mkNode(Kind::LT, t[0], d_data->d_zero),
77 nm->mkNode(Kind::LT, t, d_data->d_zero)),
78 // zero val
79 nm->mkNode(Kind::EQUAL,
80 nm->mkNode(Kind::GT, t[0], d_data->d_zero),
81 nm->mkNode(Kind::GT, t, d_data->d_zero)));
82 lem = nm->mkNode(
83 Kind::AND,
84 lem,
85 // zero tangent
86 nm->mkNode(Kind::AND,
87 nm->mkNode(Kind::IMPLIES,
88 nm->mkNode(Kind::GT, t[0], d_data->d_zero),
89 nm->mkNode(Kind::LT, t, t[0])),
90 nm->mkNode(Kind::IMPLIES,
91 nm->mkNode(Kind::LT, t[0], d_data->d_zero),
92 nm->mkNode(Kind::GT, t, t[0]))),
93 // pi tangent
94 nm->mkNode(
95 Kind::AND,
96 nm->mkNode(
97 Kind::IMPLIES,
98 nm->mkNode(Kind::LT, t[0], d_data->d_pi),
99 nm->mkNode(Kind::LT,
100 t,
101 nm->mkNode(Kind::MINUS, d_data->d_pi, t[0]))),
102 nm->mkNode(
103 Kind::IMPLIES,
104 nm->mkNode(Kind::GT, t[0], d_data->d_pi_neg),
105 nm->mkNode(
106 Kind::GT,
107 t,
108 nm->mkNode(Kind::MINUS, d_data->d_pi_neg, t[0])))));
109 if (!lem.isNull())
110 {
111 d_data->d_im.addPendingArithLemma(lem, InferenceId::NL_T_INIT_REFINE);
112 }
113 }
114 }
115 }
116 }
117
118 void SineSolver::checkMonotonic()
119 {
120 Trace("nl-ext") << "Get monotonicity lemmas for transcendental functions..."
121 << std::endl;
122
123 auto it = d_data->d_funcMap.find(Kind::SINE);
124 if (it == d_data->d_funcMap.end())
125 {
126 Trace("nl-ext-exp") << "No sine terms" << std::endl;
127 return;
128 }
129
130 // sort arguments of all transcendentals
131 std::vector<Node> tf_args;
132 std::map<Node, Node> tf_arg_to_term;
133
134 for (const Node& tf : it->second)
135 {
136 Node a = tf[0];
137 Node mvaa = d_data->d_model.computeAbstractModelValue(a);
138 if (mvaa.isConst())
139 {
140 Trace("nl-ext-tf-mono-debug") << "...tf term : " << a << std::endl;
141 tf_args.push_back(a);
142 tf_arg_to_term[a] = tf;
143 }
144 }
145
146 if (tf_args.empty())
147 {
148 return;
149 }
150
151 sortByNlModel(
152 tf_args.begin(), tf_args.end(), &d_data->d_model, true, false, true);
153
154 std::vector<Node> mpoints = {d_data->d_pi,
155 d_data->d_pi_2,
156 d_data->d_zero,
157 d_data->d_pi_neg_2,
158 d_data->d_pi_neg};
159 std::vector<Node> mpoints_vals;
160
161 // get model values for points
162 for (const auto& point : mpoints)
163 {
164 mpoints_vals.emplace_back(d_data->d_model.computeAbstractModelValue(point));
165 Assert(mpoints_vals.back().isConst());
166 }
167
168 unsigned mdir_index = 0;
169 int monotonic_dir = -1;
170 Node mono_bounds[2];
171 Node targ, targval, t, tval;
172 for (const auto& sarg : tf_args)
173 {
174 Node sargval = d_data->d_model.computeAbstractModelValue(sarg);
175 Assert(sargval.isConst());
176 Node s = tf_arg_to_term[sarg];
177 Node sval = d_data->d_model.computeAbstractModelValue(s);
178 Assert(sval.isConst());
179
180 // increment to the proper monotonicity region
181 bool increment = true;
182 while (increment && mdir_index < mpoints.size())
183 {
184 increment = false;
185 Node pval = mpoints_vals[mdir_index];
186 Assert(pval.isConst());
187 if (sargval.getConst<Rational>() < pval.getConst<Rational>())
188 {
189 increment = true;
190 Trace("nl-ext-tf-mono")
191 << "...increment at " << sarg << " since model value is less than "
192 << mpoints[mdir_index] << std::endl;
193 }
194 if (increment)
195 {
196 tval = Node::null();
197 mono_bounds[1] = mpoints[mdir_index];
198 mdir_index++;
199 monotonic_dir = regionToMonotonicityDir(mdir_index);
200 if (mdir_index < mpoints.size())
201 {
202 mono_bounds[0] = mpoints[mdir_index];
203 }
204 else
205 {
206 mono_bounds[0] = Node::null();
207 }
208 }
209 }
210 // store the concavity region
211 d_data->d_tf_region[s] = mdir_index;
212 Trace("nl-ext-concavity")
213 << "Transcendental function " << s << " is in region #" << mdir_index;
214 Trace("nl-ext-concavity") << ", arg model value = " << sargval << std::endl;
215
216 if (!tval.isNull())
217 {
218 NodeManager* nm = NodeManager::currentNM();
219 Node mono_lem;
220 if (monotonic_dir == 1
221 && sval.getConst<Rational>() > tval.getConst<Rational>())
222 {
223 mono_lem = nm->mkNode(Kind::IMPLIES,
224 nm->mkNode(Kind::GEQ, targ, sarg),
225 nm->mkNode(Kind::GEQ, t, s));
226 }
227 else if (monotonic_dir == -1
228 && sval.getConst<Rational>() < tval.getConst<Rational>())
229 {
230 mono_lem = nm->mkNode(Kind::IMPLIES,
231 nm->mkNode(Kind::LEQ, targ, sarg),
232 nm->mkNode(Kind::LEQ, t, s));
233 }
234 if (!mono_lem.isNull())
235 {
236 if (!mono_bounds[0].isNull())
237 {
238 Assert(!mono_bounds[1].isNull());
239 mono_lem = nm->mkNode(
240 Kind::IMPLIES,
241 nm->mkNode(Kind::AND,
242 mkBounded(mono_bounds[0], targ, mono_bounds[1]),
243 mkBounded(mono_bounds[0], sarg, mono_bounds[1])),
244 mono_lem);
245 }
246 Trace("nl-ext-tf-mono")
247 << "Monotonicity lemma : " << mono_lem << std::endl;
248
249 d_data->d_im.addPendingArithLemma(mono_lem,
250 InferenceId::NL_T_MONOTONICITY);
251 }
252 }
253 // store the previous values
254 targ = sarg;
255 targval = sargval;
256 t = s;
257 tval = sval;
258 }
259 }
260
261 void SineSolver::doTangentLemma(TNode e, TNode c, TNode poly_approx, int region)
262 {
263 NodeManager* nm = NodeManager::currentNM();
264
265 // compute tangent plane
266 // Figure 3: T( x )
267 // We use zero slope tangent planes, since the concavity of the Taylor
268 // approximation cannot be easily established.
269 Convexity convexity = regionToConvexity(region);
270 int mdir = regionToMonotonicityDir(region);
271 bool usec = (mdir == 1) == (convexity == Convexity::CONCAVE);
272 Node lem = nm->mkNode(
273 Kind::IMPLIES,
274 nm->mkNode(
275 Kind::AND,
276 nm->mkNode(
277 Kind::GEQ, e[0], usec ? Node(c) : regionToLowerBound(region)),
278 nm->mkNode(
279 Kind::LEQ, e[0], usec ? Node(c) : regionToUpperBound(region))),
280 nm->mkNode(convexity == Convexity::CONVEX ? Kind::GEQ : Kind::LEQ,
281 e,
282 poly_approx));
283
284 Trace("nl-ext-sine") << "*** Tangent plane lemma (pre-rewrite): " << lem
285 << std::endl;
286 lem = Rewriter::rewrite(lem);
287 Trace("nl-ext-sine") << "*** Tangent plane lemma : " << lem << std::endl;
288 Assert(d_data->d_model.computeAbstractModelValue(lem) == d_data->d_false);
289 // Figure 3 : line 9
290 d_data->d_im.addPendingArithLemma(lem, InferenceId::NL_T_TANGENT, nullptr, true);
291 }
292
293 void SineSolver::doSecantLemmas(TNode e,
294 TNode poly_approx,
295 TNode c,
296 TNode poly_approx_c,
297 unsigned d,
298 unsigned actual_d,
299 int region)
300 {
301 d_data->doSecantLemmas(getSecantBounds(e, c, d, region),
302 poly_approx,
303 c,
304 poly_approx_c,
305 e,
306 regionToConvexity(region),
307 d,
308 actual_d);
309 }
310
311 std::pair<Node, Node> SineSolver::getSecantBounds(TNode e,
312 TNode c,
313 unsigned d,
314 int region)
315 {
316 std::pair<Node, Node> bounds = d_data->getClosestSecantPoints(e, c, d);
317
318 // Check if we already have neighboring secant points
319 if (bounds.first.isNull())
320 {
321 // lower boundary point for this concavity region
322 bounds.first = regionToLowerBound(region);
323 }
324 if (bounds.second.isNull())
325 {
326 // upper boundary point for this concavity region
327 bounds.second = regionToUpperBound(region);
328 }
329 return bounds;
330 }
331
332 } // namespace transcendental
333 } // namespace nl
334 } // namespace arith
335 } // namespace theory
336 } // namespace CVC4