bv2int: improving bvand tables (#5235)
[cvc5.git] / src / preprocessing / passes / sygus_inference.cpp
1 /********************* */
2 /*! \file sygus_inference.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Andrew Reynolds, Mathias Preiner, Andres Noetzli
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Sygus inference module
13 **/
14
15 #include "preprocessing/passes/sygus_inference.h"
16
17 #include "smt/smt_engine.h"
18 #include "smt/smt_engine_scope.h"
19 #include "smt/smt_statistics_registry.h"
20 #include "theory/quantifiers/quantifiers_attributes.h"
21 #include "theory/quantifiers/quantifiers_rewriter.h"
22 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
23 #include "theory/smt_engine_subsolver.h"
24
25 using namespace std;
26 using namespace CVC4::kind;
27 using namespace CVC4::theory;
28
29 namespace CVC4 {
30 namespace preprocessing {
31 namespace passes {
32
33 SygusInference::SygusInference(PreprocessingPassContext* preprocContext)
34 : PreprocessingPass(preprocContext, "sygus-infer"){};
35
36 PreprocessingPassResult SygusInference::applyInternal(
37 AssertionPipeline* assertionsToPreprocess)
38 {
39 Trace("sygus-infer") << "Run sygus inference..." << std::endl;
40 std::vector<Node> funs;
41 std::vector<Node> sols;
42 // see if we can succesfully solve the input as a sygus problem
43 if (solveSygus(assertionsToPreprocess->ref(), funs, sols))
44 {
45 Assert(funs.size() == sols.size());
46 // if so, sygus gives us function definitions
47 SmtEngine* master_smte = d_preprocContext->getSmt();
48 for (unsigned i = 0, size = funs.size(); i < size; i++)
49 {
50 std::vector<Node> args;
51 Node sol = sols[i];
52 // if it is a non-constant function
53 if (sol.getKind() == LAMBDA)
54 {
55 for (const Node& v : sol[0])
56 {
57 args.push_back(v);
58 }
59 sol = sol[1];
60 }
61 master_smte->defineFunction(funs[i], args, sol);
62 }
63
64 // apply substitution to everything, should result in SAT
65 for (unsigned i = 0, size = assertionsToPreprocess->ref().size(); i < size;
66 i++)
67 {
68 Node prev = (*assertionsToPreprocess)[i];
69 Node curr =
70 prev.substitute(funs.begin(), funs.end(), sols.begin(), sols.end());
71 if (curr != prev)
72 {
73 curr = theory::Rewriter::rewrite(curr);
74 Trace("sygus-infer-debug")
75 << "...rewrote " << prev << " to " << curr << std::endl;
76 assertionsToPreprocess->replace(i, curr);
77 }
78 }
79 }
80 return PreprocessingPassResult::NO_CONFLICT;
81 }
82
83 bool SygusInference::solveSygus(const std::vector<Node>& assertions,
84 std::vector<Node>& funs,
85 std::vector<Node>& sols)
86 {
87 if (assertions.empty())
88 {
89 Trace("sygus-infer") << "...fail: empty assertions." << std::endl;
90 return false;
91 }
92
93 NodeManager* nm = NodeManager::currentNM();
94
95 // collect free variables in all assertions
96 std::vector<Node> qvars;
97 std::map<TypeNode, std::vector<Node> > qtvars;
98 std::vector<Node> free_functions;
99
100 std::vector<TNode> visit;
101 std::unordered_set<TNode, TNodeHashFunction> visited;
102
103 // add top-level conjuncts to eassertions
104 std::vector<Node> assertions_proc = assertions;
105 std::vector<Node> eassertions;
106 unsigned index = 0;
107 while (index < assertions_proc.size())
108 {
109 Node ca = assertions_proc[index];
110 if (ca.getKind() == AND)
111 {
112 for (const Node& ai : ca)
113 {
114 assertions_proc.push_back(ai);
115 }
116 }
117 else
118 {
119 eassertions.push_back(ca);
120 }
121 index++;
122 }
123
124 // process eassertions
125 std::vector<Node> processed_assertions;
126 for (const Node& as : eassertions)
127 {
128 // substitution for this assertion
129 std::vector<Node> vars;
130 std::vector<Node> subs;
131 std::map<TypeNode, unsigned> type_count;
132 Node pas = as;
133 // rewrite
134 pas = theory::Rewriter::rewrite(pas);
135 Trace("sygus-infer") << "assertion : " << pas << std::endl;
136 if (pas.getKind() == FORALL)
137 {
138 // preprocess the quantified formula
139 TrustNode trn = quantifiers::QuantifiersRewriter::preprocess(pas);
140 if (!trn.isNull())
141 {
142 pas = trn.getNode();
143 }
144 Trace("sygus-infer-debug") << " ...preprocessed to " << pas << std::endl;
145 }
146 if (pas.getKind() == FORALL)
147 {
148 // it must be a standard quantifier
149 theory::quantifiers::QAttributes qa;
150 theory::quantifiers::QuantAttributes::computeQuantAttributes(pas, qa);
151 if (!qa.isStandard())
152 {
153 Trace("sygus-infer")
154 << "...fail: non-standard top-level quantifier." << std::endl;
155 return false;
156 }
157 // infer prefix
158 for (const Node& v : pas[0])
159 {
160 TypeNode tnv = v.getType();
161 unsigned vnum = type_count[tnv];
162 type_count[tnv]++;
163 vars.push_back(v);
164 if (vnum < qtvars[tnv].size())
165 {
166 subs.push_back(qtvars[tnv][vnum]);
167 }
168 else
169 {
170 Assert(vnum == qtvars[tnv].size());
171 Node bv = nm->mkBoundVar(tnv);
172 qtvars[tnv].push_back(bv);
173 qvars.push_back(bv);
174 subs.push_back(bv);
175 }
176 }
177 pas = pas[1];
178 if (!vars.empty())
179 {
180 pas =
181 pas.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
182 }
183 }
184 Trace("sygus-infer-debug") << " ...substituted to " << pas << std::endl;
185
186 // collect free functions, ensure no quantified formulas
187 TNode cur = pas;
188 // compute free variables
189 visit.push_back(cur);
190 do
191 {
192 cur = visit.back();
193 visit.pop_back();
194 if (visited.find(cur) == visited.end())
195 {
196 visited.insert(cur);
197 if (cur.getKind() == APPLY_UF)
198 {
199 Node op = cur.getOperator();
200 // visit the operator, which might not be a variable
201 visit.push_back(op);
202 }
203 else if (cur.isVar() && cur.getKind() != BOUND_VARIABLE)
204 {
205 // We are either in the case of a free first-order constant or a
206 // function in a higher-order context. We add to free_functions
207 // in either case. Note that a free constant that is not in a
208 // higher-order context is a 0-argument function-to-synthesize.
209 // We should not have traversed here before due to our visited cache.
210 Assert(std::find(free_functions.begin(), free_functions.end(), cur)
211 == free_functions.end());
212 free_functions.push_back(cur);
213 }
214 else if (cur.isClosure())
215 {
216 Trace("sygus-infer")
217 << "...fail: non-top-level quantifier." << std::endl;
218 return false;
219 }
220 for (const TNode& cn : cur)
221 {
222 visit.push_back(cn);
223 }
224 }
225 } while (!visit.empty());
226 processed_assertions.push_back(pas);
227 }
228
229 // no functions to synthesize
230 if (free_functions.empty())
231 {
232 Trace("sygus-infer") << "...fail: no free function symbols." << std::endl;
233 return false;
234 }
235
236 // Ensure the type of all free functions is handled by the sygus grammar
237 // constructor utility.
238 bool typeSuccess = true;
239 for (const Node& f : free_functions)
240 {
241 TypeNode tn = f.getType();
242 if (!theory::quantifiers::CegGrammarConstructor::isHandledType(tn))
243 {
244 Trace("sygus-infer") << "...fail: unhandled type " << tn << std::endl;
245 typeSuccess = false;
246 break;
247 }
248 }
249 if (!typeSuccess)
250 {
251 return false;
252 }
253
254 Assert(!processed_assertions.empty());
255 // conjunction of the assertions
256 Trace("sygus-infer") << "Construct body..." << std::endl;
257 Node body;
258 if (processed_assertions.size() == 1)
259 {
260 body = processed_assertions[0];
261 }
262 else
263 {
264 body = nm->mkNode(AND, processed_assertions);
265 }
266
267 // for each free function symbol, make a bound variable of the same type
268 Trace("sygus-infer") << "Do free function substitution..." << std::endl;
269 std::vector<Node> ff_vars;
270 std::map<Node, Node> ff_var_to_ff;
271 for (const Node& ff : free_functions)
272 {
273 Node ffv = nm->mkBoundVar(ff.getType());
274 ff_vars.push_back(ffv);
275 Trace("sygus-infer") << " synth-fun: " << ff << " as " << ffv << std::endl;
276 ff_var_to_ff[ffv] = ff;
277 }
278 // substitute free functions -> variables
279 body = body.substitute(free_functions.begin(),
280 free_functions.end(),
281 ff_vars.begin(),
282 ff_vars.end());
283 Trace("sygus-infer-debug") << "...got : " << body << std::endl;
284
285 // quantify the body
286 Trace("sygus-infer") << "Make inner sygus conjecture..." << std::endl;
287 body = body.negate();
288 if (!qvars.empty())
289 {
290 Node bvl = nm->mkNode(BOUND_VAR_LIST, qvars);
291 body = nm->mkNode(EXISTS, bvl, body);
292 }
293
294 // sygus attribute to mark the conjecture as a sygus conjecture
295 Trace("sygus-infer") << "Make outer sygus conjecture..." << std::endl;
296 Node sygusVar = nm->mkSkolem("sygus", nm->booleanType());
297 theory::SygusAttribute ca;
298 sygusVar.setAttribute(ca, true);
299 Node instAttr = nm->mkNode(INST_ATTRIBUTE, sygusVar);
300 Node instAttrList = nm->mkNode(INST_PATTERN_LIST, instAttr);
301
302 Node fbvl = nm->mkNode(BOUND_VAR_LIST, ff_vars);
303
304 body = nm->mkNode(FORALL, fbvl, body, instAttrList);
305
306 Trace("sygus-infer") << "*** Return sygus inference : " << body << std::endl;
307
308 // make a separate smt call
309 std::unique_ptr<SmtEngine> rrSygus;
310 theory::initializeSubsolver(rrSygus);
311 rrSygus->assertFormula(body);
312 Trace("sygus-infer") << "*** Check sat..." << std::endl;
313 Result r = rrSygus->checkSat();
314 Trace("sygus-infer") << "...result : " << r << std::endl;
315 if (r.asSatisfiabilityResult().isSat() != Result::UNSAT)
316 {
317 // failed, conjecture was infeasible
318 return false;
319 }
320 // get the synthesis solutions
321 std::map<Node, Node> synth_sols;
322 rrSygus->getSynthSolutions(synth_sols);
323
324 std::vector<Node> final_ff;
325 std::vector<Node> final_ff_sol;
326 for (std::map<Node, Node>::iterator it = synth_sols.begin();
327 it != synth_sols.end();
328 ++it)
329 {
330 Trace("sygus-infer") << " synth sol : " << it->first << " -> "
331 << it->second << std::endl;
332 Node ffv = it->first;
333 std::map<Node, Node>::iterator itffv = ff_var_to_ff.find(ffv);
334 // all synthesis solutions should correspond to a variable we introduced
335 Assert(itffv != ff_var_to_ff.end());
336 if (itffv != ff_var_to_ff.end())
337 {
338 Node ff = itffv->second;
339 Node body2 = it->second;
340 Trace("sygus-infer") << "Define " << ff << " as " << body2 << std::endl;
341 funs.push_back(ff);
342 sols.push_back(body2);
343 }
344 }
345 return true;
346 }
347
348
349 } // namespace passes
350 } // namespace preprocessing
351 } // namespace CVC4