(new theory) Update TheoryUF to new interface (#4944)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 27 Aug 2020 01:24:28 +0000 (20:24 -0500)
committerGitHub <noreply@github.com>
Thu, 27 Aug 2020 01:24:28 +0000 (20:24 -0500)
This updates TheoryUF to use the 4 check callbacks instead of implementing check, and uses the official TheoryState object instead of its context::CDO<bool> d_conflict field.

It also makes a minor change to collectModelValues for const and to preNotifyFact to include an isInternal flag.

src/theory/theory.cpp
src/theory/theory.h
src/theory/uf/ho_extension.cpp
src/theory/uf/ho_extension.h
src/theory/uf/theory_uf.cpp
src/theory/uf/theory_uf.h

index f65a7c45c328b1b8599b7f4a14616d986474d175..50a5c4493dbdea231b05bfbec777611cae5699fd 100644 (file)
@@ -425,7 +425,7 @@ void Theory::computeRelevantTerms(std::set<Node>& termSet, bool includeShared)
   computeRelevantTermsInternal(termSet, irrKinds, includeShared);
 }
 
-bool Theory::collectModelValues(TheoryModel* m, std::set<Node>& termSet)
+bool Theory::collectModelValues(TheoryModel* m, const std::set<Node>& termSet)
 {
   return true;
 }
@@ -538,7 +538,7 @@ void Theory::check(Effort level)
     bool polarity = fact.getKind() != kind::NOT;
     TNode atom = polarity ? fact : fact[0];
     // call the pre-notify method
