Properly set up equality engine for BV bitblast solver. (#5905)
[cvc5.git] / src / theory / bv / theory_bv.cpp
index fd91d034625978f0348f966b46358db0f633671d..f6e056f4241e3fb7da90f0705873e568a8ad1cb3 100644 (file)
@@ -16,6 +16,8 @@
 #include "theory/bv/theory_bv.h"
 
 #include "options/bv_options.h"
+#include "options/smt_options.h"
+#include "theory/bv/bv_solver_bitblast.h"
 #include "theory/bv/bv_solver_lazy.h"
 #include "theory/bv/bv_solver_simple.h"
 #include "theory/bv/theory_bv_utils.h"
@@ -37,20 +39,25 @@ TheoryBV::TheoryBV(context::Context* c,
       d_ufRemByZero(),
       d_rewriter(),
       d_state(c, u, valuation),
-      d_inferMgr(*this, d_state, nullptr)
+      d_im(*this, d_state, nullptr),
+      d_notify(d_im)
 {
   switch (options::bvSolver())
   {
+    case options::BVSolver::BITBLAST:
+      d_internal.reset(new BVSolverBitblast(&d_state, d_im, pnm));
+      break;
+
     case options::BVSolver::LAZY:
       d_internal.reset(new BVSolverLazy(*this, c, u, pnm, name));
       break;
 
     default:
       AlwaysAssert(options::bvSolver() == options::BVSolver::SIMPLE);
-      d_internal.reset(new BVSolverSimple(&d_state, d_inferMgr, pnm));
+      d_internal.reset(new BVSolverSimple(&d_state, d_im, pnm));
   }
   d_theoryState = &d_state;
-  d_inferManager = &d_inferMgr;
+  d_inferManager = &d_im;
 }
 
 TheoryBV::~TheoryBV() {}
@@ -59,7 +66,16 @@ TheoryRewriter* TheoryBV::getTheoryRewriter() { return &d_rewriter; }
 
 bool TheoryBV::needsEqualityEngine(EeSetupInfo& esi)
 {
-  return d_internal->needsEqualityEngine(esi);
+  bool need_ee = d_internal->needsEqualityEngine(esi);
+
+  /* Set up default notify class for equality engine. */
+  if (need_ee && esi.d_notify == nullptr)
+  {
+    esi.d_notify = &d_notify;
+    esi.d_name = "theory::bv::ee";
+  }
+
+  return need_ee;
 }
 
 void TheoryBV::finishInit()
@@ -69,6 +85,44 @@ void TheoryBV::finishInit()
   getValuation().setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UDIV);
   getValuation().setSemiEvaluatedKind(kind::BITVECTOR_ACKERMANNIZE_UREM);
   d_internal->finishInit();
+
+  eq::EqualityEngine* ee = getEqualityEngine();
+  if (ee)
+  {
+    // The kinds we are treating as function application in congruence
+    ee->addFunctionKind(kind::BITVECTOR_CONCAT, true);
+    //    ee->addFunctionKind(kind::BITVECTOR_AND);
+    //    ee->addFunctionKind(kind::BITVECTOR_OR);
+    //    ee->addFunctionKind(kind::BITVECTOR_XOR);
+    //    ee->addFunctionKind(kind::BITVECTOR_NOT);
+    //    ee->addFunctionKind(kind::BITVECTOR_NAND);
+    //    ee->addFunctionKind(kind::BITVECTOR_NOR);
+    //    ee->addFunctionKind(kind::BITVECTOR_XNOR);
+    //    ee->addFunctionKind(kind::BITVECTOR_COMP);
+    ee->addFunctionKind(kind::BITVECTOR_MULT, true);
+    ee->addFunctionKind(kind::BITVECTOR_PLUS, true);
+    ee->addFunctionKind(kind::BITVECTOR_EXTRACT, true);
+    //    ee->addFunctionKind(kind::BITVECTOR_SUB);
+    //    ee->addFunctionKind(kind::BITVECTOR_NEG);
+    //    ee->addFunctionKind(kind::BITVECTOR_UDIV);
+    //    ee->addFunctionKind(kind::BITVECTOR_UREM);
+    //    ee->addFunctionKind(kind::BITVECTOR_SDIV);
+    //    ee->addFunctionKind(kind::BITVECTOR_SREM);
+    //    ee->addFunctionKind(kind::BITVECTOR_SMOD);
+    //    ee->addFunctionKind(kind::BITVECTOR_SHL);
+    //    ee->addFunctionKind(kind::BITVECTOR_LSHR);
+    //    ee->addFunctionKind(kind::BITVECTOR_ASHR);
+    //    ee->addFunctionKind(kind::BITVECTOR_ULT);
+    //    ee->addFunctionKind(kind::BITVECTOR_ULE);
+    //    ee->addFunctionKind(kind::BITVECTOR_UGT);
+    //    ee->addFunctionKind(kind::BITVECTOR_UGE);
+    //    ee->addFunctionKind(kind::BITVECTOR_SLT);
+    //    ee->addFunctionKind(kind::BITVECTOR_SLE);
+    //    ee->addFunctionKind(kind::BITVECTOR_SGT);
+    //    ee->addFunctionKind(kind::BITVECTOR_SGE);
+    ee->addFunctionKind(kind::BITVECTOR_TO_NAT);
+    ee->addFunctionKind(kind::INT_TO_BITVECTOR);
+  }
 }
 
 Node TheoryBV::getUFDivByZero(Kind k, unsigned width)
