Make ExtTheory independent of Theory (#5010)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 3 Sep 2020 01:18:48 +0000 (20:18 -0500)
committerGitHub <noreply@github.com>
Thu, 3 Sep 2020 01:18:48 +0000 (18:18 -0700)
This makes it so that ExtTheory uses a generic callback instead of relying on Theory.

The primary purpose of this commit is to eliminate the connection of TheoryBV and ExtTheory. This commit moves all things related to ExtTheory in BV into CoreSolver.

It also refactors the use of ExtTheory in strings and arithmetic.

22 files changed:
src/CMakeLists.txt
src/theory/arith/nl/ext_theory_callback.cpp [new file with mode: 0644]
src/theory/arith/nl/ext_theory_callback.h [new file with mode: 0644]
src/theory/arith/nl/nonlinear_extension.cpp
src/theory/arith/nl/nonlinear_extension.h
src/theory/arith/theory_arith.cpp
src/theory/arith/theory_arith.h
src/theory/arith/theory_arith_private.cpp
src/theory/arith/theory_arith_private.h
src/theory/bv/bv_subtheory_core.cpp
src/theory/bv/bv_subtheory_core.h
src/theory/bv/theory_bv.cpp
src/theory/bv/theory_bv.h
src/theory/datatypes/theory_datatypes.cpp
src/theory/datatypes/theory_datatypes.h
src/theory/ext_theory.cpp
src/theory/ext_theory.h
src/theory/strings/extf_solver.cpp
src/theory/strings/extf_solver.h
src/theory/strings/theory_strings.cpp
src/theory/strings/theory_strings.h
src/theory/theory.h

index 692ae09ac17e2079c5243aaa98efb8690783fc8b..971648839f692c025c406fc924e894d750dac5d2 100644 (file)
@@ -309,6 +309,8 @@ libcvc4_add_sources(
   theory/arith/nl/cad/projections.h
   theory/arith/nl/cad/variable_ordering.cpp
   theory/arith/nl/cad/variable_ordering.h
+  theory/arith/nl/ext_theory_callback.cpp
+  theory/arith/nl/ext_theory_callback.h
   theory/arith/nl/iand_solver.cpp
   theory/arith/nl/iand_solver.h
   theory/arith/nl/inference.cpp
diff --git a/src/theory/arith/nl/ext_theory_callback.cpp b/src/theory/arith/nl/ext_theory_callback.cpp
new file mode 100644 (file)
index 0000000..4518df0
--- /dev/null
@@ -0,0 +1,131 @@
+/*********************                                                        */
+/*! \file ext_theory_callback.cpp
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The extended theory callback for non-linear arithmetic
+ **/
+
+#include "theory/arith/nl/ext_theory_callback.h"
+
+#include "theory/arith/arith_utilities.h"
+
+using namespace CVC4::kind;
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+namespace nl {
+
+NlExtTheoryCallback::NlExtTheoryCallback(eq::EqualityEngine* ee) : d_ee(ee)
+{
+  d_zero = NodeManager::currentNM()->mkConst(Rational(0));
+}
+
+bool NlExtTheoryCallback::getCurrentSubstitution(
+    int effort,
+    const std::vector<Node>& vars,
+    std::vector<Node>& subs,
+    std::map<Node, std::vector<Node>>& exp)
+{
+  // get the constant equivalence classes
+  std::map<Node, std::vector<int>> rep_to_subs_index;
+
+  bool retVal = false;
+  for (unsigned i = 0; i < vars.size(); i++)
+  {
+    Node n = vars[i];
+    if (d_ee->hasTerm(n))
+    {
+      Node nr = d_ee->getRepresentative(n);
+      if (nr.isConst())
+      {
+        subs.push_back(nr);
+        Trace("nl-subs") << "Basic substitution : " << n << " -> " << nr
+                         << std::endl;
+        exp[n].push_back(n.eqNode(nr));
+        retVal = true;
+      }
+      else
+      {
+        rep_to_subs_index[nr].push_back(i);
+        subs.push_back(n);
+      }
+    }
+    else
+    {
+      subs.push_back(n);
+    }
+  }
+
+  // return true if the substitution is non-trivial
+  return retVal;
+}
+
+bool NlExtTheoryCallback::isExtfReduced(int effort,
+                                        Node n,
+                                        Node on,
+                                        std::vector<Node>& exp)
+{
+  if (n != d_zero)
+  {
+    Kind k = n.getKind();
+    return k != NONLINEAR_MULT && !isTranscendentalKind(k) && k != IAND;
+  }
+  Assert(n == d_zero);
+  if (on.getKind() == NONLINEAR_MULT)
+  {
+    Trace("nl-ext-zero-exp")
+        << "Infer zero : " << on << " == " << n << std::endl;
+    // minimize explanation if a substitution+rewrite results in zero
+    const std::set<Node> vars(on.begin(), on.end());
+
+    for (unsigned i = 0, size = exp.size(); i < size; i++)
+    {
+      Trace("nl-ext-zero-exp")
+          << "  exp[" << i << "] = " << exp[i] << std::endl;
+      std::vector<Node> eqs;
+      if (exp[i].getKind() == EQUAL)
+      {
+        eqs.push_back(exp[i]);
+      }
+      else if (exp[i].getKind() == AND)
+      {
+        for (const Node& ec : exp[i])
+        {
+          if (ec.getKind() == EQUAL)
+          {
+            eqs.push_back(ec);
+          }
+        }
+      }
+
+      for (unsigned j = 0; j < eqs.size(); j++)
+      {
+        for (unsigned r = 0; r < 2; r++)
+        {
+          if (eqs[j][r] == d_zero && vars.find(eqs[j][1 - r]) != vars.end())
+          {
+            Trace("nl-ext-zero-exp")
+                << "...single exp : " << eqs[j] << std::endl;
+            exp.clear();
+            exp.push_back(eqs[j]);
+            return true;
+          }
+        }
+      }
+    }
+  }
+  return true;
+}
+
+}  // namespace nl
+}  // namespace arith
+}  // namespace theory
+}  // namespace CVC4
diff --git a/src/theory/arith/nl/ext_theory_callback.h b/src/theory/arith/nl/ext_theory_callback.h
new file mode 100644 (file)
index 0000000..0d95db1
--- /dev/null
@@ -0,0 +1,86 @@
+/*********************                                                        */
+/*! \file ext_theory_callback.h
+ ** \verbatim
+ ** Top contributors (to current version):
+ **   Andrew Reynolds
+ ** This file is part of the CVC4 project.
+ ** Copyright (c) 2009-2020 by the authors listed in the file AUTHORS
+ ** in the top-level source directory) and their institutional affiliations.
+ ** All rights reserved.  See the file COPYING in the top-level source
+ ** directory for licensing information.\endverbatim
+ **
+ ** \brief The extended theory callback for non-linear arithmetic
+ **/
+
+#ifndef CVC4__THEORY__ARITH__NL__EXT_THEORY_CALLBACK_H
+#define CVC4__THEORY__ARITH__NL__EXT_THEORY_CALLBACK_H
+
+#include "expr/node.h"
+#include "theory/ext_theory.h"
+
+namespace CVC4 {
+namespace theory {
+namespace arith {
+namespace nl {
+
+class NlExtTheoryCallback : public ExtTheoryCallback
+{
+ public:
+  NlExtTheoryCallback(eq::EqualityEngine* ee);
+  ~NlExtTheoryCallback() {}
+  /** Get current substitution
+   *
+   * This function and the one below are
+   * used for context-dependent
+   * simplification, see Section 3.1 of
+   * "Designing Theory Solvers with Extensions"
+   * by Reynolds et al. FroCoS 2017.
+   *
+   * effort : an identifier indicating the stage where
+   *          we are performing context-dependent simplification,
+   * vars : a set of arithmetic variables.
+   *
+   * This function populates subs and exp, such that for 0 <= i < vars.size():
+   *   ( exp[vars[i]] ) => vars[i] = subs[i]
+   * where exp[vars[i]] is a set of assertions
+   * that hold in the current context. We call { vars -> subs } a "derivable
+   * substituion" (see Reynolds et al. FroCoS 2017).
+   */
+  bool getCurrentSubstitution(int effort,
+                              const std::vector<Node>& vars,
+                              std::vector<Node>& subs,
+                              std::map<Node, std::vector<Node>>& exp) override;
+  /** Is the term n in reduced form?
+   *
+   * Used for context-dependent simplification.
+   *
+   * effort : an identifier indicating the stage where
+   *          we are performing context-dependent simplification,
+   * on : the original term that we reduced to n,
+   * exp : an explanation such that ( exp => on = n ).
+   *
+   * We return a pair ( b, exp' ) such that
+   *   if b is true, then:
+   *     n is in reduced form
+   *     if exp' is non-null, then ( exp' => on = n )
+   * The second part of the pair is used for constructing
+   * minimal explanations for context-dependent simplifications.
+   */
+  bool isExtfReduced(int effort,
+                     Node n,
+                     Node on,
+                     std::vector<Node>& exp) override;
+
+ private:
+  /** The underlying equality engine. */
+  eq::EqualityEngine* d_ee;
+  /** Commonly used nodes */
+  Node d_zero;
+};
+
+}  // namespace nl
+}  // namespace arith
+}  // namespace theory
+}  // namespace CVC4
+
+#endif /* CVC4__THEORY__ARITH__NL__EXT_THEORY_CALLBACK_H */
index ada6aa11a184fe48d5f5aab7d158dbb38c6208b6..733912969a99d281854a999df8653ad30bde4c13 100644 (file)
@@ -39,7 +39,11 @@ NonlinearExtension::NonlinearExtension(TheoryArith& containing,
       d_ee(ee),
       d_needsLastCall(false),
       d_checkCounter(0),
-      d_extTheory(&containing),
+      d_extTheoryCb(ee),
+      d_extTheory(d_extTheoryCb,
+                  containing.getSatContext(),
+                  containing.getUserContext(),
+                  containing.getOutputChannel()),
       d_model(containing.getSatContext()),
       d_trSlv(d_model),
       d_nlSlv(containing, d_model),
@@ -67,101 +71,6 @@ void NonlinearExtension::preRegisterTerm(TNode n)
   d_extTheory.registerTermRec(n);
 }
 