-    if (preNotifyFact(atom, polarity, fact, assertion.d_isPreregistered))
+    if (preNotifyFact(atom, polarity, fact, assertion.d_isPreregistered, false))
     {
       // handled in theory-specific way that doesn't involve equality engine
       continue;
@@ -566,7 +566,8 @@ bool Theory::preCheck(Effort level) { return false; }
 
 void Theory::postCheck(Effort level) {}
 
-bool Theory::preNotifyFact(TNode atom, bool polarity, TNode fact, bool isPrereg)
+bool Theory::preNotifyFact(
+    TNode atom, bool polarity, TNode fact, bool isPrereg, bool isInternal)
 {
   return false;
 }
index 8ea64e724097da119646a14a0995009356e6cc6a..1fdf96331e0fef371a93b253fc60d15e846929ce 100644 (file)
@@ -644,18 +644,22 @@ class Theory {
    * @param polarity Its polarity
    * @param fact The original literal that was asserted
    * @param isPrereg Whether the assertion is preregistered
+   * @param isInternal Whether the origin of the fact was internal. If this
+   * is false, the fact was asserted via the fact queue of the theory.
    * @return true if the theory completely processed this fact, i.e. it does
    * not need to assert the fact to its equality engine.
    */
-  virtual bool preNotifyFact(TNode atom, bool pol, TNode fact, bool isPrereg);
+  virtual bool preNotifyFact(
+      TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal);
   /**
    * Notify fact, called immediately after the fact was pushed into the
    * equality engine.
    *
    * @param atom The atom
    * @param polarity Its polarity
-   * @param fact The original literal that was asserted
-   * @param isInternal Whether the origin of the fact was internal
+   * @param fact The original literal that was asserted.
+   * @param isInternal Whether the origin of the fact was internal. If this
+   * is false, the fact was asserted via the fact queue of the theory.
    */
   virtual void notifyFact(TNode atom, bool pol, TNode fact, bool isInternal);
   //--------------------------------- end check
@@ -690,7 +694,8 @@ class Theory {
    * The argument termSet is the set of relevant terms returned by
    * computeRelevantTerms.
    */
-  virtual bool collectModelValues(TheoryModel* m, std::set<Node>& termSet);
+  virtual bool collectModelValues(TheoryModel* m,
+                                  const std::set<Node>& termSet);
   /** if theories want to do something with model after building, do it here */
   virtual void postProcessModel( TheoryModel* m ){ }
   //--------------------------------- end collect model info
index d90a0248679c8c8796a2f848de2d58af25f965fd..11b872e724c8f18d29413bd22918ad770e749d50 100644 (file)
@@ -28,10 +28,11 @@ namespace CVC4 {
 namespace theory {
 namespace uf {
 
-HoExtension::HoExtension(TheoryUF& p,
-                         context::Context* c,
-                         context::UserContext* u)
-    : d_parent(p), d_extensionality(u), d_uf_std_skolem(u)
+HoExtension::HoExtension(TheoryUF& p, TheoryState& state)
+    : d_parent(p),
+      d_state(state),
+      d_extensionality(state.getUserContext()),
+      d_uf_std_skolem(state.getUserContext())
 {
   d_true = NodeManager::currentNM()->mkConst(true);
 }
@@ -191,7 +192,7 @@ Node HoExtension::getApplyUfForHoApply(Node node)
 
 unsigned HoExtension::checkExtensionality(TheoryModel* m)
 {
-  eq::EqualityEngine* ee = d_parent.getEqualityEngine();
+  eq::EqualityEngine* ee = d_state.getEqualityEngine();
   NodeManager* nm = NodeManager::currentNM();
   unsigned num_lemmas = 0;
   bool isCollectModel = (m != nullptr);
@@ -276,7 +277,7 @@ unsigned HoExtension::applyAppCompletion(TNode n)
 {
   Assert(n.getKind() == APPLY_UF);
 
-  eq::EqualityEngine* ee = d_parent.getEqualityEngine();
+  eq::EqualityEngine* ee = d_state.getEqualityEngine();
   // must expand into APPLY_HO version if not there already
   Node ret = TheoryUfRewriter::getHoApplyForApplyUf(n);
   if (!ee->hasTerm(ret) || !ee->areEqual(ret, n))
@@ -297,7 +298,7 @@ unsigned HoExtension::checkAppCompletion()
   Trace("uf-ho") << "HoExtension::checkApplyCompletion..." << std::endl;
   // compute the operators that are relevant (those for which an HO_APPLY exist)
   std::set<TNode> rlvOp;
-  eq::EqualityEngine* ee = d_parent.getEqualityEngine();
+  eq::EqualityEngine* ee = d_state.getEqualityEngine();
   eq::EqClassesIterator eqcs_i = eq::EqClassesIterator(ee);
   std::map<TNode, std::vector<Node> > apply_uf;
   while (!eqcs_i.isFinished())
@@ -388,7 +389,7 @@ unsigned HoExtension::check()
   do
   {
     num_facts = checkAppCompletion();
-    if (d_parent.inConflict())
+    if (d_state.isInConflict())
     {
       Trace("uf-ho") << "...conflict during app-completion." << std::endl;
       return 1;
@@ -413,7 +414,8 @@ unsigned HoExtension::check()
   return 0;
 }
 
-bool HoExtension::collectModelInfoHo(std::set<Node>& termSet, TheoryModel* m)
+bool HoExtension::collectModelInfoHo(TheoryModel* m,
+                                     const std::set<Node>& termSet)
 {
   for (std::set<Node>::iterator it = termSet.begin(); it != termSet.end(); ++it)
   {
index d00372c9850413f588dc32935aba747a859176a7..dc37808c9c763f3b9a2bccf3c0e65d044113d3b7 100644 (file)
@@ -50,7 +50,7 @@ class HoExtension
   typedef context::CDHashMap<Node, Node, NodeHashFunction> NodeNodeMap;
 
  public:
-  HoExtension(TheoryUF& p, context::Context* c, context::UserContext* u);
+  HoExtension(TheoryUF& p, TheoryState& state);
 
   /** expand definition
    *
@@ -110,7 +110,7 @@ class HoExtension
    * values in m. It returns false if any (dis)equality added to m led to
    * an inconsistency in m.
    */
-  bool collectModelInfoHo(std::set<Node>& termSet, TheoryModel* m);
+  bool collectModelInfoHo(TheoryModel* m, const std::set<Node>& termSet);
 
  protected:
   /** get apply uf for ho apply
@@ -182,6 +182,8 @@ class HoExtension
   Node d_true;
   /** the parent of this extension */
   TheoryUF& d_parent;
+  /** Reference to the state object */
+  TheoryState& d_state;
   /** extensionality has been applied to these disequalities */
   NodeSet d_extensionality;
 
index 7d554c613fc4424fe922ee5fbba402879bba7664..f94cc36af8fabfe2da6200d7ba37a664b675b723 100644 (file)
@@ -54,7 +54,6 @@ TheoryUF::TheoryUF(context::Context* c,
        * so make sure it's initialized first. */
       d_thss(nullptr),
       d_ho(nullptr),
-      d_conflict(c, false),
       d_functionsTerms(c),
       d_symb(u, instanceName),
       d_state(c, u, valuation)
@@ -89,7 +88,7 @@ void TheoryUF::finishInit() {
   // Initialize the cardinality constraints solver if the logic includes UF,
   // finite model finding is enabled, and it is not disabled by
   // options::ufssMode().
-  if (getLogicInfo().isTheoryEnabled(THEORY_UF) && options::finiteModelFind()
+  if (options::finiteModelFind()
       && options::ufssMode() != options::UfssMode::NONE)
   {
     d_thss.reset(new CardinalityExtension(
@@ -100,7 +99,7 @@ void TheoryUF::finishInit() {
   if (options::ufHo())
   {
     d_equalityEngine->addFunctionKind(kind::HO_APPLY);
-    d_ho.reset(new HoExtension(*this, getSatContext(), getUserContext()));
+    d_ho.reset(new HoExtension(*this, d_state));
   }
 }
 
@@ -126,79 +125,87 @@ static Node mkAnd(const std::vector<TNode>& conjunctions) {
   return conjunction;
 }/* mkAnd() */
 
-void TheoryUF::check(Effort level) {
-  if (done() && !fullEffort(level)) {
+//--------------------------------- standard check
+
+void TheoryUF::postCheck(Effort level)
+{
+  if (d_state.isInConflict())
+  {
     return;
   }
-  getOutputChannel().spendResource(ResourceManager::Resource::TheoryCheckStep);
-  TimerStat::CodeTimer checkTimer(d_checkTime);
-
-  while (!done() && !d_conflict)
+  // check with the cardinality constraints extension
+  if (d_thss != nullptr)
   {
-    // Get all the assertions
-    Assertion assertion = get();
-    TNode fact = assertion.d_assertion;
-
-    Debug("uf") << "TheoryUF::check(): processing " << fact << std::endl;
-    Debug("uf") << "Term's theory: " << theory::Theory::theoryOf(fact.toExpr()) << std::endl;
-
-    if (d_thss != NULL) {
-      bool isDecision = d_valuation.isSatLiteral(fact) && d_valuation.isDecision(fact);
-      d_thss->assertNode(fact, isDecision);
-      if( d_thss->isConflict() ){
-        d_conflict = true;
-        return;
-      }
+    d_thss->check(level);
+    if (d_thss->isConflict())
+    {
+      d_state.notifyInConflict();
     }
+  }
+  // check with the higher-order extension
+  if (!d_state.isInConflict() && fullEffort(level))
+  {
+    if (options::ufHo())
+    {
+      d_ho->check();
+    }
+  }
+}
 
-    // Do the work
-    bool polarity = fact.getKind() != kind::NOT;
-    TNode atom = polarity ? fact : fact[0];
-    if (atom.getKind() == kind::EQUAL) {
-      d_equalityEngine->assertEquality(atom, polarity, fact);
-      if( options::ufHo() && options::ufHoExt() ){
-        if( !polarity && !d_conflict && atom[0].getType().isFunction() ){
-          // apply extensionality eagerly using the ho extension
-          d_ho->applyExtensionality(fact);
-        }
-      }
-    } else if (atom.getKind() == kind::CARDINALITY_CONSTRAINT || atom.getKind() == kind::COMBINED_CARDINALITY_CONSTRAINT) {
-      if( d_thss == NULL ){
-        if( !getLogicInfo().hasCardinalityConstraints() ){
-          std::stringstream ss;
-          ss << "Cardinality constraint " << atom << " was asserted, but the logic does not allow it." << std::endl;
-          ss << "Try using a logic containing \"UFC\"." << std::endl;
-          throw Exception( ss.str() );
-        }else{
-          // support for cardinality constraints is not enabled, set incomplete
-          d_out->setIncomplete();
-        }
+bool TheoryUF::preNotifyFact(
+    TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
+{
+  if (d_thss != nullptr)
+  {
+    bool isDecision =
+        d_valuation.isSatLiteral(fact) && d_valuation.isDecision(fact);
+    d_thss->assertNode(fact, isDecision);
+    if (d_thss->isConflict())
+    {
+      d_state.notifyInConflict();
+      return true;
+    }
+  }
+  if (atom.getKind() == kind::CARDINALITY_CONSTRAINT
+      || atom.getKind() == kind::COMBINED_CARDINALITY_CONSTRAINT)
+  {
+    if (d_thss == nullptr)
+    {
+      if (!getLogicInfo().hasCardinalityConstraints())
+      {
+        std::stringstream ss;
+        ss << "Cardinality constraint " << atom
+           << " was asserted, but the logic does not allow it." << std::endl;
+        ss << "Try using a logic containing \"UFC\"." << std::endl;
+        throw Exception(ss.str());
       }
-      //needed for models
-      if( options::produceModels() ){
-        d_equalityEngine->assertPredicate(atom, polarity, fact);
+      else
+      {
+        // support for cardinality constraints is not enabled, set incomplete
+        d_out->setIncomplete();
       }
-    } else {
-      d_equalityEngine->assertPredicate(atom, polarity, fact);
     }
+    // don't need to assert cardinality constraints if not producing models
+    return !options::produceModels();
   }
+  return false;
+}
 
-  if(! d_conflict ){
-    // check with the cardinality constraints extension
-    if (d_thss != NULL) {
-      d_thss->check(level);
-      if( d_thss->isConflict() ){
-        d_conflict = true;
-      }
-    }
-    // check with the higher-order extension
-    if(! d_conflict && fullEffort(level) ){
-      if( options::ufHo() ){
-        d_ho->check();
+void TheoryUF::notifyFact(TNode atom, bool pol, TNode fact, bool isInternal)
+{
+  if (!d_state.isInConflict() && atom.getKind() == kind::EQUAL)
+  {
+    if (options::ufHo() && options::ufHoExt())
+    {
+      if (!pol && !d_state.isInConflict() && atom[0].getType().isFunction())
+      {
+        // apply extensionality eagerly using the ho extension
+        d_ho->applyExtensionality(fact);
       }
     }
   }
-}/* TheoryUF::check() */
+}
+//--------------------------------- end standard check
 
 TrustNode TheoryUF::expandDefinition(Node node)
 {
@@ -221,7 +228,8 @@ TrustNode TheoryUF::expandDefinition(Node node)
   return TrustNode::null();
 }
 
-void TheoryUF::preRegisterTerm(TNode node) {
+void TheoryUF::preRegisterTerm(TNode node)
+{
   Debug("uf") << "TheoryUF::preRegisterTerm(" << node << ")" << std::endl;
 
   if (d_thss != NULL) {
@@ -259,14 +267,15 @@ void TheoryUF::preRegisterTerm(TNode node) {
     d_equalityEngine->addTerm(node);
     break;
   }
-}/* TheoryUF::preRegisterTerm() */
+}
 
 bool TheoryUF::propagateLit(TNode literal)
 {
   Debug("uf::propagate") << "TheoryUF::propagateLit(" << literal << ")"
                          << std::endl;
   // If already in conflict, no more propagation
-  if (d_conflict) {
+  if (d_state.isInConflict())
+  {
     Debug("uf::propagate") << "TheoryUF::propagateLit(" << literal
                            << "): already in conflict" << std::endl;
     return false;
@@ -274,7 +283,7 @@ bool TheoryUF::propagateLit(TNode literal)
   // Propagate out
   bool ok = d_out->propagate(literal);
   if (!ok) {
-    d_conflict = true;
+    d_state.notifyInConflict();
   }
   return ok;
 }/* TheoryUF::propagate(TNode) */
@@ -314,24 +323,12 @@ Node TheoryUF::explain(TNode literal, eq::EqProof* pf) {
   return mkAnd(assumptions);
 }
 
-bool TheoryUF::collectModelInfo(TheoryModel* m)
+bool TheoryUF::collectModelValues(TheoryModel* m, const std::set<Node>& termSet)
 {
-  Debug("uf") << "UF : collectModelInfo " << std::endl;
-  set<Node> termSet;
-
-  // Compute terms appearing in assertions and shared terms
-  computeRelevantTerms(termSet);
-
-  if (!m->assertEqualityEngine(d_equalityEngine, &termSet))
-  {
-    Trace("uf") << "Collect model info fail UF" << std::endl;
-    return false;
-  }
-
   if( options::ufHo() ){
     // must add extensionality disequalities for all pairs of (non-disequal)
     // function equivalence classes.
-    if (!d_ho->collectModelInfoHo(termSet, m))
+    if (!d_ho->collectModelInfoHo(m, termSet))
     {
       Trace("uf") << "Collect model info fail HO" << std::endl;
       return false;
@@ -503,12 +500,6 @@ EqualityStatus TheoryUF::getEqualityStatus(TNode a, TNode b) {
   return EQUALITY_FALSE_IN_MODEL;
 }
 
-void TheoryUF::notifySharedTerm(TNode t)
-{
-  Debug("uf::sharing") << "TheoryUF::addSharedTerm(" << t << ")" << std::endl;
-  d_equalityEngine->addTriggerTerm(t, THEORY_UF);
-}
-
 bool TheoryUF::areCareDisequal(TNode x, TNode y){
   Assert(d_equalityEngine->hasTerm(x));
   Assert(d_equalityEngine->hasTerm(y));
@@ -674,10 +665,10 @@ void TheoryUF::computeCareGraph() {
 void TheoryUF::conflict(TNode a, TNode b) {
   std::shared_ptr<eq::EqProof> pf =
       d_proofsEnabled ? std::make_shared<eq::EqProof>() : nullptr;
-  d_conflictNode = explain(a.eqNode(b), pf.get());
+  Node conf = explain(a.eqNode(b), pf.get());
   std::unique_ptr<ProofUF> puf(d_proofsEnabled ? new ProofUF(pf) : nullptr);
-  d_out->conflict(d_conflictNode, std::move(puf));
-  d_conflict = true;
+  d_out->conflict(conf, std::move(puf));
+  d_state.notifyInConflict();
 }
 
 void TheoryUF::eqNotifyNewClass(TNode t) {
index 916c6ef6bf01a341fdbd18130a790d7999e5d099..2bfd7e16c71f466b07738a659a7919467870146d 100644 (file)
@@ -103,13 +103,6 @@ private:
   /** the higher-order solver extension (or nullptr if it does not exist) */
   std::unique_ptr<HoExtension> d_ho;
 
-  /** Are we in conflict */
-  context::CDO<bool> d_conflict;
-
-  /** The conflict node */
-  Node d_conflictNode;
-
-
   /** node for true */
   Node d_true;
 
@@ -174,17 +167,31 @@ private:
   void finishInit() override;
   //--------------------------------- end initialization
 
-  void check(Effort) override;
+  //--------------------------------- standard check
+  /** Post-check, called after the fact queue of the theory is processed. */
+  void postCheck(Effort level) override;
+  /** Pre-notify fact, return true if processed. */
+  bool preNotifyFact(TNode atom,
+                     bool pol,
+                     TNode fact,
+                     bool isPrereg,
+                     bool isInternal) override;
+  /** Notify fact */
+  void notifyFact(TNode atom, bool pol, TNode fact, bool isInternal) override;
+  //--------------------------------- end standard check
+
+  /** Collect model values in m based on the relevant terms given by termSet */
+  bool collectModelValues(TheoryModel* m,
+                          const std::set<Node>& termSet) override;
+
   TrustNode expandDefinition(Node node) override;
   void preRegisterTerm(TNode term) override;
   TrustNode explain(TNode n) override;
 
-  bool collectModelInfo(TheoryModel* m) override;
 
   void ppStaticLearn(TNode in, NodeBuilder<>& learned) override;
   void presolve() override;
 
-  void notifySharedTerm(TNode n) override;
   void computeCareGraph() override;
 
   EqualityStatus getEqualityStatus(TNode a, TNode b) override;
@@ -193,8 +200,6 @@ private:
 
   /** get a pointer to the uf with cardinality */
   CardinalityExtension* getCardinalityExtension() const { return d_thss.get(); }
-  /** are we in conflict? */
-  bool inConflict() const { return d_conflict; }
 
  private:
   bool areCareDisequal(TNode x, TNode y);