Replace Expr-level datatype with Node-level DType (#4875)
[cvc5.git] / src / theory / quantifiers / sygus / sygus_interpol.cpp
index cc97bb9748fe0733193a4a57d25b15a0d1298bff..c2ca83e41ae2152aea45dd0567042ecb7cf66183 100644 (file)
 
 #include "theory/quantifiers/sygus/sygus_interpol.h"
 
-#include "expr/datatype.h"
 #include "expr/dtype.h"
 #include "expr/node_algorithm.h"
-#include "expr/sygus_datatype.h"
 #include "options/smt_options.h"
 #include "theory/datatypes/sygus_datatype_utils.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
-#include "theory/quantifiers/quantifiers_rewriter.h"
 #include "theory/quantifiers/sygus/sygus_grammar_cons.h"
-#include "theory/quantifiers/term_util.h"
 #include "theory/rewriter.h"
 
 namespace CVC4 {
@@ -188,15 +184,14 @@ TypeNode SygusInterpol::setSynthGrammar(const TypeNode& itpGType,
     std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
     getIncludeCons(axioms, conj, include_cons);
     std::unordered_set<Node, NodeHashFunction> terms_irrelevant;
-    itpGTypeS =
-        CVC4::theory::quantifiers::CegGrammarConstructor::mkSygusDefaultType(
-            NodeManager::currentNM()->booleanType(),
-            d_ibvlShared,
-            "interpolation_grammar",
-            extra_cons,
-            exclude_cons,
-            include_cons,
-            terms_irrelevant);
+    itpGTypeS = CegGrammarConstructor::mkSygusDefaultType(
+        NodeManager::currentNM()->booleanType(),
+        d_ibvlShared,
+        "interpolation_grammar",
+        extra_cons,
+        exclude_cons,
+        include_cons,
+        terms_irrelevant);
   }
   Trace("sygus-interpol-debug") << "...finish setting up grammar" << std::endl;
   return itpGTypeS;
@@ -204,7 +199,15 @@ TypeNode SygusInterpol::setSynthGrammar(const TypeNode& itpGType,
 
 Node SygusInterpol::mkPredicate(const std::string& name)
 {
-  Node itp;
+  NodeManager* nm = NodeManager::currentNM();
+  // make the interpolation predicate to synthesize
+  Trace("sygus-interpol-debug")
+      << "Make interpolation predicate..." << std::endl;
+  TypeNode itpType = d_varTypesShared.empty()
+                         ? nm->booleanType()
+                         : nm->mkPredicateType(d_varTypesShared);
+  Node itp = nm->mkBoundVar(name.c_str(), itpType);
+  Trace("sygus-interpol-debug") << "...finish" << std::endl;
   return itp;
 }
 
@@ -212,11 +215,100 @@ void SygusInterpol::mkSygusConjecture(Node itp,
                                       const std::vector<Node>& axioms,
                                       const Node& conj)
 {
+  NodeManager* nm = NodeManager::currentNM();
+  // make the interpolation application to synthesize
+  Trace("sygus-interpol-debug")
+      << "Make interpolation predicate app..." << std::endl;
+  std::vector<Node> ichildren;
+  ichildren.push_back(itp);
+  ichildren.insert(ichildren.end(), d_varsShared.begin(), d_varsShared.end());
+  Node itpApp =
+      d_varsShared.empty() ? itp : nm->mkNode(kind::APPLY_UF, ichildren);
+  Trace("sygus-interpol-debug") << "itpApp: " << itpApp << std::endl
+                                << std::endl;
+  Trace("sygus-interpol-debug") << "...finish" << std::endl;
+
+  // set the sygus bound variable list
+  Trace("sygus-interpol-debug") << "Set attributes..." << std::endl;
+  itp.setAttribute(SygusSynthFunVarListAttribute(), d_ibvlShared);
+  // sygus attribute
+  Node sygusVar = nm->mkSkolem("sygus", nm->booleanType());
+  SygusAttribute ca;
+  sygusVar.setAttribute(ca, true);
+  Node instAttr = nm->mkNode(kind::INST_ATTRIBUTE, sygusVar);
+  std::vector<Node> iplc;
+  iplc.push_back(instAttr);
+  Node instAttrList = nm->mkNode(kind::INST_PATTERN_LIST, iplc);
+  Trace("sygus-interpol-debug") << "...finish" << std::endl;
+
+  // Fa( x )
+  Trace("sygus-interpol-debug") << "Make conjecture body..." << std::endl;
+  Node Fa = axioms.size() == 1 ? axioms[0] : nm->mkNode(kind::AND, axioms);
+  // Fa( x ) => A( x )
+  Node firstImplication = nm->mkNode(kind::IMPLIES, Fa, itpApp);
+  Trace("sygus-interpol-debug")
+      << "first implication: " << firstImplication << std::endl
+      << std::endl;
+  // A( x ) => Fc( x )
+  Node Fc = conj;
+  Node secondImplication = nm->mkNode(kind::IMPLIES, itpApp, Fc);
+  Trace("sygus-interpol-debug")
+      << "second implication: " << secondImplication << std::endl
+      << std::endl;
+  // Fa( x ) => A( x ) ^ A( x ) => Fc( x )
+  Node constraint = nm->mkNode(kind::AND, firstImplication, secondImplication);
+  constraint = constraint.substitute(
+      d_syms.begin(), d_syms.end(), d_vars.begin(), d_vars.end());
+  Trace("sygus-interpol-debug") << constraint << "...finish" << std::endl;
+  constraint = Rewriter::rewrite(constraint);
+
+  d_sygusConj = constraint;
+  Trace("sygus-interpol") << "Generate: " << d_sygusConj << std::endl;
 }
 
 bool SygusInterpol::findInterpol(Expr& interpol, Node itp)
 {
-  return false;
+  // get the synthesis solution
+  std::map<Node, Node> sols;
+  d_subSolver->getSynthSolutions(sols);
+  Assert(sols.size() == 1);
+  std::map<Node, Node>::iterator its = sols.find(itp);
+  if (its == sols.end())
+  {
+    Trace("sygus-interpol")
+        << "SmtEngine::getInterpol: could not find solution!" << std::endl;
+    throw RecoverableModalException(
+        "Could not find solution for get-interpol.");
+    return false;
+  }
+  Trace("sygus-interpol") << "SmtEngine::getInterpol: solution is "
+                          << its->second << std::endl;
+  Node interpoln = its->second;
+  // replace back the created variables to original symbols.
+  Node interpoln_reduced;
+  if (interpoln.getKind() == kind::LAMBDA)
+  {
+    interpoln_reduced = interpoln[1];
+  }
+  else
+  {
+    interpoln_reduced = interpoln;
+  }
+  if (interpoln.getNumChildren() != 0 && interpoln[0].getNumChildren() != 0)
+  {
+    std::vector<Node> formals;
+    for (const Node& n : interpoln[0])
+    {
+      formals.push_back(n);
+    }
+    interpoln_reduced = interpoln_reduced.substitute(formals.begin(),
+                                                     formals.end(),
+                                                     d_symSetShared.begin(),
+                                                     d_symSetShared.end());
+  }
+  // convert to expression
+  interpol = interpoln_reduced.toExpr();
+  return true;
 }
 
 bool SygusInterpol::SolveInterpolation(const std::string& name,
@@ -225,6 +317,41 @@ bool SygusInterpol::SolveInterpolation(const std::string& name,
                                        const TypeNode& itpGType,
                                        Expr& interpol)
 {
+  NodeManager* nm = NodeManager::currentNM();
+  // we generate a new smt engine to do the interpolation query
+  d_subSolver.reset(new SmtEngine(nm->toExprManager()));
+  d_subSolver->setIsInternalSubsolver();
+  // get the logic
+  LogicInfo l = d_logic.getUnlockedCopy();
+  // enable everything needed for sygus
+  l.enableSygus();
+  d_subSolver->setLogic(l);
+
+  collectSymbols(axioms, conj);
+  createVariables(itpGType.isNull());
+  for (Node var : d_vars)
+  {
+    d_subSolver->declareSygusVar(name, var, var.getType());
+  }
+  std::vector<Node> vars_empty;
+  TypeNode grammarType = setSynthGrammar(itpGType, axioms, conj);
+  Node itp = mkPredicate(name);
+  d_subSolver->declareSynthFun(name, itp, grammarType, false, vars_empty);
+  mkSygusConjecture(itp, axioms, conj);
+  Trace("sygus-interpol") << "SmtEngine::getInterpol: made conjecture : "
+                          << d_sygusConj << ", solving for "
+                          << d_sygusConj[0][0] << std::endl;
+  d_subSolver->assertSygusConstraint(d_sygusConj);
+
+  Trace("sygus-interpol") << "  SmtEngine::getInterpol check sat..."
+                          << std::endl;
+  Result r = d_subSolver->checkSynth();
+  Trace("sygus-interpol") << "  SmtEngine::getInterpol result: " << r
+                          << std::endl;
+  if (r.asSatisfiabilityResult().isSat() == Result::UNSAT)
+  {
+    return findInterpol(interpol, itp);
+  }
   return false;
 }