-bool NonlinearExtension::getCurrentSubstitution(
-    int effort,
-    const std::vector<Node>& vars,
-    std::vector<Node>& subs,
-    std::map<Node, std::vector<Node>>& exp)
-{
-  // get the constant equivalence classes
-  std::map<Node, std::vector<int>> rep_to_subs_index;
-
-  bool retVal = false;
-  for (unsigned i = 0; i < vars.size(); i++)
-  {
-    Node n = vars[i];
-    if (d_ee->hasTerm(n))
-    {
-      Node nr = d_ee->getRepresentative(n);
-      if (nr.isConst())
-      {
-        subs.push_back(nr);
-        Trace("nl-subs") << "Basic substitution : " << n << " -> " << nr
-                         << std::endl;
-        exp[n].push_back(n.eqNode(nr));
-        retVal = true;
-      }
-      else
-      {
-        rep_to_subs_index[nr].push_back(i);
-        subs.push_back(n);
-      }
-    }
-    else
-    {
-      subs.push_back(n);
-    }
-  }
-
-  // return true if the substitution is non-trivial
-  return retVal;
-}
-
-std::pair<bool, Node> NonlinearExtension::isExtfReduced(
-    int effort, Node n, Node on, const std::vector<Node>& exp) const
-{
-  if (n != d_zero)
-  {
-    Kind k = n.getKind();
-    return std::make_pair(
-        k != NONLINEAR_MULT && !isTranscendentalKind(k) && k != IAND,
-        Node::null());
-  }
-  Assert(n == d_zero);
-  if (on.getKind() == NONLINEAR_MULT)
-  {
-    Trace("nl-ext-zero-exp")
-        << "Infer zero : " << on << " == " << n << std::endl;
-    // minimize explanation if a substitution+rewrite results in zero
-    const std::set<Node> vars(on.begin(), on.end());
-
-    for (unsigned i = 0, size = exp.size(); i < size; i++)
-    {
-      Trace("nl-ext-zero-exp")
-          << "  exp[" << i << "] = " << exp[i] << std::endl;
-      std::vector<Node> eqs;
-      if (exp[i].getKind() == EQUAL)
-      {
-        eqs.push_back(exp[i]);
-      }
-      else if (exp[i].getKind() == AND)
-      {
-        for (const Node& ec : exp[i])
-        {
-          if (ec.getKind() == EQUAL)
-          {
-            eqs.push_back(ec);
-          }
-        }
-      }
-
-      for (unsigned j = 0; j < eqs.size(); j++)
-      {
-        for (unsigned r = 0; r < 2; r++)
-        {
-          if (eqs[j][r] == d_zero && vars.find(eqs[j][1 - r]) != vars.end())
-          {
-            Trace("nl-ext-zero-exp")
-                << "...single exp : " << eqs[j] << std::endl;
-            return std::make_pair(true, eqs[j]);
-          }
-        }
-      }
-    }
-  }
-  return std::make_pair(true, Node::null());
-}
-
 void NonlinearExtension::sendLemmas(const std::vector<NlLemma>& out)
 {
   for (const NlLemma& nlem : out)
index d035b1056eb200859d5515f4a5725d5bbf30fbbd..41f24e76900aa81877782a0cefd682962f405351 100644 (file)
@@ -27,6 +27,7 @@
 #include "expr/kind.h"
 #include "expr/node.h"
 #include "theory/arith/nl/cad_solver.h"
+#include "theory/arith/nl/ext_theory_callback.h"
 #include "theory/arith/nl/iand_solver.h"
 #include "theory/arith/nl/nl_lemma_utils.h"
 #include "theory/arith/nl/nl_model.h"
@@ -77,48 +78,6 @@ class NonlinearExtension
    * Does non-context dependent setup for a node connected to a theory.
    */
   void preRegisterTerm(TNode n);
-  /** Get current substitution
-   *
-   * This function and the one below are
-   * used for context-dependent
-   * simplification, see Section 3.1 of
-   * "Designing Theory Solvers with Extensions"
-   * by Reynolds et al. FroCoS 2017.
-   *
-   * effort : an identifier indicating the stage where
-   *          we are performing context-dependent simplification,
-   * vars : a set of arithmetic variables.
-   *
-   * This function populates subs and exp, such that for 0 <= i < vars.size():
-   *   ( exp[vars[i]] ) => vars[i] = subs[i]
-   * where exp[vars[i]] is a set of assertions
-   * that hold in the current context. We call { vars -> subs } a "derivable
-   * substituion" (see Reynolds et al. FroCoS 2017).
-   */
-  bool getCurrentSubstitution(int effort,
-                              const std::vector<Node>& vars,
-                              std::vector<Node>& subs,
-                              std::map<Node, std::vector<Node>>& exp);
-  /** Is the term n in reduced form?
-   *
-   * Used for context-dependent simplification.
-   *
-   * effort : an identifier indicating the stage where
-   *          we are performing context-dependent simplification,
-   * on : the original term that we reduced to n,
-   * exp : an explanation such that ( exp => on = n ).
-   *
-   * We return a pair ( b, exp' ) such that
-   *   if b is true, then:
-   *     n is in reduced form
-   *     if exp' is non-null, then ( exp' => on = n )
-   * The second part of the pair is used for constructing
-   * minimal explanations for context-dependent simplifications.
-   */
-  std::pair<bool, Node> isExtfReduced(int effort,
-                                      Node n,
-                                      Node on,
-                                      const std::vector<Node>& exp) const;
   /** Check at effort level e.
    *
    * This call may result in (possibly multiple) calls to d_out->lemma(...)
@@ -300,6 +259,8 @@ class NonlinearExtension
    * (modelBasedRefinement). This counter is used for interleaving strategies.
    */
   unsigned d_checkCounter;
+  /** The callback for the extended theory below */
+  NlExtTheoryCallback d_extTheoryCb;
   /** Extended theory, responsible for context-dependent simplification. */
   ExtTheory d_extTheory;
   /** The non-linear model object
index 762634ce71549132e8cff0a74dd221ad9d22bff9..fbf25705c7d4a468067f56224d914ce86a0d9d1d 100644 (file)
@@ -117,14 +117,6 @@ TrustNode TheoryArith::explain(TNode n)
   return TrustNode::mkTrustPropExp(n, exp, nullptr);
 }
 
-bool TheoryArith::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) {
-  return d_internal->getCurrentSubstitution( effort, vars, subs, exp );
-}
-
-bool TheoryArith::isExtfReduced( int effort, Node n, Node on, std::vector< Node >& exp ) {
-  return d_internal->isExtfReduced( effort, n, on, exp );
-}
-
 void TheoryArith::propagate(Effort e) {
   d_internal->propagate(e);
 }
index 6adf8f66abb4e86204233359c9217a0b4d3419f5..71a25ac123b44e42e57460ee754177547527dcb9 100644 (file)
@@ -75,14 +75,6 @@ class TheoryArith : public Theory {
   bool needsCheckLastEffort() override;
   void propagate(Effort e) override;
   TrustNode explain(TNode n) override;
-  bool getCurrentSubstitution(int effort,
-                              std::vector<Node>& vars,
-                              std::vector<Node>& subs,
-                              std::map<Node, std::vector<Node> >& exp) override;
-  bool isExtfReduced(int effort,
-                     Node n,
-                     Node on,
-                     std::vector<Node>& exp) override;
 
   bool collectModelInfo(TheoryModel* m) override;
 
index 8a780116cab6223e4b13af3825c59d64f439db05..1b49b73508efa2e4d7c4a661aa157cbcb20ee8a8 100644 (file)
@@ -3877,31 +3877,6 @@ Node TheoryArithPrivate::explain(TNode n)
   }
 }
 
-bool TheoryArithPrivate::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) {
-  if (d_nonlinearExtension != nullptr)
-  {
-    return d_nonlinearExtension->getCurrentSubstitution( effort, vars, subs, exp );
-  }else{
-    return false;
-  }
-}
-
-bool TheoryArithPrivate::isExtfReduced(int effort, Node n, Node on,
-                                       std::vector<Node>& exp) {
-  if (d_nonlinearExtension != nullptr)
-  {
-    std::pair<bool, Node> reduced =
-        d_nonlinearExtension->isExtfReduced(effort, n, on, exp);
-    if (!reduced.second.isNull()) {
-      exp.clear();
-      exp.push_back(reduced.second);
-    }
-    return reduced.first;
-  } else {
-    return false;  // d_containing.isExtfReduced( effort, n, on );
-  }
-}
-
 void TheoryArithPrivate::propagate(Theory::Effort e) {
   // This uses model values for safety. Disable for now.
   if (d_qflraStatus == Result::SAT
index d96b5e2d3de8496327d8a1a14756b3b1c2d47225..d0428f2ef6447587ad66b1f684517f89eba96d4b 100644 (file)
@@ -452,8 +452,6 @@ public:
   bool needsCheckLastEffort();
   void propagate(Theory::Effort e);
   Node explain(TNode n);
-  bool getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp );
-  bool isExtfReduced( int effort, Node n, Node on, std::vector< Node >& exp );
 
   Rational deltaValueForTotalOrder() const;
 
index 38c5cb4820363b3d2df03b5fa72b80a702feabac..b341b067194c215de65361cb168e6b21dba87c98 100644 (file)
@@ -31,7 +31,65 @@ using namespace CVC4::theory;
 using namespace CVC4::theory::bv;
 using namespace CVC4::theory::bv::utils;
 
-CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt)
+bool CoreSolverExtTheoryCallback::getCurrentSubstitution(
+    int effort,
+    const std::vector<Node>& vars,
+    std::vector<Node>& subs,
+    std::map<Node, std::vector<Node> >& exp)
+{
+  if (d_equalityEngine == nullptr)
+  {
+    return false;
+  }
+  // get the constant equivalence classes
+  bool retVal = false;
+  for (const Node& n : vars)
+  {
+    if (d_equalityEngine->hasTerm(n))
+    {
+      Node nr = d_equalityEngine->getRepresentative(n);
+      if (nr.isConst())
+      {
+        subs.push_back(nr);
+        exp[n].push_back(n.eqNode(nr));
+        retVal = true;
+      }
+      else
+      {
+        subs.push_back(n);
+      }
+    }
+    else
+    {
+      subs.push_back(n);
+    }
+  }
+  // return true if the substitution is non-trivial
+  return retVal;
+}
+
+bool CoreSolverExtTheoryCallback::getReduction(int effort,
+                                               Node n,
+                                               Node& nr,
+                                               bool& satDep)
+{
+  Trace("bv-ext") << "TheoryBV::checkExt : non-reduced : " << n << std::endl;
+  if (n.getKind() == kind::BITVECTOR_TO_NAT)
+  {
+    nr = utils::eliminateBv2Nat(n);
+    satDep = false;
+    return true;
+  }
+  else if (n.getKind() == kind::INT_TO_BITVECTOR)
+  {
+    nr = utils::eliminateInt2Bv(n);
+    satDep = false;
+    return true;
+  }
+  return false;
+}
+
+CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv)
     : SubtheorySolver(c, bv),
       d_notify(*this),
       d_isComplete(c, true),
@@ -39,9 +97,18 @@ CoreSolver::CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt)
       d_preregisterCalled(false),
       d_checkCalled(false),
       d_bv(bv),
-      d_extTheory(extt),
-      d_reasons(c)
+      d_extTheoryCb(),
+      d_extTheory(new ExtTheory(d_extTheoryCb,
+                                bv->getSatContext(),
+                                bv->getUserContext(),
+                                bv->getOutputChannel())),
+      d_reasons(c),
+      d_needsLastCallCheck(false),
+      d_extf_range_infer(bv->getUserContext()),
+      d_extf_collapse_infer(bv->getUserContext())
 {
+  d_extTheory->addFunctionKind(kind::BITVECTOR_TO_NAT);
+  d_extTheory->addFunctionKind(kind::INT_TO_BITVECTOR);
 }
 
 CoreSolver::~CoreSolver() {}
@@ -431,3 +498,141 @@ CoreSolver::Statistics::Statistics()
 CoreSolver::Statistics::~Statistics() {
   smtStatisticsRegistry()->unregisterStat(&d_numCallstoCheck);
 }
+
+void CoreSolver::checkExtf(Theory::Effort e)
+{
+  if (e == Theory::EFFORT_LAST_CALL)
+  {
+    std::vector<Node> nred = d_extTheory->getActive();
+    doExtfReductions(nred);
+  }
+  Assert(e == Theory::EFFORT_FULL);
+  // do inferences (adds external lemmas)  TODO: this can be improved to add
+  // internal inferences
+  std::vector<Node> nred;
+  if (d_extTheory->doInferences(0, nred))
+  {
+    return;
+  }
+  d_needsLastCallCheck = false;
+  if (!nred.empty())
+  {
+    // other inferences involving bv2nat, int2bv
+    if (options::bvAlgExtf())
+    {
+      if (doExtfInferences(nred))
+      {
+        return;
+      }
+    }
+    if (!options::bvLazyReduceExtf())
+    {
+      if (doExtfReductions(nred))
+      {
+        return;
+      }
+    }
+    else
+    {
+      d_needsLastCallCheck = true;
+    }
+  }
+}
+
+bool CoreSolver::needsCheckLastEffort() const { return d_needsLastCallCheck; }
+
+bool CoreSolver::doExtfInferences(std::vector<Node>& terms)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  bool sentLemma = false;
+  eq::EqualityEngine* ee = d_equalityEngine;
+  std::map<Node, Node> op_map;
+  for (unsigned j = 0; j < terms.size(); j++)
+  {
+    TNode n = terms[j];
+    Assert(n.getKind() == kind::BITVECTOR_TO_NAT
+           || n.getKind() == kind::INT_TO_BITVECTOR);
+    if (n.getKind() == kind::BITVECTOR_TO_NAT)
+    {
+      // range lemmas
+      if (d_extf_range_infer.find(n) == d_extf_range_infer.end())
+      {
+        d_extf_range_infer.insert(n);
+        unsigned bvs = n[0].getType().getBitVectorSize();
+        Node min = nm->mkConst(Rational(0));
+        Node max = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs)));
+        Node lem = nm->mkNode(kind::AND,
+                              nm->mkNode(kind::GEQ, n, min),
+                              nm->mkNode(kind::LT, n, max));
+        Trace("bv-extf-lemma")
+            << "BV extf lemma (range) : " << lem << std::endl;
+        d_bv->getOutputChannel().lemma(lem);
+        sentLemma = true;
+      }
+    }
+    Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n[0]) : n[0];
+    op_map[r] = n;
+  }
+  for (unsigned j = 0; j < terms.size(); j++)
+  {
+    TNode n = terms[j];
+    Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n) : n;
+    std::map<Node, Node>::iterator it = op_map.find(r);
+    if (it != op_map.end())
+    {
+      Node parent = it->second;
+      // Node cterm = parent[0]==n ? parent : nm->mkNode( parent.getOperator(),
+      // n );
+      Node cterm = parent[0].eqNode(n);
+      Trace("bv-extf-lemma-debug")
+          << "BV extf collapse based on : " << cterm << std::endl;
+      if (d_extf_collapse_infer.find(cterm) == d_extf_collapse_infer.end())
+      {
+        d_extf_collapse_infer.insert(cterm);
+
+        Node t = n[0];
+        if (t.getType() == parent.getType())
+        {
+          if (n.getKind() == kind::INT_TO_BITVECTOR)
+          {
+            Assert(t.getType().isInteger());
+            // congruent modulo 2^( bv width )
+            unsigned bvs = n.getType().getBitVectorSize();
+            Node coeff = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs)));
+            Node k = nm->mkSkolem(
+                "int_bv_cong", t.getType(), "for int2bv/bv2nat congruence");
+            t = nm->mkNode(kind::PLUS, t, nm->mkNode(kind::MULT, coeff, k));
+          }
+          Node lem = parent.eqNode(t);
+
+          if (parent[0] != n)
+          {
+            Assert(ee->areEqual(parent[0], n));
+            lem = nm->mkNode(kind::IMPLIES, parent[0].eqNode(n), lem);
+          }
+          // this handles inferences of the form, e.g.:
+          //   ((_ int2bv w) (bv2nat x)) == x (if x is bit-width w)
+          //   (bv2nat ((_ int2bv w) x)) == x + k*2^w for some k
+          Trace("bv-extf-lemma")
+              << "BV extf lemma (collapse) : " << lem << std::endl;
+          d_bv->getOutputChannel().lemma(lem);
+          sentLemma = true;
+        }
+      }
+      Trace("bv-extf-lemma-debug")
+          << "BV extf f collapse based on : " << cterm << std::endl;
+    }
+  }
+  return sentLemma;
+}
+
+bool CoreSolver::doExtfReductions(std::vector<Node>& terms)
+{
+  std::vector<Node> nredr;
+  if (d_extTheory->doReductions(0, terms, nredr))
+  {
+    return true;
+  }
+  Assert(nredr.empty());
+  return false;
+}
index 3818046812dd1c861a7eb3f828a0378ad084d808..32bc36164c0ea4ff477392a4c1eeac81ca6b5538 100644 (file)
@@ -31,6 +31,23 @@ namespace theory {
 namespace bv {
 
 class Base;
+
+/** An extended theory callback used by the core solver */
+class CoreSolverExtTheoryCallback : public ExtTheoryCallback
+{
+ public:
+  CoreSolverExtTheoryCallback() : d_equalityEngine(nullptr) {}
+  /** Get current substitution based on the underlying equality engine. */
+  bool getCurrentSubstitution(int effort,
+                              const std::vector<Node>& vars,
+                              std::vector<Node>& subs,
+                              std::map<Node, std::vector<Node> >& exp) override;
+  /** Get reduction. */
+  bool getReduction(int effort, Node n, Node& nr, bool& satDep) override;
+  /** The underlying equality engine */
+  eq::EqualityEngine* d_equalityEngine;
+};
+
 /**
  * Bitvector equality solver
  */
@@ -83,8 +100,10 @@ class CoreSolver : public SubtheorySolver {
   TheoryBV* d_bv;
   /** Pointer to the equality engine of the parent */
   eq::EqualityEngine* d_equalityEngine;
-  /** Pointer to the extended theory module. */
-  ExtTheory* d_extTheory;
+  /** The extended theory callback */
+  CoreSolverExtTheoryCallback d_extTheoryCb;
+  /** Extended theory module, for context-dependent simplification. */
+  std::unique_ptr<ExtTheory> d_extTheory;
 
   /** To make sure we keep the explanations */
   context::CDHashSet<Node, NodeHashFunction> d_reasons;
@@ -96,8 +115,38 @@ class CoreSolver : public SubtheorySolver {
   bool isCompleteForTerm(TNode term, TNodeBoolMap& seen);
   Statistics d_statistics;
 
+  /** Whether we need a last call effort check */
+  bool d_needsLastCallCheck;
+  /** For extended functions */
+  context::CDHashSet<Node, NodeHashFunction> d_extf_range_infer;
+  context::CDHashSet<Node, NodeHashFunction> d_extf_collapse_infer;
+
+  /** do extended function inferences
+   *
+   * This method adds lemmas on the output channel of TheoryBV based on
+   * reasoning about extended functions, such as bv2nat and int2bv. Examples
+   * of lemmas added by this method include:
+   *   0 <= ((_ int2bv w) x) < 2^w
+   *   ((_ int2bv w) (bv2nat x)) = x
+   *   (bv2nat ((_ int2bv w) x)) == x + k*2^w
+   * The purpose of these lemmas is to recognize easy conflicts before fully
+   * reducing extended functions based on their full semantics.
+   */
+  bool doExtfInferences(std::vector<Node>& terms);
+  /** do extended function reductions
+   *
+   * This method adds lemmas on the output channel of TheoryBV based on
+   * reducing all extended function applications that are preregistered to
+   * this theory and have not already been reduced by context-dependent
+   * simplification (see theory/ext_theory.h). Examples of lemmas added by
+   * this method include:
+   *   (bv2nat x) = (ite ((_ extract w w-1) x) 2^{w-1} 0) + ... +
+   *                (ite ((_ extract 1 0) x) 1 0)
+   */
+  bool doExtfReductions(std::vector<Node>& terms);
+
  public:
-  CoreSolver(context::Context* c, TheoryBV* bv, ExtTheory* extt);
+  CoreSolver(context::Context* c, TheoryBV* bv);
   ~CoreSolver();
   bool needsEqualityEngine(EeSetupInfo& esi);
   void finishInit();
@@ -111,9 +160,11 @@ class CoreSolver : public SubtheorySolver {
   EqualityStatus getEqualityStatus(TNode a, TNode b) override;
   bool hasTerm(TNode node) const;
   void addTermToEqualityEngine(TNode node);
+  /** check extended functions at the given effort */
+  void checkExtf(Theory::Effort e);
+  bool needsCheckLastEffort() const;
 };
 
-
 }
 }
 }
