Properly set up equality engine for BV bitblast solver. (#5905)
[cvc5.git] / src / theory / bv / theory_bv.cpp
index b27bd04e1d9eee063e0085e0d0c315bbed87c93b..f6e056f4241e3fb7da90f0705873e568a8ad1cb3 100644 (file)
@@ -39,12 +39,13 @@ TheoryBV::TheoryBV(context::Context* c,
       d_ufRemByZero(),
       d_rewriter(),
       d_state(c, u, valuation),
-      d_inferMgr(*this, d_state, nullptr)
+      d_im(*this, d_state, nullptr),
+      d_notify(d_im)
 {
   switch (options::bvSolver())
   {
     case options::BVSolver::BITBLAST:
-      d_internal.reset(new BVSolverBitblast(&d_state, d_inferMgr, pnm));
+      d_internal.reset(new BVSolverBitblast(&d_state, d_im, pnm));
       break;
 
     case options::BVSolver::LAZY:
@@ -53,10 +54,10 @@ TheoryBV::TheoryBV(context::Context* c,
 
     default:
       AlwaysAssert(options::bvSolver() == options::BVSolver::SIMPLE);
-      d_internal.reset(new BVSolverSimple(&d_state, d_inferMgr, pnm));
+      d_internal.reset(new BVSolverSimple(&d_state, d_im, pnm));
   }
   d_theoryState = &d_state;
-  d_inferManager = &d_inferMgr;
+  d_inferManager = &d_im;
 }
 
 TheoryBV::~TheoryBV() {}
@@ -65,7 +66,16 @@ TheoryRewriter* TheoryBV::getTheoryRewriter() { return &d_rewriter; }
 
 bool TheoryBV::needsEqualityEngine(EeSetupInfo& esi)
 {
-  return d_internal->needsEqualityEngine(esi);
+  bool need_ee = d_internal->needsEqualityEngine(esi);
+
+  /* Set up default notify class for equality engine. */
+  if (need_ee && esi.d_notify == nullptr)
+  {
+    esi.d_notify = &d_notify;
+    esi.d_name = "theory::bv::ee";
+  }
+
+  return need_ee;
 }
 
 void TheoryBV::finishInit()
@@ -194,6 +204,19 @@ TrustNode TheoryBV::expandDefinition(Node node)
 void TheoryBV::preRegisterTerm(TNode node)
 {
   d_internal->preRegisterTerm(node);
+
+  eq::EqualityEngine* ee = getEqualityEngine();
+  if (ee)
+  {
+    if (node.getKind() == kind::EQUAL)
+    {
+      ee->addTriggerPredicate(node);
+    }
+    else
+    {
+      ee->addTerm(node);
+    }
+  }
 }
 
 bool TheoryBV::preCheck(Effort e) { return d_internal->preCheck(e); }