@@ -129,27 +183,12 @@ TrustNode TheoryBV::expandDefinition(Node node)
     case kind::BITVECTOR_UREM:
     {
       NodeManager* nm = NodeManager::currentNM();
-      unsigned width = node.getType().getBitVectorSize();
-
-      if (options::bitvectorDivByZeroConst())
-      {
-        Kind kind = node.getKind() == kind::BITVECTOR_UDIV
-                        ? kind::BITVECTOR_UDIV_TOTAL
-                        : kind::BITVECTOR_UREM_TOTAL;
-        ret = nm->mkNode(kind, node[0], node[1]);
-        break;
-      }
-
-      TNode num = node[0], den = node[1];
-      Node den_eq_0 = nm->mkNode(kind::EQUAL, den, utils::mkZero(width));
-      Node divTotalNumDen = nm->mkNode(node.getKind() == kind::BITVECTOR_UDIV
-                                           ? kind::BITVECTOR_UDIV_TOTAL
-                                           : kind::BITVECTOR_UREM_TOTAL,
-                                       num,
-                                       den);
-      Node divByZero = getUFDivByZero(node.getKind(), width);
-      Node divByZeroNum = nm->mkNode(kind::APPLY_UF, divByZero, num);
-      ret = nm->mkNode(kind::ITE, den_eq_0, divByZeroNum, divTotalNumDen);
+
+      Kind kind = node.getKind() == kind::BITVECTOR_UDIV
+                      ? kind::BITVECTOR_UDIV_TOTAL
+                      : kind::BITVECTOR_UREM_TOTAL;
+      ret = nm->mkNode(kind, node[0], node[1]);
+      break;
     }
     break;
 
@@ -165,6 +204,19 @@ TrustNode TheoryBV::expandDefinition(Node node)
 void TheoryBV::preRegisterTerm(TNode node)
 {
   d_internal->preRegisterTerm(node);
+
+  eq::EqualityEngine* ee = getEqualityEngine();
+  if (ee)
+  {
+    if (node.getKind() == kind::EQUAL)
+    {
+      ee->addTriggerPredicate(node);
+    }
+    else
+    {
+      ee->addTerm(node);
+    }
+  }
 }
 
 bool TheoryBV::preCheck(Effort e) { return d_internal->preCheck(e); }
@@ -213,6 +265,11 @@ TrustNode TheoryBV::ppRewrite(TNode t)
 
 void TheoryBV::presolve() { d_internal->presolve(); }
 
+EqualityStatus TheoryBV::getEqualityStatus(TNode a, TNode b)
+{
+  return d_internal->getEqualityStatus(a, b);
+}
+
 TrustNode TheoryBV::explain(TNode node) { return d_internal->explain(node); }
 
 void TheoryBV::notifySharedTerm(TNode t)