index d6492f177c083586aa573ec884f2a6e7af314df5..815656d8fd3ed688c9c89e579eb2e0ddf739c3d6 100644 (file)
@@ -29,7 +29,6 @@
 #include "theory/bv/theory_bv_rewrite_rules_simplification.h"
 #include "theory/bv/theory_bv_rewriter.h"
 #include "theory/bv/theory_bv_utils.h"
-#include "theory/ext_theory.h"
 #include "theory/theory_model.h"
 #include "theory/valuation.h"
 
@@ -63,18 +62,12 @@ TheoryBV::TheoryBV(context::Context* c,
       d_invalidateModelCache(c, true),
       d_literalsToPropagate(c),
       d_literalsToPropagateIndex(c, 0),
-      d_extTheory(new ExtTheory(this)),
       d_propagatedBy(c),
       d_eagerSolver(),
       d_abstractionModule(new AbstractionModule(getStatsPrefix(THEORY_BV))),
       d_calledPreregister(false),
-      d_needsLastCallCheck(false),
-      d_extf_range_infer(u),
-      d_extf_collapse_infer(u),
       d_state(c, u, valuation)
 {
-  d_extTheory->addFunctionKind(kind::BITVECTOR_TO_NAT);
-  d_extTheory->addFunctionKind(kind::INT_TO_BITVECTOR);
   if (options::bitblastMode() == options::BitblastMode::EAGER)
   {
     d_eagerSolver.reset(new EagerBitblastSolver(c, this));
@@ -83,7 +76,7 @@ TheoryBV::TheoryBV(context::Context* c,
 
   if (options::bitvectorEqualitySolver())
   {
-    d_subtheories.emplace_back(new CoreSolver(c, this, d_extTheory.get()));
+    d_subtheories.emplace_back(new CoreSolver(c, this));
     d_subtheoryMap[SUB_CORE] = d_subtheories.back().get();
   }
 
@@ -331,8 +324,12 @@ void TheoryBV::check(Effort e)
 
   //last call : do reductions on extended bitvector functions
   if (e == Theory::EFFORT_LAST_CALL) {
-    std::vector<Node> nred = d_extTheory->getActive();
-    doExtfReductions(nred);
+    CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE];
+    if (core)
+    {
+      // check extended functions at last call effort
+      core->checkExtf(e);
+    }
     return;
   }
 
@@ -414,131 +411,24 @@ void TheoryBV::check(Effort e)
 
   //check extended functions
   if (Theory::fullEffort(e)) {
-    //do inferences (adds external lemmas)  TODO: this can be improved to add internal inferences
-    std::vector< Node > nred;
-    if (d_extTheory->doInferences(0, nred))
+    CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE];
+    if (core)
     {
-      return;
-    }
-    d_needsLastCallCheck = false;
-    if( !nred.empty() ){
-      //other inferences involving bv2nat, int2bv
-      if( options::bvAlgExtf() ){
-        if( doExtfInferences( nred ) ){
-          return;
-        }
-      }
-      if( !options::bvLazyReduceExtf() ){
-        if( doExtfReductions( nred ) ){
-          return;
-        }
-      }
-      else
-      {
-        d_needsLastCallCheck = true;
-      }
+      // check extended functions at full effort
+      core->checkExtf(e);
     }
   }
 }
 
-bool TheoryBV::doExtfInferences(std::vector<Node>& terms)
+bool TheoryBV::needsCheckLastEffort()
 {
-  NodeManager* nm = NodeManager::currentNM();
-  bool sentLemma = false;
-  eq::EqualityEngine* ee = getEqualityEngine();
-  std::map<Node, Node> op_map;
-  for (unsigned j = 0; j < terms.size(); j++)
-  {
-    TNode n = terms[j];
-    Assert(n.getKind() == kind::BITVECTOR_TO_NAT
-           || n.getKind() == kind::INT_TO_BITVECTOR);
-    if (n.getKind() == kind::BITVECTOR_TO_NAT)
-    {
-      // range lemmas
-      if (d_extf_range_infer.find(n) == d_extf_range_infer.end())
-      {
-        d_extf_range_infer.insert(n);
-        unsigned bvs = n[0].getType().getBitVectorSize();
-        Node min = nm->mkConst(Rational(0));
-        Node max = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs)));
-        Node lem = nm->mkNode(kind::AND,
-                              nm->mkNode(kind::GEQ, n, min),
-                              nm->mkNode(kind::LT, n, max));
-        Trace("bv-extf-lemma")
-            << "BV extf lemma (range) : " << lem << std::endl;
-        d_out->lemma(lem);
-        sentLemma = true;
-      }
-    }
-    Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n[0]) : n[0];
-    op_map[r] = n;
-  }
-  for (unsigned j = 0; j < terms.size(); j++)
-  {
-    TNode n = terms[j];
-    Node r = (ee && ee->hasTerm(n[0])) ? ee->getRepresentative(n) : n;
-    std::map<Node, Node>::iterator it = op_map.find(r);
-    if (it != op_map.end())
-    {
-      Node parent = it->second;
-      // Node cterm = parent[0]==n ? parent : nm->mkNode( parent.getOperator(),
-      // n );
-      Node cterm = parent[0].eqNode(n);
-      Trace("bv-extf-lemma-debug")
-          << "BV extf collapse based on : " << cterm << std::endl;
-      if (d_extf_collapse_infer.find(cterm) == d_extf_collapse_infer.end())
-      {
-        d_extf_collapse_infer.insert(cterm);
-
-        Node t = n[0];
-        if (t.getType() == parent.getType())
-        {
-          if (n.getKind() == kind::INT_TO_BITVECTOR)
-          {
-            Assert(t.getType().isInteger());
-            // congruent modulo 2^( bv width )
-            unsigned bvs = n.getType().getBitVectorSize();
-            Node coeff = nm->mkConst(Rational(Integer(1).multiplyByPow2(bvs)));
-            Node k = nm->mkSkolem(
-                "int_bv_cong", t.getType(), "for int2bv/bv2nat congruence");
-            t = nm->mkNode(kind::PLUS, t, nm->mkNode(kind::MULT, coeff, k));
-          }
-          Node lem = parent.eqNode(t);
-
-          if (parent[0] != n)
-          {
-            Assert(ee->areEqual(parent[0], n));
-            lem = nm->mkNode(kind::IMPLIES, parent[0].eqNode(n), lem);
-          }
-          // this handles inferences of the form, e.g.:
-          //   ((_ int2bv w) (bv2nat x)) == x (if x is bit-width w)
-          //   (bv2nat ((_ int2bv w) x)) == x + k*2^w for some k
-          Trace("bv-extf-lemma")
-              << "BV extf lemma (collapse) : " << lem << std::endl;
-          d_out->lemma(lem);
-          sentLemma = true;
-        }
-      }
-      Trace("bv-extf-lemma-debug")
-          << "BV extf f collapse based on : " << cterm << std::endl;
-    }
-  }
-  return sentLemma;
-}
-
-bool TheoryBV::doExtfReductions( std::vector< Node >& terms ) {
-  std::vector< Node > nredr;
-  if (d_extTheory->doReductions(0, terms, nredr))
+  CoreSolver* core = (CoreSolver*)d_subtheoryMap[SUB_CORE];
+  if (core)
   {
-    return true;
+    return core->needsCheckLastEffort();
   }
-  Assert(nredr.empty());
   return false;
 }
