Migrate basic EqualityEngine management from CongruenceManager to EqSolver (#8684)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 3 May 2022 00:23:04 +0000 (19:23 -0500)
committerGitHub <noreply@github.com>
Tue, 3 May 2022 00:23:04 +0000 (00:23 +0000)
This is work towards having the linear arithmetic solver not impose restrictions on equalities.

The linear arithmetic solver using a CongruenceManager which involves many non-standard uses of the equality engine.

The responsibilities of the CongruenceManager should be migrated to the arithmetic EqSolver, which manages the equality engine in the default way.

This PR is the first step. It makes it so that the memory management and notifications of the equality engine are now solely the responsibility of the EqSolver.

All relevant notifications from the EqSolver are directly forwarded to CongruenceManager. Thus there are no significant behavior changes in this PR.

This PR required removing the experimental option arithCongMan, which forces having the CongruenceManager and the EqSolver both use equality engines.

src/options/arith_options.toml
src/smt/set_defaults.cpp
src/theory/arith/equality_solver.cpp
src/theory/arith/equality_solver.h
src/theory/arith/linear/congruence_manager.cpp
src/theory/arith/linear/congruence_manager.h
src/theory/arith/linear/theory_arith_private.cpp
src/theory/arith/linear/theory_arith_private.h
src/theory/arith/theory_arith.cpp

index 4461a69df2cddebe6baab8a48f1d33cbbf92de3a..7bce3bca535eabfe0135b94f75c6eff5fb3829a7 100644 (file)
@@ -590,11 +590,3 @@ name   = "Arithmetic Theory"
   type       = "bool"
   default    = "false"
   help       = "whether to use the equality solver in the theory of arithmetic"
-
-[[option]]
-  name       = "arithCongMan"
-  category   = "expert"
-  long       = "arith-cong-man"
-  type       = "bool"
-  default    = "true"
-  help       = "(experimental) whether to use the congruence manager when the equality solver is enabled"
index e284597081a23339feef9095ae24ace68ef0dacb..32fd86f2460652b3c8c7f2da17032c614b12288a 100644 (file)
@@ -686,15 +686,6 @@ void SetDefaults::setDefaultsPost(const LogicInfo& logic, Options& opts) const
       opts.arith.arithEqSolver = true;
     }
   }
