Update copyright headers.
[cvc5.git] / src / theory / quantifiers / sygus / transition_inference.cpp
1 /********************* */
2 /*! \file transition_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implmentation of utility for inferring whether a synthesis conjecture
13 ** encodes a transition system.
14 **
15 **/
16 #include "theory/quantifiers/sygus/transition_inference.h"
17
18 #include "expr/node_algorithm.h"
19 #include "theory/arith/arith_msum.h"
20 #include "theory/quantifiers/term_util.h"
21
22 using namespace CVC4::kind;
23
24 namespace CVC4 {
25 namespace theory {
26 namespace quantifiers {
27
28 bool DetTrace::DetTraceTrie::add(Node loc, const std::vector<Node>& val)
29 {
30 DetTraceTrie* curr = this;
31 for (const Node& v : val)
32 {
33 curr = &(curr->d_children[v]);
34 }
35 if (curr->d_children.empty())
36 {
37 curr->d_children[loc].clear();
38 return true;
39 }
40 return false;
41 }
42
43 Node DetTrace::DetTraceTrie::constructFormula(const std::vector<Node>& vars,
44 unsigned index)
45 {
46 NodeManager* nm = NodeManager::currentNM();
47 if (index == vars.size())
48 {
49 return nm->mkConst(true);
50 }
51 std::vector<Node> disj;
52 for (std::pair<const Node, DetTraceTrie>& p : d_children)
53 {
54 Node eq = vars[index].eqNode(p.first);
55 if (index < vars.size() - 1)
56 {
57 Node conc = p.second.constructFormula(vars, index + 1);
58 disj.push_back(nm->mkNode(AND, eq, conc));
59 }
60 else
61 {
62 disj.push_back(eq);
63 }
64 }
65 Assert(!disj.empty());
66 return disj.size() == 1 ? disj[0] : nm->mkNode(OR, disj);
67 }
68
69 bool DetTrace::increment(Node loc, std::vector<Node>& vals)
70 {
71 if (d_trie.add(loc, vals))
72 {
73 for (unsigned i = 0, vsize = vals.size(); i < vsize; i++)
74 {
75 d_curr[i] = vals[i];
76 }
77 return true;
78 }
79 return false;
80 }
81
82 Node DetTrace::constructFormula(const std::vector<Node>& vars)
83 {
84 return d_trie.constructFormula(vars);
85 }
86
87 void DetTrace::print(const char* c) const
88 {
89 for (const Node& n : d_curr)
90 {
91 Trace(c) << n << " ";
92 }
93 }
94
95 Node TransitionInference::getFunction() const { return d_func; }
96
97 void TransitionInference::getVariables(std::vector<Node>& vars) const
98 {
99 vars.insert(vars.end(), d_vars.begin(), d_vars.end());
100 }
101
102 Node TransitionInference::getPreCondition() const { return d_pre.d_this; }
103 Node TransitionInference::getPostCondition() const { return d_post.d_this; }
104 Node TransitionInference::getTransitionRelation() const
105 {
106 return d_trans.d_this;
107 }
108
109 void TransitionInference::getConstantSubstitution(
110 const std::vector<Node>& vars,
111 const std::vector<Node>& disjuncts,
112 std::vector<Node>& const_var,
113 std::vector<Node>& const_subs,
114 bool reqPol)
115 {
116 for (const Node& d : disjuncts)
117 {
118 Node sn;
119 if (!const_var.empty())
120 {
121 sn = d.substitute(const_var.begin(),
122 const_var.end(),
123 const_subs.begin(),
124 const_subs.end());
125 sn = Rewriter::rewrite(sn);
126 }
127 else
128 {
129 sn = d;
130 }
131 bool slit_pol = sn.getKind() != NOT;
132 Node slit = sn.getKind() == NOT ? sn[0] : sn;
133 if (slit.getKind() == EQUAL && slit_pol == reqPol)
134 {
135 // check if it is a variable equality
136 TNode v;
137 Node s;
138 for (unsigned r = 0; r < 2; r++)
139 {
140 if (std::find(vars.begin(), vars.end(), slit[r]) != vars.end())
141 {
142 if (!expr::hasSubterm(slit[1 - r], slit[r]))
143 {
144 v = slit[r];
145 s = slit[1 - r];
146 break;
147 }
148 }
149 }
150 if (v.isNull())
151 {
152 // solve for var
153 std::map<Node, Node> msum;
154 if (ArithMSum::getMonomialSumLit(slit, msum))
155 {
156 for (std::pair<const Node, Node>& m : msum)
157 {
158 if (std::find(vars.begin(), vars.end(), m.first) != vars.end())
159 {
160 Node veq_c;
161 Node val;
162 int ires = ArithMSum::isolate(m.first, msum, veq_c, val, EQUAL);
163 if (ires != 0 && veq_c.isNull()
164 && !expr::hasSubterm(val, m.first))
165 {
166 v = m.first;
167 s = val;
168 }
169 }
170 }
171 }
172 }
173 if (!v.isNull())
174 {
175 TNode ts = s;
176 for (unsigned k = 0, csize = const_subs.size(); k < csize; k++)
177 {
178 const_subs[k] = Rewriter::rewrite(const_subs[k].substitute(v, ts));
179 }
180 Trace("cegqi-inv-debug2")
181 << "...substitution : " << v << " -> " << s << std::endl;
182 const_var.push_back(v);
183 const_subs.push_back(s);
184 }
185 }
186 }
187 }
188
189 void TransitionInference::process(Node n, Node f)
190 {
191 // set the function
192 d_func = f;
193 process(n);
194 }
195
196 void TransitionInference::process(Node n)
197 {
198 NodeManager* nm = NodeManager::currentNM();
199 d_complete = true;
200 d_trivial = true;
201 std::vector<Node> n_check;
202 if (n.getKind() == AND)
203 {
204 for (const Node& nc : n)
205 {
206 n_check.push_back(nc);
207 }
208 }
209 else
210 {
211 n_check.push_back(n);
212 }
213 for (const Node& nn : n_check)
214 {
215 std::map<bool, std::map<Node, bool> > visited;
216 std::map<bool, Node> terms;
217 std::vector<Node> disjuncts;
218 Trace("cegqi-inv") << "TransitionInference : Process disjunct : " << nn
219 << std::endl;
220 if (!processDisjunct(nn, terms, disjuncts, visited, true))
221 {
222 d_complete = false;
223 continue;
224 }
225 if (terms.empty())
226 {
227 continue;
228 }
229 Node curr;
230 // The component that this disjunct contributes to, where
231 // 1 : pre-condition, -1 : post-condition, 0 : transition relation
232 int comp_num;
233 std::map<bool, Node>::iterator itt = terms.find(false);
234 if (itt != terms.end())
235 {
236 curr = itt->second;
237 if (terms.find(true) != terms.end())
238 {
239 comp_num = 0;
240 }
241 else
242 {
243 comp_num = -1;
244 }
245 }
246 else
247 {
248 curr = terms[true];
249 comp_num = 1;
250 }
251 Trace("cegqi-inv-debug2") << " normalize based on " << curr << std::endl;
252 std::vector<Node> vars;
253 std::vector<Node> svars;
254 getNormalizedSubstitution(curr, d_vars, vars, svars, disjuncts);
255 for (unsigned j = 0, dsize = disjuncts.size(); j < dsize; j++)
256 {
257 Trace("cegqi-inv-debug2") << " apply " << disjuncts[j] << std::endl;
258 disjuncts[j] = Rewriter::rewrite(disjuncts[j].substitute(
259 vars.begin(), vars.end(), svars.begin(), svars.end()));
260 Trace("cegqi-inv-debug2") << " ..." << disjuncts[j] << std::endl;
261 }
262 std::vector<Node> const_var;
263 std::vector<Node> const_subs;
264 if (comp_num == 0)
265 {
266 // transition
267 Assert(terms.find(true) != terms.end());
268 Node next = terms[true];
269 next = Rewriter::rewrite(next.substitute(
270 vars.begin(), vars.end(), svars.begin(), svars.end()));
271 Trace("cegqi-inv-debug")
272 << "transition next predicate : " << next << std::endl;
273 // make the primed variables if we have not already
274 if (d_prime_vars.empty())
275 {
276 for (unsigned j = 0, nchild = next.getNumChildren(); j < nchild; j++)
277 {
278 Node v = nm->mkSkolem(
279 "ir", next[j].getType(), "template inference rev argument");
280 d_prime_vars.push_back(v);
281 }
282 }
283 // normalize the other direction
284 Trace("cegqi-inv-debug2") << " normalize based on " << next << std::endl;
285 std::vector<Node> rvars;
286 std::vector<Node> rsvars;
287 getNormalizedSubstitution(next, d_prime_vars, rvars, rsvars, disjuncts);
288 Assert(rvars.size() == rsvars.size());
289 for (unsigned j = 0, dsize = disjuncts.size(); j < dsize; j++)
290 {
291 Trace("cegqi-inv-debug2") << " apply " << disjuncts[j] << std::endl;
292 disjuncts[j] = Rewriter::rewrite(disjuncts[j].substitute(
293 rvars.begin(), rvars.end(), rsvars.begin(), rsvars.end()));
294 Trace("cegqi-inv-debug2") << " ..." << disjuncts[j] << std::endl;
295 }
296 getConstantSubstitution(
297 d_prime_vars, disjuncts, const_var, const_subs, false);
298 }
299 else
300 {
301 getConstantSubstitution(d_vars, disjuncts, const_var, const_subs, false);
302 }
303 Node res;
304 if (disjuncts.empty())
305 {
306 res = nm->mkConst(false);
307 }
308 else if (disjuncts.size() == 1)
309 {
310 res = disjuncts[0];
311 }
312 else
313 {
314 res = nm->mkNode(OR, disjuncts);
315 }
316 if (expr::hasBoundVar(res))
317 {
318 Trace("cegqi-inv-debug2") << "...failed, free variable." << std::endl;
319 d_complete = false;
320 continue;
321 }
322 Trace("cegqi-inv") << "*** inferred "
323 << (comp_num == 1 ? "pre"
324 : (comp_num == -1 ? "post" : "trans"))
325 << "-condition : " << res << std::endl;
326 Component& c =
327 (comp_num == 1 ? d_pre : (comp_num == -1 ? d_post : d_trans));
328 c.d_conjuncts.push_back(res);
329 if (!const_var.empty())
330 {
331 bool has_const_eq = const_var.size() == d_vars.size();
332 Trace("cegqi-inv") << " with constant substitution, complete = "
333 << has_const_eq << " : " << std::endl;
334 for (unsigned i = 0, csize = const_var.size(); i < csize; i++)
335 {
336 Trace("cegqi-inv") << " " << const_var[i] << " -> "
337 << const_subs[i] << std::endl;
338 if (has_const_eq)
339 {
340 c.d_const_eq[res][const_var[i]] = const_subs[i];
341 }
342 }
343 Trace("cegqi-inv") << "...size = " << const_var.size()
344 << ", #vars = " << d_vars.size() << std::endl;
345 }
346 }
347
348 // finalize the components
349 for (int i = -1; i <= 1; i++)
350 {
351 Component& c = (i == 1 ? d_pre : (i == -1 ? d_post : d_trans));
352 Node ret;
353 if (c.d_conjuncts.empty())
354 {
355 ret = nm->mkConst(true);
356 }
357 else if (c.d_conjuncts.size() == 1)
358 {
359 ret = c.d_conjuncts[0];
360 }
361 else
362 {
363 ret = nm->mkNode(AND, c.d_conjuncts);
364 }
365 if (i == 0 || i == 1)
366 {
367 // pre-condition and transition are negated
368 ret = TermUtil::simpleNegate(ret);
369 }
370 c.d_this = ret;
371 }
372 }
373 void TransitionInference::getNormalizedSubstitution(
374 Node curr,
375 const std::vector<Node>& pvars,
376 std::vector<Node>& vars,
377 std::vector<Node>& subs,
378 std::vector<Node>& disjuncts)
379 {
380 for (unsigned j = 0, nchild = curr.getNumChildren(); j < nchild; j++)
381 {
382 if (curr[j].getKind() == BOUND_VARIABLE)
383 {
384 // if the argument is a bound variable, add to the renaming
385 vars.push_back(curr[j]);
386 subs.push_back(pvars[j]);
387 }
388 else
389 {
390 // otherwise, treat as a constraint on the variable
391 // For example, this transforms e.g. a precondition clause
392 // I( 0, 1 ) to x1 != 0 OR x2 != 1 OR I( x1, x2 ).
393 Node eq = curr[j].eqNode(pvars[j]);
394 disjuncts.push_back(eq.negate());
395 }
396 }
397 }
398
399 bool TransitionInference::processDisjunct(
400 Node n,
401 std::map<bool, Node>& terms,
402 std::vector<Node>& disjuncts,
403 std::map<bool, std::map<Node, bool> >& visited,
404 bool topLevel)
405 {
406 if (visited[topLevel].find(n) != visited[topLevel].end())
407 {
408 return true;
409 }
410 visited[topLevel][n] = true;
411 bool childTopLevel = n.getKind() == OR && topLevel;
412 // if another part mentions UF or a free variable, then fail
413 bool lit_pol = n.getKind() != NOT;
414 Node lit = n.getKind() == NOT ? n[0] : n;
415 // is it an application of the function-to-synthesize? Yes if we haven't
416 // encountered a function or if it matches the existing d_func.
417 if (lit.getKind() == APPLY_UF
418 && (d_func.isNull() || lit.getOperator() == d_func))
419 {
420 Node op = lit.getOperator();
421 // initialize the variables
422 if (d_trivial)
423 {
424 d_trivial = false;
425 d_func = op;
426 Trace("cegqi-inv-debug") << "Use " << op << " with args ";
427 NodeManager* nm = NodeManager::currentNM();
428 for (const Node& l : lit)
429 {
430 Node v = nm->mkSkolem("i", l.getType(), "template inference argument");
431 d_vars.push_back(v);
432 Trace("cegqi-inv-debug") << v << " ";
433 }
434 Trace("cegqi-inv-debug") << std::endl;
435 }
436 Assert(!d_func.isNull());
437 if (topLevel)
438 {
439 if (terms.find(lit_pol) == terms.end())
440 {
441 terms[lit_pol] = lit;
442 return true;
443 }
444 else
445 {
446 Trace("cegqi-inv-debug")
447 << "...failed, repeated inv-app : " << lit << std::endl;
448 return false;
449 }
450 }
451 Trace("cegqi-inv-debug")
452 << "...failed, non-entailed inv-app : " << lit << std::endl;
453 return false;
454 }
455 else if (topLevel && !childTopLevel)
456 {
457 disjuncts.push_back(n);
458 }
459 for (const Node& nc : n)
460 {
461 if (!processDisjunct(nc, terms, disjuncts, visited, childTopLevel))
462 {
463 return false;
464 }
465 }
466 return true;
467 }
468
469 TraceIncStatus TransitionInference::initializeTrace(DetTrace& dt,
470 Node loc,
471 bool fwd)
472 {
473 Component& c = fwd ? d_pre : d_post;
474 Assert(c.has(loc));
475 std::map<Node, std::map<Node, Node> >::iterator it = c.d_const_eq.find(loc);
476 if (it != c.d_const_eq.end())
477 {
478 std::vector<Node> next;
479 for (const Node& v : d_vars)
480 {
481 Assert(it->second.find(v) != it->second.end());
482 next.push_back(it->second[v]);
483 dt.d_curr.push_back(it->second[v]);
484 }
485 Trace("cegqi-inv-debug2") << "dtrace : initial increment" << std::endl;
486 bool ret = dt.increment(loc, next);
487 AlwaysAssert(ret);
488 return TRACE_INC_SUCCESS;
489 }
490 return TRACE_INC_INVALID;
491 }
492
493 TraceIncStatus TransitionInference::incrementTrace(DetTrace& dt,
494 Node loc,
495 bool fwd)
496 {
497 Assert(d_trans.has(loc));
498 // check if it satisfies the pre/post condition
499 Node cc = fwd ? getPostCondition() : getPreCondition();
500 Assert(!cc.isNull());
501 Node ccr = Rewriter::rewrite(cc.substitute(
502 d_vars.begin(), d_vars.end(), dt.d_curr.begin(), dt.d_curr.end()));
503 if (ccr.isConst())
504 {
505 if (ccr.getConst<bool>() == (fwd ? false : true))
506 {
507 Trace("cegqi-inv-debug2") << "dtrace : counterexample" << std::endl;
508 return TRACE_INC_CEX;
509 }
510 }
511
512 // terminates?
513 Node c = getTransitionRelation();
514 Assert(!c.isNull());
515
516 Assert(d_vars.size() == dt.d_curr.size());
517 Node cr = Rewriter::rewrite(c.substitute(
518 d_vars.begin(), d_vars.end(), dt.d_curr.begin(), dt.d_curr.end()));
519 if (cr.isConst())
520 {
521 if (!cr.getConst<bool>())
522 {
523 Trace("cegqi-inv-debug2") << "dtrace : terminated" << std::endl;
524 return TRACE_INC_TERMINATE;
525 }
526 return TRACE_INC_INVALID;
527 }
528 if (!fwd)
529 {
530 // only implemented in forward direction
531 Assert(false);
532 return TRACE_INC_INVALID;
533 }
534 Component& cm = d_trans;
535 std::map<Node, std::map<Node, Node> >::iterator it = cm.d_const_eq.find(loc);
536 if (it == cm.d_const_eq.end())
537 {
538 return TRACE_INC_INVALID;
539 }
540 std::vector<Node> next;
541 for (const Node& pv : d_prime_vars)
542 {
543 Assert(it->second.find(pv) != it->second.end());
544 Node pvs = it->second[pv];
545 Assert(d_vars.size() == dt.d_curr.size());
546 Node pvsr = Rewriter::rewrite(pvs.substitute(
547 d_vars.begin(), d_vars.end(), dt.d_curr.begin(), dt.d_curr.end()));
548 next.push_back(pvsr);
549 }
550 if (dt.increment(loc, next))
551 {
552 Trace("cegqi-inv-debug2") << "dtrace : success increment" << std::endl;
553 return TRACE_INC_SUCCESS;
554 }
555 // looped
556 Trace("cegqi-inv-debug2") << "dtrace : looped" << std::endl;
557 return TRACE_INC_TERMINATE;
558 }
559
560 TraceIncStatus TransitionInference::initializeTrace(DetTrace& dt, bool fwd)
561 {
562 Trace("cegqi-inv-debug2") << "Initialize trace" << std::endl;
563 Component& c = fwd ? d_pre : d_post;
564 if (c.d_conjuncts.size() == 1)
565 {
566 return initializeTrace(dt, c.d_conjuncts[0], fwd);
567 }
568 return TRACE_INC_INVALID;
569 }
570
571 TraceIncStatus TransitionInference::incrementTrace(DetTrace& dt, bool fwd)
572 {
573 if (d_trans.d_conjuncts.size() == 1)
574 {
575 return incrementTrace(dt, d_trans.d_conjuncts[0], fwd);
576 }
577 return TRACE_INC_INVALID;
578 }
579
580 Node TransitionInference::constructFormulaTrace(DetTrace& dt) const
581 {
582 return dt.constructFormula(d_vars);
583 }
584
585 } // namespace quantifiers
586 } // namespace theory
587 } // namespace CVC4