-
-bool TheoryBV::needsCheckLastEffort() {
-  return d_needsLastCallCheck;
-}
 bool TheoryBV::collectModelInfo(TheoryModel* m)
 {
   Assert(!inConflict());
@@ -595,48 +485,6 @@ void TheoryBV::propagate(Effort e) {
   }
 }
 
-bool TheoryBV::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) {
-  eq::EqualityEngine * ee = getEqualityEngine();
-  if( ee ){
-    //get the constant equivalence classes
-    bool retVal = false;
-    for( unsigned i=0; i<vars.size(); i++ ){
-      Node n = vars[i];
-      if( ee->hasTerm( n ) ){
-        Node nr = ee->getRepresentative( n );
-        if( nr.isConst() ){
-          subs.push_back( nr );
-          exp[n].push_back( n.eqNode( nr ) );
-          retVal = true;
-        }else{
-          subs.push_back( n );
-        }
-      }else{
-        subs.push_back( n );
-      }
-    }
-    //return true if the substitution is non-trivial
-    return retVal;
-  }
-  return false;
-}
-
-int TheoryBV::getReduction(int effort, Node n, Node& nr)
-{
-  Trace("bv-ext") << "TheoryBV::checkExt : non-reduced : " << n << std::endl;
-  if (n.getKind() == kind::BITVECTOR_TO_NAT)
-  {
-    nr = utils::eliminateBv2Nat(n);
-    return -1;
-  }
-  else if (n.getKind() == kind::INT_TO_BITVECTOR)
-  {
-    nr = utils::eliminateInt2Bv(n);
-    return -1;
-  }
-  return 0;
-}
-
 Theory::PPAssertStatus TheoryBV::ppAssert(TNode in,
                                           SubstitutionMap& outSubstitutions)
 {
index 2f63f1a5229e26bd6f7dca760f11d5e56260f4ea..7475fecccb261aff6deb288c18a95f15d487583e 100644 (file)
 #include "util/statistics_registry.h"
 
 namespace CVC4 {
-
 namespace theory {
-
-class ExtTheory;
-
 namespace bv {
 
 class CoreSolver;
@@ -101,12 +97,6 @@ class TheoryBV : public Theory {
 
   std::string identify() const override { return std::string("TheoryBV"); }
 
-  bool getCurrentSubstitution(int effort,
-                              std::vector<Node>& vars,
-                              std::vector<Node>& subs,
-                              std::map<Node, std::vector<Node>>& exp) override;
-  int getReduction(int effort, Node n, Node& nr) override;
-
   PPAssertStatus ppAssert(TNode in, SubstitutionMap& outSubstitutions) override;
 
   TrustNode ppRewrite(TNode t) override;
@@ -177,9 +167,6 @@ class TheoryBV : public Theory {
   /** Index of the next literal to propagate */
   context::CDO<unsigned> d_literalsToPropagateIndex;
 
-  /** Extended theory module, for context-dependent simplification. */
-  std::unique_ptr<ExtTheory> d_extTheory;
-
   /**
    * Keeps a map from nodes to the subtheory that propagated it so that we can explain it
    * properly.
@@ -191,34 +178,6 @@ class TheoryBV : public Theory {
   std::unique_ptr<AbstractionModule> d_abstractionModule;
   bool d_calledPreregister;
 
-  //for extended functions
-  bool d_needsLastCallCheck;
-  context::CDHashSet<Node, NodeHashFunction> d_extf_range_infer;
-  context::CDHashSet<Node, NodeHashFunction> d_extf_collapse_infer;
-  /** do extended function inferences
-   *
-   * This method adds lemmas on the output channel of TheoryBV based on
-   * reasoning about extended functions, such as bv2nat and int2bv. Examples
-   * of lemmas added by this method include:
-   *   0 <= ((_ int2bv w) x) < 2^w
-   *   ((_ int2bv w) (bv2nat x)) = x
-   *   (bv2nat ((_ int2bv w) x)) == x + k*2^w
-   * The purpose of these lemmas is to recognize easy conflicts before fully
-   * reducing extended functions based on their full semantics.
-   */
-  bool doExtfInferences( std::vector< Node >& terms );
-  /** do extended function reductions
-   *
-   * This method adds lemmas on the output channel of TheoryBV based on
-   * reducing all extended function applications that are preregistered to
-   * this theory and have not already been reduced by context-dependent
-   * simplification (see theory/ext_theory.h). Examples of lemmas added by
-   * this method include:
-   *   (bv2nat x) = (ite ((_ extract w w-1) x) 2^{w-1} 0) + ... +
-   *                (ite ((_ extract 1 0) x) 1 0)
-   */
-  bool doExtfReductions( std::vector< Node >& terms );
-
   bool wasPropagatedBySubtheory(TNode literal) const {
     return d_propagatedBy.find(literal) != d_propagatedBy.end();
   }
index 5253414a9fc0533b24872c6c91a93507c2d7d3f5..585f13d822b581dfb7fde48e09031a4ae33af005 100644 (file)
@@ -1965,10 +1965,6 @@ TNode TheoryDatatypes::getRepresentative( TNode a ){
   }
 }
 
-bool TheoryDatatypes::getCurrentSubstitution( int effort, std::vector< Node >& vars, std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) {
-  return false;
-}
-
 void TheoryDatatypes::printModelDebug( const char* c ){
   if(! (Trace.isOn(c))) {
     return;
index 37a4f81f7963bad8399505aa0b7f49b1d97c32bf..bf5d33177d9c62202e846a7283b1df6865784cb3 100644 (file)
@@ -273,10 +273,6 @@ private:
   {
     return std::string("TheoryDatatypes");
   }
-  bool getCurrentSubstitution(int effort,
-                              std::vector<Node>& vars,
-                              std::vector<Node>& subs,
-                              std::map<Node, std::vector<Node> >& exp) override;
   /** debug print */
   void printModelDebug( const char* c );
   /** entailment check */
index bdcd5dcff8c954746e0a5d66295818b17ddfeba1..e8ed60ae4c01709a49c595fb08aa0feaffb77e38 100644 (file)
@@ -28,13 +28,41 @@ using namespace std;
 namespace CVC4 {
 namespace theory {
 
-ExtTheory::ExtTheory(Theory* p, bool cacheEnabled)
+bool ExtTheoryCallback::getCurrentSubstitution(
+    int effort,
+    const std::vector<Node>& vars,
+    std::vector<Node>& subs,
+    std::map<Node, std::vector<Node> >& exp)
+{
+  return false;
+}
+bool ExtTheoryCallback::isExtfReduced(int effort,
+                                      Node n,
+                                      Node on,
+                                      std::vector<Node>& exp)
+{
+  return n.isConst();
+}
+bool ExtTheoryCallback::getReduction(int effort,
+                                    Node n,
+                                    Node& nr,
+                                    bool& isSatDep)
+{
+  return false;
+}
+
+ExtTheory::ExtTheory(ExtTheoryCallback& p,
+                     context::Context* c,
+                     context::UserContext* u,
+                     OutputChannel& out,
+                     bool cacheEnabled)
     : d_parent(p),
-      d_ext_func_terms(p->getSatContext()),
-      d_ci_inactive(p->getUserContext()),
-      d_has_extf(p->getSatContext()),
-      d_lemmas(p->getUserContext()),
-      d_pp_lemmas(p->getUserContext()),
+      d_out(out),
+      d_ext_func_terms(c),
+      d_ci_inactive(u),
+      d_has_extf(c),
+      d_lemmas(u),
+      d_pp_lemmas(u),
       d_cacheEnabled(cacheEnabled)
 {
   d_true = NodeManager::currentNM()->mkConst(true);
@@ -61,7 +89,6 @@ std::vector<Node> ExtTheory::collectVars(Node n)
     // (commented below)
     if (current.getNumChildren() > 0)
     {
-      //&& Theory::theoryOf(n)==d_parent->getId() ){
       worklist.insert(worklist.end(), current.begin(), current.end());
     }
     else
@@ -140,7 +167,7 @@ void ExtTheory::getSubstitutedTerms(int effort,
           }
         }
       }
-      bool useSubs = d_parent->getCurrentSubstitution(effort, vars, sub, expc);
+      bool useSubs = d_parent.getCurrentSubstitution(effort, vars, sub, expc);
       // get the current substitution for all variables
       Assert(!useSubs || vars.size() == sub.size());
       for (const Node& n : terms)
@@ -206,8 +233,8 @@ bool ExtTheory::doInferencesInternal(int effort,
       {
         Node nr;
         // note: could do reduction with substitution here
-        int ret = d_parent->getReduction(effort, n, nr);
-        if (ret == 0)
+        bool satDep = false;
+        if (!d_parent.getReduction(effort, n, nr, satDep))
         {
           nred.push_back(n);
         }
@@ -223,7 +250,7 @@ bool ExtTheory::doInferencesInternal(int effort,
               addedLemma = true;
             }
           }
-          markReduced(n, ret < 0);
+          markReduced(n, satDep);
         }
       }
     }
@@ -242,7 +269,7 @@ bool ExtTheory::doInferencesInternal(int effort,
           Node sr = Rewriter::rewrite(sterms[i]);
           // ask the theory if this term is reduced, e.g. is it constant or it
           // is a non-extf term.
-          if (d_parent->isExtfReduced(effort, sr, terms[i], exp[i]))
+          if (d_parent.isExtfReduced(effort, sr, terms[i], exp[i]))
           {
             processed = true;
             markReduced(terms[i]);
@@ -344,7 +371,7 @@ bool ExtTheory::sendLemma(Node lem, bool preprocess)
     if (d_pp_lemmas.find(lem) == d_pp_lemmas.end())
     {
       d_pp_lemmas.insert(lem);
-      d_parent->getOutputChannel().lemma(lem, LemmaProperty::PREPROCESS);
+      d_out.lemma(lem, LemmaProperty::PREPROCESS);
       return true;
     }
   }
@@ -353,7 +380,7 @@ bool ExtTheory::sendLemma(Node lem, bool preprocess)
     if (d_lemmas.find(lem) == d_lemmas.end())
     {
       d_lemmas.insert(lem);
-      d_parent->getOutputChannel().lemma(lem);
+      d_out.lemma(lem);
       return true;
     }
   }
@@ -403,8 +430,7 @@ void ExtTheory::registerTerm(Node n)
   {
     if (d_ext_func_terms.find(n) == d_ext_func_terms.end())
     {
-      Trace("extt-debug") << "Found extended function : " << n << " in "
-                          << d_parent->getId() << std::endl;
+      Trace("extt-debug") << "Found extended function : " << n << std::endl;
       d_ext_func_terms[n] = true;
       d_has_extf = n;
       d_extf_info[n].d_vars = collectVars(n);
@@ -435,13 +461,13 @@ void ExtTheory::registerTermRec(Node n)
 }
 
 // mark reduced
-void ExtTheory::markReduced(Node n, bool contextDepend)
+void ExtTheory::markReduced(Node n, bool satDep)
 {
   Trace("extt-debug") << "Mark reduced " << n << std::endl;
   registerTerm(n);
   Assert(d_ext_func_terms.find(n) != d_ext_func_terms.end());
   d_ext_func_terms[n] = false;
-  if (!contextDepend)
+  if (!satDep)
   {
     d_ci_inactive.insert(n);
   }
index 2721bc89ebf6dc186be3d0a841565432dbe031e0..efd24e2c852b79f45eb9aa1fecc79b6599244cb1 100644 (file)
 namespace CVC4 {
 namespace theory {
 
+/**
+ * A callback class for ExtTheory below. This class is responsible for
+ * determining how to apply context-dependent simplification.
+ */
+class ExtTheoryCallback
+{
+ public:
+  virtual ~ExtTheoryCallback() {}
+  /*
+   * Get current substitution at an effort
+   * @param effort The effort identifier
+   * @param vars The variables to get a substitution for
+   * @param subs The terms to substitute for variables, in order. This vector
+   * should be updated to one the same size as vars.
+   * @param exp The map containing the explanation for each variable. Together
+   * with subs, we have that:
+   *   ( exp[vars[i]] => vars[i] = subs[i] ) holds for all i
+   * @return true if any (non-identity) substitution was added to subs.
+   */
+  virtual bool getCurrentSubstitution(int effort,
+                                      const std::vector<Node>& vars,
+                                      std::vector<Node>& subs,
+                                      std::map<Node, std::vector<Node> >& exp);
+
+  /*
+   * Is extended function n reduced? This returns true if n is reduced to a
+   * form that requires no further interaction from the theory.
+   *
+   * @param effort The effort identifier
+   * @param n The term to reduce
+   * @param on The original form of n, before substitution
+   * @param exp The explanation of on = n
+   * @return true if n is reduced.
+   */
+  virtual bool isExtfReduced(int effort,
+                             Node n,
+                             Node on,
+                             std::vector<Node>& exp);
+
+  /**
+   * Get reduction for node n.
+   * If return value is true, then n is reduced.
+   * If satDep is updated to false, then n is reduced independent of the
+   * SAT context (e.g. by a lemma that persists at this
+   * user-context level).
+   * If nr is non-null, then ( n = nr ) should be added as a lemma by caller,
+   * and return value of this method should be true.
+   */
+  virtual bool getReduction(int effort, Node n, Node& nr, bool& satDep);
+};
+
 /** Extended theory class
  *
  * This class is used for constructing generic extensions to theory solvers.
@@ -73,7 +124,11 @@ class ExtTheory
    *
    * If cacheEnabled is false, we do not cache results of getSubstitutedTerm.
    */
-  ExtTheory(Theory* p, bool cacheEnabled = false);
+  ExtTheory(ExtTheoryCallback& p,
+            context::Context* c,
+            context::UserContext* u,
+            OutputChannel& out,
+            bool cacheEnabled = false);
   virtual ~ExtTheory() {}
   /** Tells this class to treat terms with Kind k as extended functions */
   void addFunctionKind(Kind k) { d_extf_kind[k] = true; }
@@ -93,10 +148,10 @@ class ExtTheory
   void registerTermRec(Node n);
   /** set n as reduced/inactive
    *
-   * If contextDepend = false, then n remains inactive in the duration of this
+   * If satDep = false, then n remains inactive in the duration of this
    * user-context level
    */
-  void markReduced(Node n, bool contextDepend = true);
+  void markReduced(Node n, bool satDep = true);
   /**
    * Mark that a and b are congruent terms. This sets b inactive, and sets a to
    * inactive if b was inactive.
@@ -194,10 +249,12 @@ class ExtTheory
                             std::vector<Node>& nred,
                             bool batch,
                             bool isRed);
-  /** send lemma on the output channel of d_parent */
+  /** send lemma on the output channel */
   bool sendLemma(Node lem, bool preprocess = false);
-  /** reference to the underlying theory */
-  Theory* d_parent;
+  /** reference to the callback */
+  ExtTheoryCallback& d_parent;
+  /** Reference to the output channel we are using */
+  OutputChannel& d_out;
   /** the true node */
   Node d_true;
   /** extended function terms, map to whether they are active */
index b028da38a62d48b19def1d0b53cb5e104c05d0fb..6fcd5785daa6cd6bd9f8910a049a5176e3437294 100644 (file)
@@ -700,6 +700,23 @@ std::vector<Node> ExtfSolver::getActive(Kind k) const
   return d_extt.getActive(k);
 }
 
+bool StringsExtfCallback::getCurrentSubstitution(
+    int effort,
+    const std::vector<Node>& vars,
+    std::vector<Node>& subs,
+    std::map<Node, std::vector<Node> >& exp)
+{
+  Trace("strings-subs") << "getCurrentSubstitution, effort = " << effort
+                        << std::endl;
+  for (const Node& v : vars)
+  {
+    Trace("strings-subs") << "  get subs for " << v << "..." << std::endl;
+    Node s = d_esolver->getCurrentSubstitutionFor(effort, v, exp[v]);
+    subs.push_back(s);
+  }
+  return true;
+}
+
 }  // namespace strings
 }  // namespace theory
 }  // namespace CVC4
index 4ba38bfc647ab9b36f05ac870751a4c4c2995d76..5b11b6faff624a8c0538384f456ceac2198ea76a 100644 (file)
@@ -214,6 +214,23 @@ class ExtfSolver
   NodeSet d_reduced;
 };
 
+/** An extended theory callback */
+class StringsExtfCallback : public ExtTheoryCallback
+{
+ public:
+  StringsExtfCallback() : d_esolver(nullptr) {}
+  /**
+   * Get current substitution based on the underlying extended function
+   * solver.
+   */
+  bool getCurrentSubstitution(int effort,
+                              const std::vector<Node>& vars,
+                              std::vector<Node>& subs,
+                              std::map<Node, std::vector<Node> >& exp) override;
+  /** The extended function solver */
+  ExtfSolver* d_esolver;
+};
+
 }  // namespace strings
 }  // namespace theory
 }  // namespace CVC4
index 3e60cbc445ce5b52698f91e11bac28cd56728830..f248cb330370b4006832eb2135e270f1ea4bd80b 100644 (file)
@@ -45,7 +45,8 @@ TheoryStrings::TheoryStrings(context::Context* c,
       d_statistics(),
       d_state(c, u, d_valuation),
       d_termReg(d_state, out, d_statistics, nullptr),
-      d_extTheory(this),
+      d_extTheoryCb(),
+      d_extTheory(d_extTheoryCb, c, u, out),
       d_im(*this, d_state, d_termReg, d_extTheory, d_statistics, pnm),
       d_rewriter(&d_statistics.d_rewrites),
       d_bsolver(d_state, d_im),
@@ -75,6 +76,9 @@ TheoryStrings::TheoryStrings(context::Context* c,
 
   d_cardSize = utils::getAlphabetCardinality();
 
+  // set up the extended function callback
+  d_extTheoryCb.d_esolver = &d_esolver;
+
   ProofChecker* pc = pnm != nullptr ? pnm->getChecker() : nullptr;
   if (pc != nullptr)
   {
@@ -202,18 +206,6 @@ TrustNode TheoryStrings::explain(TNode literal)
   return TrustNode::mkTrustPropExp(literal, ret, nullptr);
 }
 
-bool TheoryStrings::getCurrentSubstitution( int effort, std::vector< Node >& vars, 
-                                            std::vector< Node >& subs, std::map< Node, std::vector< Node > >& exp ) {
-  Trace("strings-subs") << "getCurrentSubstitution, effort = " << effort << std::endl;
-  for( unsigned i=0; i<vars.size(); i++ ){
-    Node n = vars[i];
-    Trace("strings-subs") << "  get subs for " << n << "..." << std::endl;
-    Node s = d_esolver.getCurrentSubstitutionFor(effort, n, exp[n]);
-    subs.push_back(s);
-  }
-  return true;
-}
-
 void TheoryStrings::presolve() {
   Debug("strings-presolve") << "TheoryStrings::Presolving : get fmf options " << (options::stringFMF() ? "true" : "false") << std::endl;
   d_strat.initializeStrategy();
index 0f59e73dcc94cc8882912560fd6e5a27e45f726e..cbe6000bfe4f628bc9b73b20ab2a7baa0ec09c06 100644 (file)
@@ -86,11 +86,6 @@ class TheoryStrings : public Theory {
   std::string identify() const override;
   /** Explain */
   TrustNode explain(TNode literal) override;
-  /** Get current substitution */
-  bool getCurrentSubstitution(int effort,
-                              std::vector<Node>& vars,
-                              std::vector<Node>& subs,
-                              std::map<Node, std::vector<Node> >& exp) override;
   /** presolve */
   void presolve() override;
   /** shutdown */
@@ -262,6 +257,8 @@ class TheoryStrings : public Theory {
   SolverState d_state;
   /** The term registry for this theory */
   TermRegistry d_termReg;
+  /** The extended theory callback */
+  StringsExtfCallback d_extTheoryCb;
   /** Extended theory, responsible for context-dependent simplification. */
   ExtTheory d_extTheory;
   /** The (custom) output channel of the theory of strings */
index 176d4b6726cfe0bf79c016a3a3f456f4d70d9f39..77652f874b882823270087a67d9ffefaaa9fdfb8 100644 (file)
@@ -897,30 +897,6 @@ class Theory {
    * E |= lit in the theory.
    */
   virtual std::pair<bool, Node> entailmentCheck(TNode lit);
-
-  /* get current substitution at an effort
-   *   input : vars
-   *   output : subs, exp
-   *   where ( exp[vars[i]] => vars[i] = subs[i] ) holds for all i
-   */
-  virtual bool getCurrentSubstitution(int effort, std::vector<Node>& vars,
-                                      std::vector<Node>& subs,
-                                      std::map<Node, std::vector<Node> >& exp) {
-    return false;
-  }
-
-  /* is extended function reduced */
-  virtual bool isExtfReduced( int effort, Node n, Node on, std::vector< Node >& exp ) { return n.isConst(); }
-
-  /**
-   * Get reduction for node
-   * If return value is not 0, then n is reduced.
-   * If return value <0 then n is reduced SAT-context-independently (e.g. by a
-   *  lemma that persists at this user-context level).
-   * If nr is non-null, then ( n = nr ) should be added as a lemma by caller,
-   *  and return value should be <0.
-   */
-  virtual int getReduction( int effort, Node n, Node& nr ) { return 0; }
 };/* class Theory */
 
 std::ostream& operator<<(std::ostream& os, theory::Theory::Effort level);