-  if (opts.arith.arithEqSolver)
-  {
-    if (!opts.arith.arithCongManWasSetByUser)
-    {
-      // if we are using the arithmetic equality solver, do not use the
-      // arithmetic congruence manager by default
-      opts.arith.arithCongMan = false;
-    }
-  }
 
   if (logic.isHigherOrder())
   {
index 30abe534bb663445e4cbd26432675c0811e32383..aa3e31fac43d20f5c3c5fd01b36abbced723b40c 100644 (file)
@@ -16,6 +16,7 @@
 #include "theory/arith/equality_solver.h"
 
 #include "theory/arith/inference_manager.h"
+#include "theory/arith/linear/congruence_manager.h"
 
 using namespace cvc5::internal::kind;
 
@@ -31,7 +32,8 @@ EqualitySolver::EqualitySolver(Env& env,
       d_aim(aim),
       d_notify(*this),
       d_ee(nullptr),
-      d_propLits(context())
+      d_propLits(context()),
+      d_acm(nullptr)
 {
 }
 
@@ -81,8 +83,19 @@ TrustNode EqualitySolver::explain(TNode lit)
   // if we did, explain with the arithmetic inference manager
   return d_aim.explainLit(lit);
 }
+
+void EqualitySolver::setCongruenceManager(linear::ArithCongruenceManager* acm)
+{
+  d_acm = acm;
+}
+
 bool EqualitySolver::propagateLit(Node lit)
 {
+  if (d_acm != nullptr)
+  {
+    // if we are using the congruence manager, notify it
+    return d_acm->propagate(lit);
+  }
   // if we've already propagated, ignore
   if (d_aim.hasPropagated(lit))
   {
@@ -96,6 +109,13 @@ bool EqualitySolver::propagateLit(Node lit)
 }
 void EqualitySolver::conflictEqConstantMerge(TNode a, TNode b)
 {
+  if (d_acm != nullptr)
+  {
+    // if we are using the congruence manager, notify it
+    Node eq = a.eqNode(b);
+    d_acm->propagate(eq);
+    return;
+  }
   d_aim.conflictEqConstantMerge(a, b);
 }
 
index 5faf90dc997dd6d3375880df76304889d71c32fe..dcc6ccb54e7fcbf26fe34c3f97d5286c610a915f 100644 (file)
@@ -32,9 +32,15 @@ namespace arith {
 
 class InferenceManager;
 
+namespace linear {
+class ArithCongruenceManager;
+}
+
 /**
  * The arithmetic equality solver. This class manages arithmetic equalities
- * in the default way via an equality engine.
+ * in the default way via an equality engine, or defers to the congruence
+ * manager of linear arithmetic if setCongruenceManager is called on a
+ * non-null congruence manager.
  *
  * Since arithmetic has multiple ways of propagating literals, it tracks
  * the literals that it propagates and only explains the literals that
@@ -70,6 +76,9 @@ class EqualitySolver : protected EnvObj
    */
   TrustNode explain(TNode lit);
 
+  /** Set the congruence manager, which will be notified of propagations */
+  void setCongruenceManager(linear::ArithCongruenceManager* acm);
+
  private:
   /** Notification class from the equality engine */
   class EqualitySolverNotify : public eq::EqualityEngineNotify
@@ -107,6 +116,8 @@ class EqualitySolver : protected EnvObj
   eq::EqualityEngine* d_ee;
   /** The literals we have propagated */
   NodeSet d_propLits;
+  /** Pointer to the congruence manager, for notifications of propagations */
+  linear::ArithCongruenceManager* d_acm;
 };
 
 }  // namespace arith
index da4d81aa7ea8bee2b1ac3501cb16e1b1da17b342..017fef67548743fb66c4f754e4c3c77b5012d774 100644 (file)
@@ -47,7 +47,6 @@ ArithCongruenceManager::ArithCongruenceManager(
     : EnvObj(env),
       d_inConflict(context()),
       d_raiseConflict(raiseConflict),
-      d_notify(*this),
       d_keepAlive(context()),
       d_propagatations(context()),
       d_explanationMap(context()),
@@ -70,41 +69,12 @@ ArithCongruenceManager::ArithCongruenceManager(
 
 ArithCongruenceManager::~ArithCongruenceManager() {}
 
-bool ArithCongruenceManager::needsEqualityEngine(EeSetupInfo& esi)
-{
-  Assert(!options().arith.arithEqSolver);
-  esi.d_notify = &d_notify;
-  esi.d_name = "arithCong::ee";
-  return true;
-}
-
 void ArithCongruenceManager::finishInit(eq::EqualityEngine* ee)
 {
-  if (options().arith.arithEqSolver)
-  {
-    // use our own copy
-    d_allocEe = std::make_unique<eq::EqualityEngine>(
-        d_env, context(), d_notify, "arithCong::ee", true);
-    d_ee = d_allocEe.get();
-    if (d_pnm != nullptr)
-    {
-      // allocate an internal proof equality engine
-      d_allocPfee = std::make_unique<eq::ProofEqEngine>(d_env, *d_ee);
-      d_ee->setProofEqualityEngine(d_allocPfee.get());
-    }
-  }
-  else
-  {
-    Assert(ee != nullptr);
-    // otherwise, we use the official one
-    d_ee = ee;
-  }
-  // set the congruence kinds on the separate equality engine
-  d_ee->addFunctionKind(kind::NONLINEAR_MULT);
-  d_ee->addFunctionKind(kind::EXPONENTIAL);
-  d_ee->addFunctionKind(kind::SINE);
-  d_ee->addFunctionKind(kind::IAND);
-  d_ee->addFunctionKind(kind::POW2);
+  Assert(ee != nullptr);
+  // otherwise, we use the official one
+  d_ee = ee;
+  // the congruence kinds are already set up
   // the proof equality engine is the one from the equality engine
   d_pfee = d_ee->getProofEqualityEngine();
   // have proof equality engine only if proofs are enabled
@@ -129,44 +99,6 @@ ArithCongruenceManager::Statistics::Statistics()
 {
 }
 
-ArithCongruenceManager::ArithCongruenceNotify::ArithCongruenceNotify(ArithCongruenceManager& acm)
-  : d_acm(acm)
-{}
-
-bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerPredicate(
-    TNode predicate, bool value)
-{
-  Assert(predicate.getKind() == kind::EQUAL);
-  Trace("arith::congruences")
-      << "ArithCongruenceNotify::eqNotifyTriggerPredicate(" << predicate << ", "
-      << (value ? "true" : "false") << ")" << std::endl;
-  if (value) {
-    return d_acm.propagate(predicate);
-  }
-  return d_acm.propagate(predicate.notNode());
-}
-
-bool ArithCongruenceManager::ArithCongruenceNotify::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1, TNode t2, bool value) {
-  Trace("arith::congruences") << "ArithCongruenceNotify::eqNotifyTriggerTermEquality(" << t1 << ", " << t2 << ", " << (value ? "true" : "false") << ")" << std::endl;
-  if (value) {
-    return d_acm.propagate(t1.eqNode(t2));
-  } else {
-    return d_acm.propagate(t1.eqNode(t2).notNode());
-  }
-}
-void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
-  Trace("arith::congruences") << "ArithCongruenceNotify::eqNotifyConstantTermMerge(" << t1 << ", " << t2 << std::endl;
-  d_acm.propagate(t1.eqNode(t2));
-}
-void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyNewClass(TNode t) {
-}
-void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyMerge(TNode t1,
-                                                                  TNode t2)
-{
-}
-void ArithCongruenceManager::ArithCongruenceNotify::eqNotifyDisequal(TNode t1, TNode t2, TNode reason) {
-}
-
 void ArithCongruenceManager::raiseConflict(Node conflict,
                                            std::shared_ptr<ProofNode> pf)
 {
index c9dfb158d0ea64a9d461cf3d255ac2d4a70fd0dc..180cbcef1c0e2f3b80656ff3bec2ca4d79ea7e8f 100644 (file)
@@ -72,27 +72,6 @@ class ArithCongruenceManager : protected EnvObj
   /** d_watchedVariables |-> (= x y) */
   ArithVarToNodeMap d_watchedEqualities;
 
-
-  class ArithCongruenceNotify : public eq::EqualityEngineNotify {
-  private:
-    ArithCongruenceManager& d_acm;
-  public:
-    ArithCongruenceNotify(ArithCongruenceManager& acm);
-
-    bool eqNotifyTriggerPredicate(TNode predicate, bool value) override;
-
-    bool eqNotifyTriggerTermEquality(TheoryId tag,
-                                     TNode t1,
-                                     TNode t2,
-                                     bool value) override;
-
-    void eqNotifyConstantTermMerge(TNode t1, TNode t2) override;
-    void eqNotifyNewClass(TNode t) override;
-    void eqNotifyMerge(TNode t1, TNode t2) override;
-    void eqNotifyDisequal(TNode t1, TNode t2, TNode reason) override;
-  };
-  ArithCongruenceNotify d_notify;
-
   context::CDList<Node> d_keepAlive;
 
   /** Store the propagations. */
@@ -113,8 +92,6 @@ class ArithCongruenceManager : protected EnvObj
 
   /** The equality engine being used by this class */
   eq::EqualityEngine* d_ee;
-  /** The equality engine we allocated */
-  std::unique_ptr<eq::EqualityEngine> d_allocEe;
   /** proof manager */
   ProofNodeManager* d_pnm;
   /** A proof generator for storing proofs of facts that are asserted to the EQ
@@ -167,7 +144,12 @@ class ArithCongruenceManager : protected EnvObj
 
   bool canExplain(TNode n) const;
 
-private:
+  /**
+   * Propagate. Called when the equality engine has inferred literal x.
+   */
+  bool propagate(TNode x);
+
+ private:
   Node externalToInternal(TNode n) const;
 
   void pushBack(TNode n);
@@ -176,7 +158,6 @@ private:
 
   void pushBack(TNode n, TNode r, TNode w);
 
-  bool propagate(TNode x);
   void explain(TNode literal, std::vector<TNode>& assumptions);
 
   /** Assert this literal to the eq engine. Common functionality for
@@ -233,11 +214,6 @@ private:
   ~ArithCongruenceManager();
 
   //--------------------------------- initialization
-  /**
-   * Returns true if we need an equality engine, see
-   * Theory::needsEqualityEngine.
-   */
-  bool needsEqualityEngine(EeSetupInfo& esi);
   /**
    * Finish initialize. This class is instructed by TheoryArithPrivate to use
    * the equality engine ee.
index 8a6d535c98ded3589612fe1e4c1bc7131dbfe2a3..0458a43ecb4040c8465f483101c674f69d583fda 100644 (file)
@@ -134,7 +134,7 @@ TheoryArithPrivate::TheoryArithPrivate(TheoryArith& containing,
                           SetupLiteralCallBack(*this),
                           d_partialModel,
                           RaiseEqualityEngineConflict(*this)),
-      d_cmEnabled(context(), options().arith.arithCongMan),
+      d_cmEnabled(context(), !options().arith.arithEqSolver),
 
       d_dualSimplex(
           env, d_linEq, d_errorSet, RaiseConflict(*this), TempVarMalloc(*this)),
@@ -176,14 +176,6 @@ TheoryArithPrivate::~TheoryArithPrivate(){
   if(d_approxStats != NULL) { delete d_approxStats; }
 }
 
-bool TheoryArithPrivate::needsEqualityEngine(EeSetupInfo& esi)
-{
-  if (!d_cmEnabled)
-  {
-    return false;
-  }
-  return d_congruenceManager.needsEqualityEngine(esi);
-}
 void TheoryArithPrivate::finishInit()
 {
   if (d_cmEnabled)
@@ -5005,6 +4997,11 @@ ArithProofRuleChecker* TheoryArithPrivate::getProofChecker()
   return &d_checker;
 }
 
+ArithCongruenceManager* TheoryArithPrivate::getCongruenceManager()
+{
+  return d_cmEnabled.get() ? &d_congruenceManager : nullptr;
+}
+
 }  // namespace arith
 }  // namespace theory
 }  // namespace cvc5::internal
index d8a3613154bc98429332fe8dbc294ecaa2423f63..253f4b3a5060efa4c51fa44ccc5b5799c105fcc4 100644 (file)
@@ -432,11 +432,6 @@ private:
   ~TheoryArithPrivate();
 
   //--------------------------------- initialization
-  /**
-   * Returns true if we need an equality engine, see
-   * Theory::needsEqualityEngine.
-   */
-  bool needsEqualityEngine(EeSetupInfo& esi);
   /** finish initialize */
   void finishInit();
   //--------------------------------- end initialization
@@ -507,6 +502,8 @@ private:
 
   /** get the proof checker of this theory */
   ArithProofRuleChecker* getProofChecker();
+  /** get the congruence manager, if we are using one */
+  ArithCongruenceManager* getCongruenceManager();
 
  private:
   /** The constant zero. */
index c1cdef9f14bb34deaed867919ce1249f6e8c34b3..1fca6459ba70439985b782afe6dc96f7bc88a8c6 100644 (file)
@@ -57,10 +57,8 @@ TheoryArith::TheoryArith(Env& env, OutputChannel& out, Valuation valuation)
   d_theoryState = &d_astate;
   d_inferManager = &d_im;
 
-  if (options().arith.arithEqSolver)
-  {
-    d_eqSolver.reset(new EqualitySolver(env, d_astate, d_im));
-  }
+  // construct the equality solver
+  d_eqSolver.reset(new EqualitySolver(env, d_astate, d_im));
 }
 
 TheoryArith::~TheoryArith(){
@@ -78,13 +76,7 @@ bool TheoryArith::needsEqualityEngine(EeSetupInfo& esi)
 {
   // if the equality solver is enabled, then it is responsible for setting
   // up the equality engine
-  if (d_eqSolver != nullptr)
-  {
-    return d_eqSolver->needsEqualityEngine(esi);
-  }
-  // otherwise, the linear arithmetic solver is responsible for setting up
-  // the equality engine
-  return d_internal->needsEqualityEngine(esi);
+  return d_eqSolver->needsEqualityEngine(esi);
 }
 void TheoryArith::finishInit()
 {
@@ -104,12 +96,14 @@ void TheoryArith::finishInit()
     d_nonlinearExtension.reset(
         new nl::NonlinearExtension(d_env, *this, d_astate));
   }
-  if (d_eqSolver != nullptr)
-  {
-    d_eqSolver->finishInit();
-  }
+  d_eqSolver->finishInit();
   // finish initialize in the old linear solver
   d_internal->finishInit();
+
+  // Set the congruence manager on the equality solver. If the congruence
+  // manager exists, it is responsible for managing the notifications from
+  // the equality engine, which the equality solver forwards to it.
+  d_eqSolver->setCongruenceManager(d_internal->getCongruenceManager());
 }
 
 void TheoryArith::preRegisterTerm(TNode n)
@@ -236,7 +230,7 @@ bool TheoryArith::preNotifyFact(
   // We do not assert to the equality engine of arithmetic in the standard way,
   // hence we return "true" to indicate we are finished with this fact.
   bool ret = true;
-  if (d_eqSolver != nullptr)
+  if (options().arith.arithEqSolver)
   {
     // the equality solver may indicate ret = false, after which the assertion
     // will be asserted to the equality engine in the default way.
@@ -257,7 +251,7 @@ bool TheoryArith::needsCheckLastEffort() {
 
 TrustNode TheoryArith::explain(TNode n)
 {
-  if (d_eqSolver != nullptr)
+  if (options().arith.arithEqSolver)
   {
     // if the equality solver has an explanation for it, use it
     TrustNode texp = d_eqSolver->explain(n);