Add implementation for SyGuS interpolation module (step4) (#4811)
[cvc5.git] / src / theory / quantifiers / sygus / sygus_interpol.cpp
1 /********************* */
2 /*! \file sygus_interpol.cpp
3 ** \verbatim
4 ** Top contributors (to current version):
5 ** Ying Sheng
6 ** This file is part of the CVC4 project.
7 ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
8 ** in the top-level source directory) and their institutional affiliations.
9 ** All rights reserved. See the file COPYING in the top-level source
10 ** directory for licensing information.\endverbatim
11 **
12 ** \brief Implementation of sygus interpolation utility, which
13 ** transforms an input of axioms and conjecture into an interpolation problem,
14 *and solve it.
15 **/
16
17 #include "theory/quantifiers/sygus/sygus_interpol.h"
18
19 #include "expr/datatype.h"
20 #include "expr/dtype.h"
21 #include "expr/node_algorithm.h"
22 #include "expr/sygus_datatype.h"
23 #include "options/smt_options.h"
24 #include "theory/datatypes/sygus_datatype_utils.h"
25 #include "theory/quantifiers/quantifiers_attributes.h"
26 #include "theory/quantifiers/quantifiers_rewriter.h"
27 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
28 #include "theory/quantifiers/term_util.h"
29 #include "theory/rewriter.h"
30
31 namespace CVC4 {
32 namespace theory {
33 namespace quantifiers {
34
35 SygusInterpol::SygusInterpol() {}
36
37 SygusInterpol::SygusInterpol(LogicInfo logic) : d_logic(logic) {}
38
39 void SygusInterpol::collectSymbols(const std::vector<Node>& axioms,
40 const Node& conj)
41 {
42 Trace("sygus-interpol-debug") << "Collect symbols..." << std::endl;
43 std::unordered_set<Node, NodeHashFunction> symSetAxioms;
44 std::unordered_set<Node, NodeHashFunction> symSetConj;
45 for (size_t i = 0, size = axioms.size(); i < size; i++)
46 {
47 expr::getSymbols(axioms[i], symSetAxioms);
48 }
49 expr::getSymbols(conj, symSetConj);
50 d_syms.insert(d_syms.end(), symSetAxioms.begin(), symSetAxioms.end());
51 d_syms.insert(d_syms.end(), symSetConj.begin(), symSetConj.end());
52 for (const Node& elem : symSetConj)
53 {
54 if (symSetAxioms.find(elem) != symSetAxioms.end())
55 {
56 d_symSetShared.insert(elem);
57 }
58 }
59 Trace("sygus-interpol-debug")
60 << "...finish, got " << d_syms.size() << " symbols in total. And "
61 << d_symSetShared.size() << " shared symbols." << std::endl;
62 }
63
64 void SygusInterpol::createVariables(bool needsShared)
65 {
66 NodeManager* nm = NodeManager::currentNM();
67 for (const Node& s : d_syms)
68 {
69 TypeNode tn = s.getType();
70 if (tn.isConstructor() || tn.isSelector() || tn.isTester())
71 {
72 // datatype symbols should be considered interpreted symbols here, not
73 // (higher-order) variables.
74 continue;
75 }
76 // Notice that we allow for non-first class (e.g. function) variables here.
77 std::stringstream ss;
78 ss << s;
79 Node var = nm->mkBoundVar(tn);
80 d_vars.push_back(var);
81 Node vlv = nm->mkBoundVar(ss.str(), tn);
82 d_vlvs.push_back(vlv);
83 if (!needsShared || d_symSetShared.find(s) != d_symSetShared.end())
84 {
85 d_varsShared.push_back(var);
86 d_vlvsShared.push_back(vlv);
87 d_varTypesShared.push_back(tn);
88 }
89 }
90 // make the sygus variable list
91 d_ibvlShared = nm->mkNode(kind::BOUND_VAR_LIST, d_vlvsShared);
92 Trace("sygus-interpol-debug") << "...finish" << std::endl;
93 }
94
95 void SygusInterpol::getIncludeCons(
96 const std::vector<Node>& axioms,
97 const Node& conj,
98 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& result)
99 {
100 NodeManager* nm = NodeManager::currentNM();
101 Assert(options::produceInterpols() != options::ProduceInterpols::NONE);
102 // ASSUMPTIONS
103 if (options::produceInterpols() == options::ProduceInterpols::ASSUMPTIONS)
104 {
105 Node tmpAssumptions =
106 (axioms.size() == 1 ? axioms[0] : nm->mkNode(kind::AND, axioms));
107 expr::getOperatorsMap(tmpAssumptions, result);
108 }
109 // CONJECTURE
110 else if (options::produceInterpols() == options::ProduceInterpols::CONJECTURE)
111 {
112 expr::getOperatorsMap(conj, result);
113 }
114 // SHARED
115 else if (options::produceInterpols() == options::ProduceInterpols::SHARED)
116 {
117 // Get operators from axioms
118 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>
119 include_cons_axioms;
120 Node tmpAssumptions =
121 (axioms.size() == 1 ? axioms[0] : nm->mkNode(kind::AND, axioms));
122 expr::getOperatorsMap(tmpAssumptions, include_cons_axioms);
123
124 // Get operators from conj
125 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>
126 include_cons_conj;
127 expr::getOperatorsMap(conj, include_cons_conj);
128
129 // Compute intersection
130 for (std::map<TypeNode,
131 std::unordered_set<Node, NodeHashFunction>>::iterator it =
132 include_cons_axioms.begin();
133 it != include_cons_axioms.end();
134 it++)
135 {
136 TypeNode tn = it->first;
137 std::unordered_set<Node, NodeHashFunction> axiomsOps = it->second;
138 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>::iterator
139 concIter = include_cons_conj.find(tn);
140 if (concIter != include_cons_conj.end())
141 {
142 std::unordered_set<Node, NodeHashFunction> conjOps = concIter->second;
143 for (const Node& n : axiomsOps)
144 {
145 if (conjOps.find(n) != conjOps.end())
146 {
147 if (result.find(tn) == result.end())
148 {
149 result[tn] = std::unordered_set<Node, NodeHashFunction>();
150 }
151 result[tn].insert(n);
152 }
153 }
154 }
155 }
156 }
157 // ALL
158 else if (options::produceInterpols() == options::ProduceInterpols::ALL)
159 {
160 Node tmpAssumptions =
161 (axioms.size() == 1 ? axioms[0] : nm->mkNode(kind::AND, axioms));
162 Node tmpAll = nm->mkNode(kind::AND, tmpAssumptions, conj);
163 expr::getOperatorsMap(tmpAll, result);
164 }
165 }
166
167 TypeNode SygusInterpol::setSynthGrammar(const TypeNode& itpGType,
168 const std::vector<Node>& axioms,
169 const Node& conj)
170 {
171 Trace("sygus-interpol-debug") << "Setup grammar..." << std::endl;
172 TypeNode itpGTypeS;
173 if (!itpGType.isNull())
174 {
175 // set user-defined grammar
176 Assert(itpGType.isDatatype() && itpGType.getDType().isSygus());
177 itpGTypeS = datatypes::utils::substituteAndGeneralizeSygusType(
178 itpGType, d_syms, d_vlvs);
179 Assert(itpGTypeS.isDatatype() && itpGTypeS.getDType().isSygus());
180 // TODO(Ying Sheng) check if the vars in user-defined grammar, are
181 // consistent with the shared vars
182 }
183 else
184 {
185 // set default grammar
186 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
187 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exclude_cons;
188 std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
189 getIncludeCons(axioms, conj, include_cons);
190 std::unordered_set<Node, NodeHashFunction> terms_irrelevant;
191 itpGTypeS =
192 CVC4::theory::quantifiers::CegGrammarConstructor::mkSygusDefaultType(
193 NodeManager::currentNM()->booleanType(),
194 d_ibvlShared,
195 "interpolation_grammar",
196 extra_cons,
197 exclude_cons,
198 include_cons,
199 terms_irrelevant);
200 }
201 Trace("sygus-interpol-debug") << "...finish setting up grammar" << std::endl;
202 return itpGTypeS;
203 }
204
205 Node SygusInterpol::mkPredicate(const std::string& name)
206 {
207 NodeManager* nm = NodeManager::currentNM();
208 // make the interpolation predicate to synthesize
209 Trace("sygus-interpol-debug")
210 << "Make interpolation predicate..." << std::endl;
211 TypeNode itpType = d_varTypesShared.empty()
212 ? nm->booleanType()
213 : nm->mkPredicateType(d_varTypesShared);
214 Node itp = nm->mkBoundVar(name.c_str(), itpType);
215 Trace("sygus-interpol-debug") << "...finish" << std::endl;
216 return itp;
217 }
218
219 void SygusInterpol::mkSygusConjecture(Node itp,
220 const std::vector<Node>& axioms,
221 const Node& conj)
222 {
223 NodeManager* nm = NodeManager::currentNM();
224 // make the interpolation application to synthesize
225 Trace("sygus-interpol-debug")
226 << "Make interpolation predicate app..." << std::endl;
227 std::vector<Node> ichildren;
228 ichildren.push_back(itp);
229 ichildren.insert(ichildren.end(), d_varsShared.begin(), d_varsShared.end());
230 Node itpApp =
231 d_varsShared.empty() ? itp : nm->mkNode(kind::APPLY_UF, ichildren);
232 Trace("sygus-interpol-debug") << "itpApp: " << itpApp << std::endl
233 << std::endl;
234 Trace("sygus-interpol-debug") << "...finish" << std::endl;
235
236 // set the sygus bound variable list
237 Trace("sygus-interpol-debug") << "Set attributes..." << std::endl;
238 itp.setAttribute(theory::SygusSynthFunVarListAttribute(), d_ibvlShared);
239 // sygus attribute
240 Node sygusVar = nm->mkSkolem("sygus", nm->booleanType());
241 theory::SygusAttribute ca;
242 sygusVar.setAttribute(ca, true);
243 Node instAttr = nm->mkNode(kind::INST_ATTRIBUTE, sygusVar);
244 std::vector<Node> iplc;
245 iplc.push_back(instAttr);
246 Node instAttrList = nm->mkNode(kind::INST_PATTERN_LIST, iplc);
247 Trace("sygus-interpol-debug") << "...finish" << std::endl;
248
249 // Fa( x )
250 Trace("sygus-interpol-debug") << "Make conjecture body..." << std::endl;
251 Node Fa = axioms.size() == 1 ? axioms[0] : nm->mkNode(kind::AND, axioms);
252 // Fa( x ) => A( x )
253 Node firstImplication = nm->mkNode(kind::IMPLIES, Fa, itpApp);
254 Trace("sygus-interpol-debug")
255 << "first implication: " << firstImplication << std::endl
256 << std::endl;
257 // A( x ) => Fc( x )
258 Node Fc = conj;
259 Node secondImplication = nm->mkNode(kind::IMPLIES, itpApp, Fc);
260 Trace("sygus-interpol-debug")
261 << "second implication: " << secondImplication << std::endl
262 << std::endl;
263 // Fa( x ) => A( x ) ^ A( x ) => Fc( x )
264 Node constraint = nm->mkNode(kind::AND, firstImplication, secondImplication);
265 constraint = constraint.substitute(
266 d_syms.begin(), d_syms.end(), d_vars.begin(), d_vars.end());
267 Trace("sygus-interpol-debug") << constraint << "...finish" << std::endl;
268 constraint = theory::Rewriter::rewrite(constraint);
269
270 d_sygusConj = constraint;
271 Trace("sygus-interpol") << "Generate: " << d_sygusConj << std::endl;
272 }
273
274 bool SygusInterpol::findInterpol(Expr& interpol, Node itp)
275 {
276 // get the synthesis solution
277 std::map<Expr, Expr> sols;
278 d_subSolver->getSynthSolutions(sols);
279 Assert(sols.size() == 1);
280 std::map<Expr, Expr>::iterator its = sols.find(itp.toExpr());
281 if (its == sols.end())
282 {
283 Trace("sygus-interpol")
284 << "SmtEngine::getInterpol: could not find solution!" << std::endl;
285 throw RecoverableModalException(
286 "Could not find solution for get-interpol.");
287 return false;
288 }
289 Trace("sygus-interpol") << "SmtEngine::getInterpol: solution is "
290 << its->second << std::endl;
291 Node interpoln = Node::fromExpr(its->second);
292 // replace back the created variables to original symbols.
293 Node interpoln_reduced;
294 if (interpoln.getKind() == kind::LAMBDA)
295 {
296 interpoln_reduced = interpoln[1];
297 }
298 else
299 {
300 interpoln_reduced = interpoln;
301 }
302 if (interpoln.getNumChildren() != 0 && interpoln[0].getNumChildren() != 0)
303 {
304 std::vector<Node> formals;
305 for (const Node& n : interpoln[0])
306 {
307 formals.push_back(n);
308 }
309 interpoln_reduced = interpoln_reduced.substitute(formals.begin(),
310 formals.end(),
311 d_symSetShared.begin(),
312 d_symSetShared.end());
313 }
314 // convert to expression
315 interpol = interpoln_reduced.toExpr();
316 return true;
317 }
318
319 bool SygusInterpol::SolveInterpolation(const std::string& name,
320 const std::vector<Node>& axioms,
321 const Node& conj,
322 const TypeNode& itpGType,
323 Expr& interpol)
324 {
325 NodeManager* nm = NodeManager::currentNM();
326 // we generate a new smt engine to do the interpolation query
327 d_subSolver.reset(new SmtEngine(nm->toExprManager()));
328 d_subSolver->setIsInternalSubsolver();
329 // get the logic
330 LogicInfo l = d_logic.getUnlockedCopy();
331 // enable everything needed for sygus
332 l.enableSygus();
333 d_subSolver->setLogic(l);
334
335 collectSymbols(axioms, conj);
336 createVariables(itpGType.isNull());
337 for (Node var : d_vars)
338 {
339 d_subSolver->declareSygusVar(name, var.toExpr(), var.getType().toType());
340 }
341 std::vector<Expr> vars_empty;
342 TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj);
343 Node itp = mkPredicate(name);
344 d_subSolver->declareSynthFun(
345 name, itp.toExpr(), grammarType.toType(), false, vars_empty);
346 mkSygusConjecture(itp, axioms, conj);
347 Trace("sygus-interpol") << "SmtEngine::getInterpol: made conjecture : "
348 << d_sygusConj << ", solving for "
349 << d_sygusConj[0][0].toExpr() << std::endl;
350 d_subSolver->assertSygusConstraint(d_sygusConj.toExpr());
351
352 Trace("sygus-interpol") << " SmtEngine::getInterpol check sat..."
353 << std::endl;
354 Result r = d_subSolver->checkSynth();
355 Trace("sygus-interpol") << " SmtEngine::getInterpol result: " << r
356 << std::endl;
357 if (r.asSatisfiabilityResult().isSat() == Result::UNSAT)
358 {
359 return findInterpol(interpol, itp);
360 }
361 return false;
362 }
363
364 } // namespace quantifiers
365 } // namespace theory
366 } // namespace CVC4