(new theory) Update TheoryFP to the new interface (#4953)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 28 Aug 2020 18:01:55 +0000 (13:01 -0500)
committerGitHub <noreply@github.com>
Fri, 28 Aug 2020 18:01:55 +0000 (13:01 -0500)
This updates the theory of floating points to the new interface (see #4929).

Notice that TheoryFP was not adding trigger terms to its equality engine (which should be done during notifySharedTerm), and thus was not propagating equalities between shared terms in combined theories. This PR updates its notifySharedTerm method to the default one.

FYI @martin-cs

src/theory/fp/theory_fp.cpp
src/theory/fp/theory_fp.h

index 0c5a9257248e839f5ad5f5fc9913ac594b10d5e5..4c59b1c0602828f78d1697134657403b971e0e63 100644 (file)
@@ -110,7 +110,6 @@ TheoryFp::TheoryFp(context::Context* c,
       d_registeredTerms(u),
       d_conv(u),
       d_expansionRequested(false),
-      d_conflict(c, false),
       d_conflictNode(c, Node::null()),
       d_minMap(u),
       d_maxMap(u),
@@ -908,15 +907,6 @@ void TheoryFp::preRegisterTerm(TNode node)
   return;
 }
 
-void TheoryFp::notifySharedTerm(TNode node)
-{
-  Trace("fp-addSharedTerm")
-      << "TheoryFp::notifySharedTerm(): " << node << std::endl;
-  // A system-wide invariant; terms must be registered before they are shared
-  Assert(isRegistered(node));
-  return;
-}
-
 void TheoryFp::handleLemma(Node node) {
   Trace("fp") << "TheoryFp::handleLemma(): asserting " << node << std::endl;
   // Preprocess has to be true because it contains embedded ITEs
@@ -926,83 +916,41 @@ void TheoryFp::handleLemma(Node node) {
   return;
 }
 
-bool TheoryFp::handlePropagation(TNode node) {
-  Trace("fp") << "TheoryFp::handlePropagation(): propagate " << node
-              << std::endl;
+bool TheoryFp::propagateLit(TNode node)
+{
+  Trace("fp") << "TheoryFp::propagateLit(): propagate " << node << std::endl;
 
   bool stat = d_out->propagate(node);
 
   if (!stat)
   {
-    d_conflict = true;
+    d_state.notifyInConflict();
   }
   return stat;
 }
 
-void TheoryFp::handleConflict(TNode node) {
-  Trace("fp") << "TheoryFp::handleConflict(): conflict detected " << node
-              << std::endl;
+void TheoryFp::conflictEqConstantMerge(TNode t1, TNode t2)
+{
+  std::vector<TNode> assumptions;
+  d_equalityEngine->explainEquality(t1, t2, true, assumptions);
+
+  Node conflict = helper::buildConjunct(assumptions);
+  Trace("fp") << "TheoryFp::conflictEqConstantMerge(): conflict detected "
+              << conflict << std::endl;
 
-  d_conflictNode = node;
-  d_conflict = true;
-  d_out->conflict(node);
+  d_conflictNode = conflict;
+  d_state.notifyInConflict();
+  d_out->conflict(conflict);
   return;
 }
 
-void TheoryFp::check(Effort level) {
-  Trace("fp") << "TheoryFp::check(): started at effort level " << level
-              << std::endl;
-
-  while (!done() && !d_conflict) {
-    // Get all the assertions
-    Assertion assertion = get();
-    TNode fact = assertion.d_assertion;
-
-    Debug("fp") << "TheoryFp::check(): processing " << fact << std::endl;
-
-    // Only handle equalities; the rest should be handled by
-    // the bit-vector theory
-
-    bool negated = fact.getKind() == kind::NOT;
-    TNode predicate = negated ? fact[0] : fact;
-
-    if (predicate.getKind() == kind::EQUAL) {
-      Assert(!(predicate[0].getType().isFloatingPoint()
-               || predicate[0].getType().isRoundingMode())
-             || isRegistered(predicate[0]));
-      Assert(!(predicate[1].getType().isFloatingPoint()
-               || predicate[1].getType().isRoundingMode())
-             || isRegistered(predicate[1]));
-      registerTerm(predicate);  // Needed for float equalities
-
-      if (negated) {
-        Debug("fp-eq") << "TheoryFp::check(): adding dis-equality " << fact[0]
-                       << std::endl;
-        d_equalityEngine->assertEquality(predicate, false, fact);
-      } else {
-        Debug("fp-eq") << "TheoryFp::check(): adding equality " << fact
-                       << std::endl;
-        d_equalityEngine->assertEquality(predicate, true, fact);
-      }
-    } else {
-      // A system-wide invariant; predicates are registered before they are
-      // asserted
-      Assert(isRegistered(predicate));
-
-      if (d_equalityEngine->isFunctionKind(predicate.getKind()))
-      {
-        Debug("fp-eq") << "TheoryFp::check(): adding predicate " << predicate
-                       << " is " << !negated << std::endl;
-        d_equalityEngine->assertPredicate(predicate, !negated, fact);
-      }
-    }
-  }
-
+void TheoryFp::postCheck(Effort level)
+{
   // Resolve the abstractions for the conversion lemmas
   if (level == EFFORT_LAST_CALL)
   {
     Trace("fp") << "TheoryFp::check(): checking abstractions" << std::endl;
-    TheoryModel *m = getValuation().getModel();
+    TheoryModelm = getValuation().getModel();
     bool lemmaAdded = false;
 
     for (abstractionMapType::const_iterator i = abstractionMap.begin();
@@ -1017,11 +965,35 @@ void TheoryFp::check(Effort level) {
   }
 
   Trace("fp") << "TheoryFp::check(): completed" << std::endl;
-
   /* Checking should be handled by the bit-vector engine */
-  return;
+}
 
-} /* TheoryFp::check() */
+bool TheoryFp::preNotifyFact(
+    TNode atom, bool pol, TNode fact, bool isPrereg, bool isInternal)
+{
+  if (atom.getKind() == kind::EQUAL)
+  {
+    Assert(!(atom[0].getType().isFloatingPoint()
+             || atom[0].getType().isRoundingMode())
+           || isRegistered(atom[0]));
+    Assert(!(atom[1].getType().isFloatingPoint()
+             || atom[1].getType().isRoundingMode())
+           || isRegistered(atom[1]));
+    registerTerm(atom);  // Needed for float equalities
+  }
+  else
+  {
+    // A system-wide invariant; predicates are registered before they are
+    // asserted
+    Assert(isRegistered(atom));
+
+    if (!d_equalityEngine->isFunctionKind(atom.getKind()))
+    {
+      return true;
+    }
+  }
+  return false;
+}
 
 TrustNode TheoryFp::explain(TNode n)
 {
@@ -1047,16 +1019,20 @@ Node TheoryFp::getModelValue(TNode var) {
   return d_conv.getValue(d_valuation, var);
 }
 
-bool TheoryFp::collectModelInfo(TheoryModel *m)
+bool TheoryFp::collectModelInfo(TheoryModelm)
 {
   std::set<Node> relevantTerms;
-
-  Trace("fp-collectModelInfo")
-      << "TheoryFp::collectModelInfo(): begin" << std::endl;
-
   // Work out which variables are needed
   computeRelevantTerms(relevantTerms);
+  // this override behavior to not assert equality engine
+  return collectModelValues(m, relevantTerms);
+}
 
+bool TheoryFp::collectModelValues(TheoryModel* m,
+                                  const std::set<Node>& relevantTerms)
+{
+  Trace("fp-collectModelInfo")
+      << "TheoryFp::collectModelInfo(): begin" << std::endl;
   if (Trace.isOn("fp-collectModelInfo")) {
     for (std::set<Node>::const_iterator i(relevantTerms.begin());
          i != relevantTerms.end(); ++i) {
@@ -1153,9 +1129,9 @@ bool TheoryFp::NotifyClass::eqNotifyTriggerPredicate(TNode predicate,
       << predicate << " is " << value << std::endl;
 
   if (value) {
-    return d_theorySolver.handlePropagation(predicate);
+    return d_theorySolver.propagateLit(predicate);
   }
-  return d_theorySolver.handlePropagation(predicate.notNode());
+  return d_theorySolver.propagateLit(predicate.notNode());
 }
 
 bool TheoryFp::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1,
@@ -1164,22 +1140,15 @@ bool TheoryFp::NotifyClass::eqNotifyTriggerTermEquality(TheoryId tag, TNode t1,
                  << t1 << (value ? " = " : " != ") << t2 << std::endl;
 
   if (value) {
-    return d_theorySolver.handlePropagation(t1.eqNode(t2));
-  } else {
-    return d_theorySolver.handlePropagation(t1.eqNode(t2).notNode());
+    return d_theorySolver.propagateLit(t1.eqNode(t2));
   }
+  return d_theorySolver.propagateLit(t1.eqNode(t2).notNode());
 }
 
 void TheoryFp::NotifyClass::eqNotifyConstantTermMerge(TNode t1, TNode t2) {
   Debug("fp-eq") << "TheoryFp::eqNotifyConstantTermMerge(): call back as " << t1
                  << " = " << t2 << std::endl;
-
-  std::vector<TNode> assumptions;
-  d_theorySolver.d_equalityEngine->explainEquality(t1, t2, true, assumptions);
-
-  Node conflict = helper::buildConjunct(assumptions);
-
-  d_theorySolver.handleConflict(conflict);
+  d_theorySolver.conflictEqConstantMerge(t1, t2);
 }
 
 }  // namespace fp
index 79ece7bce5ecf31a1f2da833e176d973d6b0c595..2ef3b3f3547009493a5411d80719537f6fe5d649 100644 (file)
@@ -58,15 +58,28 @@ class TheoryFp : public Theory {
   TrustNode expandDefinition(Node node) override;
 
   void preRegisterTerm(TNode node) override;
-  void notifySharedTerm(TNode node) override;
 
   TrustNode ppRewrite(TNode node) override;
 
-  void check(Effort) override;
-
+  //--------------------------------- standard check
+  /** Do we need a check call at last call effort? */
   bool needsCheckLastEffort() override { return true; }
+  /** Post-check, called after the fact queue of the theory is processed. */
+  void postCheck(Effort level) override;
+  /** Pre-notify fact, return true if processed. */
+  bool preNotifyFact(TNode atom,
+                     bool pol,
+                     TNode fact,
+                     bool isPrereg,
+                     bool isInternal) override;
+  //--------------------------------- end standard check
+
   Node getModelValue(TNode var) override;
   bool collectModelInfo(TheoryModel* m) override;
+  /** Collect model values in m based on the relevant terms given by
+   * relevantTerms */
+  bool collectModelValues(TheoryModel* m,
+                          const std::set<Node>& relevantTerms) override;
 
   std::string identify() const override { return "THEORY_FP"; }
 
@@ -108,10 +121,17 @@ class TheoryFp : public Theory {
 
   /** Interaction with the rest of the solver **/
   void handleLemma(Node node);
-  bool handlePropagation(TNode node);
-  void handleConflict(TNode node);
+  /**
+   * Called when literal node is inferred by the equality engine. This
+   * propagates node on the output channel.
+   */
+  bool propagateLit(TNode node);
+  /**
+   * Called when two constants t1 and t2 merge in the equality engine. This
+   * sends a conflict on the output channel.
+   */
+  void conflictEqConstantMerge(TNode t1, TNode t2);
 
-  context::CDO<bool> d_conflict;
   context::CDO<Node> d_conflictNode;
 
   typedef context::CDHashMap<TypeNode, Node, TypeNodeHashFunction>