1 /********************* */
2 /*! \file transition_inference.cpp
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
12 ** \brief Implmentation of utility for inferring whether a synthesis conjecture
13 ** encodes a transition system.
16 #include "theory/quantifiers/sygus/transition_inference.h"
18 #include "expr/node_algorithm.h"
19 #include "theory/arith/arith_msum.h"
20 #include "theory/quantifiers/term_util.h"
22 using namespace CVC4::kind
;
26 namespace quantifiers
{
28 bool DetTrace::DetTraceTrie::add(Node loc
, const std::vector
<Node
>& val
)
30 DetTraceTrie
* curr
= this;
31 for (const Node
& v
: val
)
33 curr
= &(curr
->d_children
[v
]);
35 if (curr
->d_children
.empty())
37 curr
->d_children
[loc
].clear();
43 Node
DetTrace::DetTraceTrie::constructFormula(const std::vector
<Node
>& vars
,
46 NodeManager
* nm
= NodeManager::currentNM();
47 if (index
== vars
.size())
49 return nm
->mkConst(true);
51 std::vector
<Node
> disj
;
52 for (std::pair
<const Node
, DetTraceTrie
>& p
: d_children
)
54 Node eq
= vars
[index
].eqNode(p
.first
);
55 if (index
< vars
.size() - 1)
57 Node conc
= p
.second
.constructFormula(vars
, index
+ 1);
58 disj
.push_back(nm
->mkNode(AND
, eq
, conc
));
65 Assert(!disj
.empty());
66 return disj
.size() == 1 ? disj
[0] : nm
->mkNode(OR
, disj
);
69 bool DetTrace::increment(Node loc
, std::vector
<Node
>& vals
)
71 if (d_trie
.add(loc
, vals
))
73 for (unsigned i
= 0, vsize
= vals
.size(); i
< vsize
; i
++)
82 Node
DetTrace::constructFormula(const std::vector
<Node
>& vars
)
84 return d_trie
.constructFormula(vars
);
87 void DetTrace::print(const char* c
) const
89 for (const Node
& n
: d_curr
)
95 Node
TransitionInference::getFunction() const { return d_func
; }
97 void TransitionInference::getVariables(std::vector
<Node
>& vars
) const
99 vars
.insert(vars
.end(), d_vars
.begin(), d_vars
.end());
102 Node
TransitionInference::getPreCondition() const { return d_pre
.d_this
; }
103 Node
TransitionInference::getPostCondition() const { return d_post
.d_this
; }
104 Node
TransitionInference::getTransitionRelation() const
106 return d_trans
.d_this
;
109 void TransitionInference::getConstantSubstitution(
110 const std::vector
<Node
>& vars
,
111 const std::vector
<Node
>& disjuncts
,
112 std::vector
<Node
>& const_var
,
113 std::vector
<Node
>& const_subs
,
116 for (const Node
& d
: disjuncts
)
119 if (!const_var
.empty())
121 sn
= d
.substitute(const_var
.begin(),
125 sn
= Rewriter::rewrite(sn
);
131 bool slit_pol
= sn
.getKind() != NOT
;
132 Node slit
= sn
.getKind() == NOT
? sn
[0] : sn
;
133 if (slit
.getKind() == EQUAL
&& slit_pol
== reqPol
)
135 // check if it is a variable equality
138 for (unsigned r
= 0; r
< 2; r
++)
140 if (std::find(vars
.begin(), vars
.end(), slit
[r
]) != vars
.end())
142 if (!expr::hasSubterm(slit
[1 - r
], slit
[r
]))
153 std::map
<Node
, Node
> msum
;
154 if (ArithMSum::getMonomialSumLit(slit
, msum
))
156 for (std::pair
<const Node
, Node
>& m
: msum
)
158 if (std::find(vars
.begin(), vars
.end(), m
.first
) != vars
.end())
162 int ires
= ArithMSum::isolate(m
.first
, msum
, veq_c
, val
, EQUAL
);
163 if (ires
!= 0 && veq_c
.isNull()
164 && !expr::hasSubterm(val
, m
.first
))
176 for (unsigned k
= 0, csize
= const_subs
.size(); k
< csize
; k
++)
178 const_subs
[k
] = Rewriter::rewrite(const_subs
[k
].substitute(v
, ts
));
180 Trace("cegqi-inv-debug2")
181 << "...substitution : " << v
<< " -> " << s
<< std::endl
;
182 const_var
.push_back(v
);
183 const_subs
.push_back(s
);
189 void TransitionInference::process(Node n
, Node f
)
196 void TransitionInference::process(Node n
)
198 NodeManager
* nm
= NodeManager::currentNM();
201 std::vector
<Node
> n_check
;
202 if (n
.getKind() == AND
)
204 for (const Node
& nc
: n
)
206 n_check
.push_back(nc
);
211 n_check
.push_back(n
);
213 for (const Node
& nn
: n_check
)
215 std::map
<bool, std::map
<Node
, bool> > visited
;
216 std::map
<bool, Node
> terms
;
217 std::vector
<Node
> disjuncts
;
218 Trace("cegqi-inv") << "TransitionInference : Process disjunct : " << nn
220 if (!processDisjunct(nn
, terms
, disjuncts
, visited
, true))
230 // The component that this disjunct contributes to, where
231 // 1 : pre-condition, -1 : post-condition, 0 : transition relation
233 std::map
<bool, Node
>::iterator itt
= terms
.find(false);
234 if (itt
!= terms
.end())
237 if (terms
.find(true) != terms
.end())
251 Trace("cegqi-inv-debug2") << " normalize based on " << curr
<< std::endl
;
252 std::vector
<Node
> vars
;
253 std::vector
<Node
> svars
;
254 getNormalizedSubstitution(curr
, d_vars
, vars
, svars
, disjuncts
);
255 for (unsigned j
= 0, dsize
= disjuncts
.size(); j
< dsize
; j
++)
257 Trace("cegqi-inv-debug2") << " apply " << disjuncts
[j
] << std::endl
;
258 disjuncts
[j
] = Rewriter::rewrite(disjuncts
[j
].substitute(
259 vars
.begin(), vars
.end(), svars
.begin(), svars
.end()));
260 Trace("cegqi-inv-debug2") << " ..." << disjuncts
[j
] << std::endl
;
262 std::vector
<Node
> const_var
;
263 std::vector
<Node
> const_subs
;
267 Assert(terms
.find(true) != terms
.end());
268 Node next
= terms
[true];
269 next
= Rewriter::rewrite(next
.substitute(
270 vars
.begin(), vars
.end(), svars
.begin(), svars
.end()));
271 Trace("cegqi-inv-debug")
272 << "transition next predicate : " << next
<< std::endl
;
273 // make the primed variables if we have not already
274 if (d_prime_vars
.empty())
276 for (unsigned j
= 0, nchild
= next
.getNumChildren(); j
< nchild
; j
++)
278 Node v
= nm
->mkSkolem(
279 "ir", next
[j
].getType(), "template inference rev argument");
280 d_prime_vars
.push_back(v
);
283 // normalize the other direction
284 Trace("cegqi-inv-debug2") << " normalize based on " << next
<< std::endl
;
285 std::vector
<Node
> rvars
;
286 std::vector
<Node
> rsvars
;
287 getNormalizedSubstitution(next
, d_prime_vars
, rvars
, rsvars
, disjuncts
);
288 Assert(rvars
.size() == rsvars
.size());
289 for (unsigned j
= 0, dsize
= disjuncts
.size(); j
< dsize
; j
++)
291 Trace("cegqi-inv-debug2") << " apply " << disjuncts
[j
] << std::endl
;
292 disjuncts
[j
] = Rewriter::rewrite(disjuncts
[j
].substitute(
293 rvars
.begin(), rvars
.end(), rsvars
.begin(), rsvars
.end()));
294 Trace("cegqi-inv-debug2") << " ..." << disjuncts
[j
] << std::endl
;
296 getConstantSubstitution(
297 d_prime_vars
, disjuncts
, const_var
, const_subs
, false);
301 getConstantSubstitution(d_vars
, disjuncts
, const_var
, const_subs
, false);
304 if (disjuncts
.empty())
306 res
= nm
->mkConst(false);
308 else if (disjuncts
.size() == 1)
314 res
= nm
->mkNode(OR
, disjuncts
);
316 if (expr::hasBoundVar(res
))
318 Trace("cegqi-inv-debug2") << "...failed, free variable." << std::endl
;
322 Trace("cegqi-inv") << "*** inferred "
323 << (comp_num
== 1 ? "pre"
324 : (comp_num
== -1 ? "post" : "trans"))
325 << "-condition : " << res
<< std::endl
;
327 (comp_num
== 1 ? d_pre
: (comp_num
== -1 ? d_post
: d_trans
));
328 c
.d_conjuncts
.push_back(res
);
329 if (!const_var
.empty())
331 bool has_const_eq
= const_var
.size() == d_vars
.size();
332 Trace("cegqi-inv") << " with constant substitution, complete = "
333 << has_const_eq
<< " : " << std::endl
;
334 for (unsigned i
= 0, csize
= const_var
.size(); i
< csize
; i
++)
336 Trace("cegqi-inv") << " " << const_var
[i
] << " -> "
337 << const_subs
[i
] << std::endl
;
340 c
.d_const_eq
[res
][const_var
[i
]] = const_subs
[i
];
343 Trace("cegqi-inv") << "...size = " << const_var
.size()
344 << ", #vars = " << d_vars
.size() << std::endl
;
348 // finalize the components
349 for (int i
= -1; i
<= 1; i
++)
351 Component
& c
= (i
== 1 ? d_pre
: (i
== -1 ? d_post
: d_trans
));
353 if (c
.d_conjuncts
.empty())
355 ret
= nm
->mkConst(true);
357 else if (c
.d_conjuncts
.size() == 1)
359 ret
= c
.d_conjuncts
[0];
363 ret
= nm
->mkNode(AND
, c
.d_conjuncts
);
365 if (i
== 0 || i
== 1)
367 // pre-condition and transition are negated
368 ret
= TermUtil::simpleNegate(ret
);
373 void TransitionInference::getNormalizedSubstitution(
375 const std::vector
<Node
>& pvars
,
376 std::vector
<Node
>& vars
,
377 std::vector
<Node
>& subs
,
378 std::vector
<Node
>& disjuncts
)
380 for (unsigned j
= 0, nchild
= curr
.getNumChildren(); j
< nchild
; j
++)
382 if (curr
[j
].getKind() == BOUND_VARIABLE
)
384 // if the argument is a bound variable, add to the renaming
385 vars
.push_back(curr
[j
]);
386 subs
.push_back(pvars
[j
]);
390 // otherwise, treat as a constraint on the variable
391 // For example, this transforms e.g. a precondition clause
392 // I( 0, 1 ) to x1 != 0 OR x2 != 1 OR I( x1, x2 ).
393 Node eq
= curr
[j
].eqNode(pvars
[j
]);
394 disjuncts
.push_back(eq
.negate());
399 bool TransitionInference::processDisjunct(
401 std::map
<bool, Node
>& terms
,
402 std::vector
<Node
>& disjuncts
,
403 std::map
<bool, std::map
<Node
, bool> >& visited
,
406 if (visited
[topLevel
].find(n
) != visited
[topLevel
].end())
410 visited
[topLevel
][n
] = true;
411 bool childTopLevel
= n
.getKind() == OR
&& topLevel
;
412 // if another part mentions UF or a free variable, then fail
413 bool lit_pol
= n
.getKind() != NOT
;
414 Node lit
= n
.getKind() == NOT
? n
[0] : n
;
415 // is it an application of the function-to-synthesize? Yes if we haven't
416 // encountered a function or if it matches the existing d_func.
417 if (lit
.getKind() == APPLY_UF
418 && (d_func
.isNull() || lit
.getOperator() == d_func
))
420 Node op
= lit
.getOperator();
421 // initialize the variables
426 Trace("cegqi-inv-debug") << "Use " << op
<< " with args ";
427 NodeManager
* nm
= NodeManager::currentNM();
428 for (const Node
& l
: lit
)
430 Node v
= nm
->mkSkolem("i", l
.getType(), "template inference argument");
432 Trace("cegqi-inv-debug") << v
<< " ";
434 Trace("cegqi-inv-debug") << std::endl
;
436 Assert(!d_func
.isNull());
439 if (terms
.find(lit_pol
) == terms
.end())
441 terms
[lit_pol
] = lit
;
446 Trace("cegqi-inv-debug")
447 << "...failed, repeated inv-app : " << lit
<< std::endl
;
451 Trace("cegqi-inv-debug")
452 << "...failed, non-entailed inv-app : " << lit
<< std::endl
;
455 else if (topLevel
&& !childTopLevel
)
457 disjuncts
.push_back(n
);
459 for (const Node
& nc
: n
)
461 if (!processDisjunct(nc
, terms
, disjuncts
, visited
, childTopLevel
))
469 TraceIncStatus
TransitionInference::initializeTrace(DetTrace
& dt
,
473 Component
& c
= fwd
? d_pre
: d_post
;
475 std::map
<Node
, std::map
<Node
, Node
> >::iterator it
= c
.d_const_eq
.find(loc
);
476 if (it
!= c
.d_const_eq
.end())
478 std::vector
<Node
> next
;
479 for (const Node
& v
: d_vars
)
481 Assert(it
->second
.find(v
) != it
->second
.end());
482 next
.push_back(it
->second
[v
]);
483 dt
.d_curr
.push_back(it
->second
[v
]);
485 Trace("cegqi-inv-debug2") << "dtrace : initial increment" << std::endl
;
486 bool ret
= dt
.increment(loc
, next
);
488 return TRACE_INC_SUCCESS
;
490 return TRACE_INC_INVALID
;
493 TraceIncStatus
TransitionInference::incrementTrace(DetTrace
& dt
,
497 Assert(d_trans
.has(loc
));
498 // check if it satisfies the pre/post condition
499 Node cc
= fwd
? getPostCondition() : getPreCondition();
500 Assert(!cc
.isNull());
501 Node ccr
= Rewriter::rewrite(cc
.substitute(
502 d_vars
.begin(), d_vars
.end(), dt
.d_curr
.begin(), dt
.d_curr
.end()));
505 if (ccr
.getConst
<bool>() == (fwd
? false : true))
507 Trace("cegqi-inv-debug2") << "dtrace : counterexample" << std::endl
;
508 return TRACE_INC_CEX
;
513 Node c
= getTransitionRelation();
516 Assert(d_vars
.size() == dt
.d_curr
.size());
517 Node cr
= Rewriter::rewrite(c
.substitute(
518 d_vars
.begin(), d_vars
.end(), dt
.d_curr
.begin(), dt
.d_curr
.end()));
521 if (!cr
.getConst
<bool>())
523 Trace("cegqi-inv-debug2") << "dtrace : terminated" << std::endl
;
524 return TRACE_INC_TERMINATE
;
526 return TRACE_INC_INVALID
;
530 // only implemented in forward direction
532 return TRACE_INC_INVALID
;
534 Component
& cm
= d_trans
;
535 std::map
<Node
, std::map
<Node
, Node
> >::iterator it
= cm
.d_const_eq
.find(loc
);
536 if (it
== cm
.d_const_eq
.end())
538 return TRACE_INC_INVALID
;
540 std::vector
<Node
> next
;
541 for (const Node
& pv
: d_prime_vars
)
543 Assert(it
->second
.find(pv
) != it
->second
.end());
544 Node pvs
= it
->second
[pv
];
545 Assert(d_vars
.size() == dt
.d_curr
.size());
546 Node pvsr
= Rewriter::rewrite(pvs
.substitute(
547 d_vars
.begin(), d_vars
.end(), dt
.d_curr
.begin(), dt
.d_curr
.end()));
548 next
.push_back(pvsr
);
550 if (dt
.increment(loc
, next
))
552 Trace("cegqi-inv-debug2") << "dtrace : success increment" << std::endl
;
553 return TRACE_INC_SUCCESS
;
556 Trace("cegqi-inv-debug2") << "dtrace : looped" << std::endl
;
557 return TRACE_INC_TERMINATE
;
560 TraceIncStatus
TransitionInference::initializeTrace(DetTrace
& dt
, bool fwd
)
562 Trace("cegqi-inv-debug2") << "Initialize trace" << std::endl
;
563 Component
& c
= fwd
? d_pre
: d_post
;
564 if (c
.d_conjuncts
.size() == 1)
566 return initializeTrace(dt
, c
.d_conjuncts
[0], fwd
);
568 return TRACE_INC_INVALID
;
571 TraceIncStatus
TransitionInference::incrementTrace(DetTrace
& dt
, bool fwd
)
573 if (d_trans
.d_conjuncts
.size() == 1)
575 return incrementTrace(dt
, d_trans
.d_conjuncts
[0], fwd
);
577 return TRACE_INC_INVALID
;
580 Node
TransitionInference::constructFormulaTrace(DetTrace
& dt
) const
582 return dt
.constructFormula(d_vars
);
585 } // namespace quantifiers
586 } // namespace theory