Use std:unique_ptr instead of raw pointers in theory/bv. (#2385)
[cvc5.git] / src / theory / bv / bv_subtheory_algebraic.cpp
index 00d33739565a565c4170392b40bcc91ad5216ee6..df7ba29b5cc06e52654e65921c5ac5f08ac376ba 100644 (file)
@@ -2,9 +2,9 @@
 /*! \file bv_subtheory_algebraic.cpp
  ** \verbatim
  ** Top contributors (to current version):
- **   Liana Hadarean, Tim King, Morgan Deters
+ **   Liana Hadarean, Aina Niemetz, Tim King
  ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2016 by the authors listed in the file AUTHORS
+ ** Copyright (c) 2009-2018 by the authors listed in the file AUTHORS
  ** in the top-level source directory) and their institutional affiliations.
  ** All rights reserved.  See the file COPYING in the top-level source
  ** directory for licensing information.\endverbatim
@@ -15,6 +15,9 @@
  **/
 #include "theory/bv/bv_subtheory_algebraic.h"
 
+#include <unordered_set>
+
+#include "expr/node_algorithm.h"
 #include "options/bv_options.h"
 #include "smt/smt_statistics_registry.h"
 #include "smt_util/boolean_simplification.h"
@@ -23,7 +26,6 @@
 #include "theory/bv/theory_bv_utils.h"
 #include "theory/theory_model.h"
 
-
 using namespace CVC4::context;
 using namespace CVC4::prop;
 using namespace CVC4::theory::bv::utils;
@@ -33,6 +35,38 @@ namespace CVC4 {
 namespace theory {
 namespace bv {
 
+/* ------------------------------------------------------------------------- */
+
+namespace {
+
+/* Collect all variables under a given a node.  */
+void collectVariables(TNode node, utils::NodeSet& vars)
+{
+  std::vector<TNode> stack;
+  std::unordered_set<TNode, TNodeHashFunction> visited;
+
+  stack.push_back(node);
+  while (!stack.empty())
+  {
+    Node n = stack.back();
+    stack.pop_back();
+
+    if (vars.find(n) != vars.end()) continue;
+    if (visited.find(n) != visited.end()) continue;
+    visited.insert(n);
+
+    if (Theory::isLeafOf(n, THEORY_BV) && n.getKind() != kind::CONST_BITVECTOR)
+    {
+      vars.insert(n);
+      continue;
+    }
+    stack.insert(stack.end(), n.begin(), n.end());
+  }
+}
+
+};
+
+/* ------------------------------------------------------------------------- */
 
 bool hasExpensiveBVOperators(TNode fact);
 Node mergeExplanations(const std::vector<Node>& expls);
@@ -194,46 +228,39 @@ void SubstitutionEx::storeCache(TNode from, TNode to, Node reason) {
 }
 
 AlgebraicSolver::AlgebraicSolver(context::Context* c, TheoryBV* bv)
-  : SubtheorySolver(c, bv)
-  , d_modelMap(NULL)
-  , d_quickSolver(new BVQuickCheck("alg", bv))
-  , d_isComplete(c, false)
-  , d_isDifficult(c, false)
-  , d_budget(options::bitvectorAlgebraicBudget())
-  , d_explanations()
-  , d_inputAssertions()
-  , d_ids()
-  , d_numSolved(0)
-  , d_numCalls(0)
-  , d_ctx(new context::Context())
-  , d_quickXplain(options::bitvectorQuickXplain() ? new QuickXPlain("alg", d_quickSolver) : NULL)
-  , d_statistics()
-{}
-
-AlgebraicSolver::~AlgebraicSolver() {
-  if(d_modelMap != NULL) { delete d_modelMap; }
-  delete d_quickXplain;
-  delete d_quickSolver;
-  delete d_ctx;
+    : SubtheorySolver(c, bv),
+      d_modelMap(),
+      d_quickSolver(new BVQuickCheck("theory::bv::algebraic", bv)),
+      d_isComplete(c, false),
+      d_isDifficult(c, false),
+      d_budget(options::bitvectorAlgebraicBudget()),
+      d_explanations(),
+      d_inputAssertions(),
+      d_ids(),
+      d_numSolved(0),
+      d_numCalls(0),
+      d_quickXplain(),
+      d_statistics()
+{
+  if (options::bitvectorQuickXplain())
+  {
+    d_quickXplain.reset(
+        new QuickXPlain("theory::bv::algebraic", d_quickSolver.get()));
+  }
 }
 
+AlgebraicSolver::~AlgebraicSolver() {}
 
-
-bool AlgebraicSolver::check(Theory::Effort e) {
+bool AlgebraicSolver::check(Theory::Effort e)
+{
   Assert(options::bitblastMode() == theory::bv::BITBLAST_MODE_LAZY);
 
-  if (!Theory::fullEffort(e)) {
-    return true;
-  }
-
-  if (!useHeuristic()) {
-    return true;
-  }
-
-  ++(d_numCalls);
+  if (!Theory::fullEffort(e)) { return true; }
+  if (!useHeuristic()) { return true; }
 
   TimerStat::CodeTimer algebraicTimer(d_statistics.d_solveTime);
   Debug("bv-subtheory-algebraic") << "AlgebraicSolver::check (" << e << ")\n";
+  ++(d_numCalls);
   ++(d_statistics.d_numCallstoCheck);
 
   d_explanations.clear();
@@ -244,6 +271,7 @@ bool AlgebraicSolver::check(Theory::Effort e) {
 
   uint64_t original_bb_cost = 0;
 
+  NodeManager* nm = NodeManager::currentNM();
   NodeSet seen_assertions;
   // Processing assertions from scratch
   for (AssertionQueue::const_iterator it = assertionsBegin(); it != assertionsEnd(); ++it) {
@@ -268,16 +296,15 @@ bool AlgebraicSolver::check(Theory::Effort e) {
 
   Assert (d_explanations.size() == worklist.size());
 
-  delete d_modelMap;
-  d_modelMap = new SubstitutionMap(d_context);
-  SubstitutionEx subst(d_modelMap);
+  d_modelMap.reset(new SubstitutionMap(d_context));
+  SubstitutionEx subst(d_modelMap.get());
 
   // first round of substitutions
   processAssertions(worklist, subst);
 
   if (!d_isDifficult.get()) {
     // skolemize all possible extracts
-    ExtractSkolemizer skolemizer(d_modelMap);
+    ExtractSkolemizer skolemizer(d_modelMap.get());
     skolemizer.skolemize(worklist);
     // second round of substitutions
     processAssertions(worklist, subst);
@@ -296,7 +323,7 @@ bool AlgebraicSolver::check(Theory::Effort e) {
 
     if (Dump.isOn("bv-algebraic")) {
       Node expl = d_explanations[id];
-      Node query = utils::mkNot(utils::mkNode(kind::IMPLIES, expl, fact));
+      Node query = utils::mkNot(nm->mkNode(kind::IMPLIES, expl, fact));
       Dump("bv-algebraic") << EchoCommand("ThoeryBV::AlgebraicSolver::substitution explanation");
       Dump("bv-algebraic") << PushCommand();
       Dump("bv-algebraic") << AssertCommand(query.toExpr());
@@ -457,15 +484,17 @@ void AlgebraicSolver::setConflict(TNode conflict) {
 bool AlgebraicSolver::solve(TNode fact, TNode reason, SubstitutionEx& subst) {
   if (fact.getKind() != kind::EQUAL) return false;
 
+  NodeManager* nm = NodeManager::currentNM();
   TNode left = fact[0];
   TNode right = fact[1];
 
-
-  if (left.isVar() && !right.hasSubterm(left)) {
-    bool changed  = subst.addSubstitution(left, right, reason);
+  if (left.isVar() && !expr::hasSubterm(right, left))
+  {
+    bool changed = subst.addSubstitution(left, right, reason);
     return changed;
   }
-  if (right.isVar() && !left.hasSubterm(right)) {
+  if (right.isVar() && !expr::hasSubterm(left, right))
+  {
     bool changed = subst.addSubstitution(right, left, reason);
     return changed;
   }
@@ -478,22 +507,21 @@ bool AlgebraicSolver::solve(TNode fact, TNode reason, SubstitutionEx& subst) {
       return false;
 
     // simplify xor with same variable on both sides
-    if (right.hasSubterm(var)) {
+    if (expr::hasSubterm(right, var))
+    {
       std::vector<Node> right_children;
       for (unsigned i = 0; i < right.getNumChildren(); ++i) {
         if (right[i] != var)
           right_children.push_back(right[i]);
       }
       Assert (right_children.size());
-      Node new_right = right_children.size() > 1 ? utils::mkNode(kind::BITVECTOR_XOR, right_children)
-                                                 : right_children[0];
+      Node new_right = utils::mkNaryNode(kind::BITVECTOR_XOR, right_children);
       std::vector<Node> left_children;
       for (unsigned i = 1; i < left.getNumChildren(); ++i) {
         left_children.push_back(left[i]);
       }
-      Node new_left = left_children.size() > 1 ? utils::mkNode(kind::BITVECTOR_XOR, left_children)
-                                               : left_children[0];
-      Node new_fact = utils::mkNode(kind::EQUAL, new_left, new_right);
+      Node new_left = utils::mkNaryNode(kind::BITVECTOR_XOR, left_children);
+      Node new_fact = nm->mkNode(kind::EQUAL, new_left, new_right);
       bool changed = subst.addSubstitution(fact, new_fact, reason);
       return changed;
     }
@@ -503,11 +531,12 @@ bool AlgebraicSolver::solve(TNode fact, TNode reason, SubstitutionEx& subst) {
       nb << left[i];
     }
     Node inverse = left.getNumChildren() == 2? (Node)left[1] : (Node)nb;
-    Node new_right = utils::mkNode(kind::BITVECTOR_XOR, right, inverse);
+    Node new_right = nm->mkNode(kind::BITVECTOR_XOR, right, inverse);
     bool changed = subst.addSubstitution(var, new_right, reason);
 
     if (Dump.isOn("bv-algebraic")) {
-      Node query = utils::mkNot(utils::mkNode(kind::IFF, fact, utils::mkNode(kind::EQUAL, var, new_right)));
+      Node query = utils::mkNot(nm->mkNode(
+          kind::EQUAL, fact, nm->mkNode(kind::EQUAL, var, new_right)));
       Dump("bv-algebraic") << EchoCommand("ThoeryBV::AlgebraicSolver::substitution explanation");
       Dump("bv-algebraic") << PushCommand();
       Dump("bv-algebraic") << AssertCommand(query.toExpr());
@@ -520,24 +549,26 @@ bool AlgebraicSolver::solve(TNode fact, TNode reason, SubstitutionEx& subst) {
   }
 
   // (a xor t = a) <=> (t = 0)
-  if (left.getKind() == kind::BITVECTOR_XOR &&
-      right.getMetaKind() == kind::metakind::VARIABLE &&
-      left.hasSubterm(right)) {
+  if (left.getKind() == kind::BITVECTOR_XOR
+      && right.getMetaKind() == kind::metakind::VARIABLE
+      && expr::hasSubterm(left, right))
+  {
     TNode var = right;
-    Node new_left = utils::mkNode(kind::BITVECTOR_XOR, var, left);
+    Node new_left = nm->mkNode(kind::BITVECTOR_XOR, var, left);
     Node zero = utils::mkConst(utils::getSize(var), 0u);
-    Node new_fact = utils::mkNode(kind::EQUAL, zero, new_left);
+    Node new_fact = nm->mkNode(kind::EQUAL, zero, new_left);
     bool changed = subst.addSubstitution(fact, new_fact, reason);
     return changed;
   }
 
-  if (right.getKind() == kind::BITVECTOR_XOR &&
-      left.getMetaKind() == kind::metakind::VARIABLE &&
-      right.hasSubterm(left)) {
+  if (right.getKind() == kind::BITVECTOR_XOR
+      && left.getMetaKind() == kind::metakind::VARIABLE
+      && expr::hasSubterm(right, left))
+  {
     TNode var = left;
-    Node new_right = utils::mkNode(kind::BITVECTOR_XOR, var, right);
+    Node new_right = nm->mkNode(kind::BITVECTOR_XOR, var, right);
     Node zero = utils::mkConst(utils::getSize(var), 0u);
-    Node new_fact = utils::mkNode(kind::EQUAL, zero, new_right);
+    Node new_fact = nm->mkNode(kind::EQUAL, zero, new_right);
     bool changed = subst.addSubstitution(fact, new_fact, reason);
     return changed;
   }
@@ -547,7 +578,7 @@ bool AlgebraicSolver::solve(TNode fact, TNode reason, SubstitutionEx& subst) {
       left.getNumChildren() == 2 &&
       right.getKind() == kind::CONST_BITVECTOR &&
       right.getConst<BitVector>() == BitVector(utils::getSize(left), 0u)) {
-    Node new_fact = utils::mkNode(kind::EQUAL, left[0], left[1]);
+    Node new_fact = nm->mkNode(kind::EQUAL, left[0], left[1]);
     bool changed = subst.addSubstitution(fact, new_fact, reason);
     return changed;
   }
@@ -556,14 +587,16 @@ bool AlgebraicSolver::solve(TNode fact, TNode reason, SubstitutionEx& subst) {
   return false;
 }
 
-bool AlgebraicSolver::isSubstitutableIn(TNode node, TNode in) {
-  if (node.getMetaKind() == kind::metakind::VARIABLE &&
-      !in.hasSubterm(node))
+bool AlgebraicSolver::isSubstitutableIn(TNode node, TNode in)
+{
+  if (node.getMetaKind() == kind::metakind::VARIABLE
+      && !expr::hasSubterm(in, node))
     return true;
   return false;
 }
 
 void AlgebraicSolver::processAssertions(std::vector<WorklistElement>& worklist, SubstitutionEx& subst) {
+  NodeManager* nm = NodeManager::currentNM();
   bool changed = true;
   while(changed) {
     // d_bv->spendResource();
@@ -613,7 +646,7 @@ void AlgebraicSolver::processAssertions(std::vector<WorklistElement>& worklist,
       }
 
       for (unsigned j = 0; j < left.getNumChildren(); ++j) {
-        Node eq_j = utils::mkNode(kind::EQUAL, left[j], right[j]);
+        Node eq_j = nm->mkNode(kind::EQUAL, left[j], right[j]);
         unsigned id = d_explanations.size();
         TNode expl = d_explanations[current_id];
         storeExplanation(expl);
@@ -676,7 +709,9 @@ void AlgebraicSolver::assertFact(TNode fact) {
 EqualityStatus AlgebraicSolver::getEqualityStatus(TNode a, TNode b) {
   return EQUALITY_UNKNOWN;
 }
-void AlgebraicSolver::collectModelInfo(TheoryModel* model, bool fullModel) {
+
+bool AlgebraicSolver::collectModelInfo(TheoryModel* model, bool fullModel)
+{
   Debug("bitvector-model") << "AlgebraicSolver::collectModelInfo\n";
   AlwaysAssert (!d_quickSolver->inConflict());
   set<Node> termSet;
@@ -703,7 +738,7 @@ void AlgebraicSolver::collectModelInfo(TheoryModel* model, bool fullModel) {
     TNode subst = Rewriter::rewrite(d_modelMap->apply(current));
     Debug("bitvector-model") << "   " << current << " => " << subst << "\n";
     values[i] = subst;
-    utils::collectVariables(subst, leaf_vars);
+    collectVariables(subst, leaf_vars);
   }
 
   Debug("bitvector-model") << "Model:\n";
@@ -714,7 +749,8 @@ void AlgebraicSolver::collectModelInfo(TheoryModel* model, bool fullModel) {
     Assert (!value.isNull() || !fullModel);
 
     // may be a shared term that did not appear in the current assertions
-    if (!value.isNull()) {
+    // AJR: need to check whether already in map for cases where collectModelInfo is called multiple times in the same context
+    if (!value.isNull() && !d_modelMap->hasSubstitution(var)) {
       Debug("bitvector-model") << "   " << var << " => " << value << "\n";
       Assert (value.getKind() == kind::CONST_BITVECTOR);
       d_modelMap->addSubstitution(var, value);
@@ -728,9 +764,12 @@ void AlgebraicSolver::collectModelInfo(TheoryModel* model, bool fullModel) {
     Debug("bitvector-model") << "AlgebraicSolver:   " << variables[i] << " => " << subst << "\n";
     // Doesn't have to be constant as it may be irrelevant
     Assert (subst.getKind() == kind::CONST_BITVECTOR);
-    model->assertEquality(variables[i], subst, true);
+    if (!model->assertEquality(variables[i], subst, true))
+    {
+      return false;
+    }
   }
-
+  return true;
  }
 
 Node AlgebraicSolver::getModelValue(TNode node) {
@@ -738,14 +777,14 @@ Node AlgebraicSolver::getModelValue(TNode node) {
 }
 
 AlgebraicSolver::Statistics::Statistics()
-  : d_numCallstoCheck("theory::bv::AlgebraicSolver::NumCallsToCheck", 0)
-  , d_numSimplifiesToTrue("theory::bv::AlgebraicSolver::NumSimplifiesToTrue", 0)
-  , d_numSimplifiesToFalse("theory::bv::AlgebraicSolver::NumSimplifiesToFalse", 0)
-  , d_numUnsat("theory::bv::AlgebraicSolver::NumUnsat", 0)
-  , d_numSat("theory::bv::AlgebraicSolver::NumSat", 0)
-  , d_numUnknown("theory::bv::AlgebraicSolver::NumUnknown", 0)
-  , d_solveTime("theory::bv::AlgebraicSolver::SolveTime")
-  , d_useHeuristic("theory::bv::AlgebraicSolver::UseHeuristic", 0.2)
+  : d_numCallstoCheck("theory::bv::algebraic::NumCallsToCheck", 0)
+  , d_numSimplifiesToTrue("theory::bv::algebraic::NumSimplifiesToTrue", 0)
+  , d_numSimplifiesToFalse("theory::bv::algebraic::NumSimplifiesToFalse", 0)
+  , d_numUnsat("theory::bv::algebraic::NumUnsat", 0)
+  , d_numSat("theory::bv::algebraic::NumSat", 0)
+  , d_numUnknown("theory::bv::algebraic::NumUnknown", 0)
+  , d_solveTime("theory::bv::algebraic::SolveTime")
+  , d_useHeuristic("theory::bv::algebraic::UseHeuristic", 0.2)
 {
   smtStatisticsRegistry()->registerStat(&d_numCallstoCheck);
   smtStatisticsRegistry()->registerStat(&d_numSimplifiesToTrue);