(proof-new) Improving proof-production in Equality Engine (#4871)
authorHaniel Barbosa <hanielbbarbosa@gmail.com>
Wed, 12 Aug 2020 14:31:30 +0000 (11:31 -0300)
committerGitHub <noreply@github.com>
Wed, 12 Aug 2020 14:31:30 +0000 (09:31 -0500)
This commit improves functionalities of the equality engine so that it is easier to produce proofs for its reasoning. They are:

avoiding assertion of already entailed predicates/equalities.
better EqProof of disequalities with constants
correct EqProof involving n-ary congruence kinds

src/theory/uf/equality_engine.cpp
src/theory/uf/equality_engine.h

index 3abc565535aa0dfdcbe7081bf3e50309a1c35118..172b2407c172d5747aae7c0aee990f9acba11a6c 100644 (file)
@@ -456,19 +456,33 @@ void EqualityEngine::assertEqualityInternal(TNode t1, TNode t2, TNode reason, un
   enqueue(MergeCandidate(t1Id, t2Id, pid, reason));
 }
 
-void EqualityEngine::assertPredicate(TNode t, bool polarity, TNode reason, unsigned pid) {
+bool EqualityEngine::assertPredicate(TNode t,
+                                     bool polarity,
+                                     TNode reason,
+                                     unsigned pid)
+{
   Debug("equality") << d_name << "::eq::addPredicate(" << t << "," << (polarity ? "true" : "false") << ")" << std::endl;
   Assert(t.getKind() != kind::EQUAL) << "Use assertEquality instead";
-  assertEqualityInternal(t, polarity ? d_true : d_false, reason, pid);
+  TNode b = polarity ? d_true : d_false;
+  if (hasTerm(t) && areEqual(t, b))
+  {
+    return false;
+  }
+  assertEqualityInternal(t, b, reason, pid);
   propagate();
+  return true;
 }
 
-void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason, unsigned pid) {
+bool EqualityEngine::assertEquality(TNode eq,
+                                    bool polarity,
+                                    TNode reason,
+                                    unsigned pid)
+{
   Debug("equality") << d_name << "::eq::addEquality(" << eq << "," << (polarity ? "true" : "false") << ")" << std::endl;
   if (polarity) {
     // If two terms are already equal, don't assert anything
     if (hasTerm(eq[0]) && hasTerm(eq[1]) && areEqual(eq[0], eq[1])) {
-      return;
+      return false;
     }
     // Add equality between terms
     assertEqualityInternal(eq[0], eq[1], reason, pid);
@@ -476,7 +490,7 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason, unsig
   } else {
     // If two terms are already dis-equal, don't assert anything
     if (hasTerm(eq[0]) && hasTerm(eq[1]) && areDisequal(eq[0], eq[1], false)) {
-      return;
+      return false;
     }
 
     // notify the theory
@@ -490,7 +504,7 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason, unsig
     propagate();
 
     if (d_done) {
-      return;
+      return true;
     }
 
     // If both have constant representatives, we don't notify anyone
@@ -499,7 +513,7 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason, unsig
     EqualityNodeId aClassId = getEqualityNode(a).getFind();
     EqualityNodeId bClassId = getEqualityNode(b).getFind();
     if (d_isConstant[aClassId] && d_isConstant[bClassId]) {
-      return;
+      return true;
     }
 
     // If we are adding a disequality, notify of the shared term representatives
@@ -551,6 +565,7 @@ void EqualityEngine::assertEquality(TNode eq, bool polarity, TNode reason, unsig
       }
     }
   }
+  return true;
 }
 
 TNode EqualityEngine::getRepresentative(TNode t) const {
@@ -980,6 +995,77 @@ std::string EqualityEngine::edgesToString(EqualityEdgeId edgeId) const {
   return out.str();
 }
 
+void EqualityEngine::buildEqConclusion(EqualityNodeId id1,
+                                       EqualityNodeId id2,
+                                       EqProof* eqp) const
+{
+  Kind k1 = d_nodes[id1].getKind();
+  Kind k2 = d_nodes[id2].getKind();
+  // only try to build if ids do not correspond to internal nodes. If they do,
+  // only try to build build if full applications corresponding to the given ids
+  // have the same congruence n-ary non-APPLY_UF kind, since the internal nodes
+  // may be full nodes.
+  if ((d_isInternal[id1] || d_isInternal[id2])
+      && (k1 != k2 || k1 == kind::APPLY_UF || !ExprManager::isNAryKind(k1)))
+  {
+    return;
+  }
+  Node eq[2];
+  NodeManager* nm = NodeManager::currentNM();
+  for (unsigned i = 0; i < 2; ++i)
+  {
+    EqualityNodeId equalityNodeId = i == 0 ? id1 : id2;
+    Node equalityNode = d_nodes[equalityNodeId];
+    // if not an internal node, just retrieve it
+    if (!d_isInternal[equalityNodeId])
+    {
+      eq[i] = equalityNode;
+      continue;
+    }
+    // build node relative to partial application of this
+    // n-ary kind. We get the full application, then we get
+    // the arguments relative to how partial the internal
+    // node is, and build the application
+
+    // get number of children of partial app:
+    // #children of full app - (id of full app - id of
+    // partial app)
+    EqualityNodeId fullAppId = getNodeId(equalityNode);
+    EqualityNodeId curr = fullAppId;
+    unsigned separation = 0;
+    Assert(fullAppId >= equalityNodeId);
+    while (curr != equalityNodeId)
+    {
+      separation = separation + (d_nodes[curr--] == equalityNode ? 1 : 0);
+    }
+    // compute separation, which is how many ids with the
+    // same fullappnode exist between equalityNodeId and
+    // fullAppId
+    unsigned numChildren = equalityNode.getNumChildren() - separation;
+    Assert(numChildren < equalityNode.getNumChildren())
+        << "broke for numChildren " << numChildren << ", fullAppId "
+        << fullAppId << ", equalityNodeId " << equalityNodeId << ", node "
+        << equalityNode << ", cong: {" << id1 << "} " << d_nodes[id1] << " = {"
+        << id2 << "} " << d_nodes[id2] << "\n";
+    // if has at least as many children as the minimal
+    // number of children of the n-ary kind, build the node
+    if (numChildren >= ExprManager::minArity(k1))
+    {
+      std::vector<Node> children;
+      for (unsigned j = 0; j < numChildren; ++j)
+      {
+        children.push_back(equalityNode[j]);
+      }
+      eq[i] = nm->mkNode(k1, children);
+    }
+  }
+  // if built equality, add it as eqp's conclusion
+  if (!eq[0].isNull() && !eq[1].isNull())
+  {
+    eqp->d_node = eq[0].eqNode(eq[1]);
+  }
+}
+
 void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
                                      std::vector<TNode>& equalities,
                                      EqProof* eqp) const {
@@ -1088,10 +1174,48 @@ void EqualityEngine::explainEquality(TNode t1, TNode t2, bool polarity,
         Assert(eqp->d_node[0][1].isConst());
         eqp->d_id = MERGED_THROUGH_CONSTANTS;
       } else if (eqp->d_children.size() == 1) {
-        // The transitivity proof has just one child. Simplify.
-        std::shared_ptr<EqProof> temp = eqp->d_children[0];
-        eqp->d_children.clear();
-        *eqp = *temp;
+        Node cnode = eqp->d_children[0]->d_node;
+        Debug("pf::ee") << "Simplifying " << cnode << " from " << eqp->d_node
+                        << std::endl;
+        bool simpTrans = true;
+        if (cnode.getKind() == kind::EQUAL)
+        {
+          // It may be the case that we have a proof of x = c2 and we want to
+          // conclude x != c1. If this is the case, below we construct:
+          //
+          //          -------- MERGED_THROUGH_EQUALITY
+          // x = c2   c1 != c2
+          // ----------------- TRANS
+          //     x != c1
+          TNode c1 = t1.isConst() ? t1 : (t2.isConst() ? t2 : TNode::null());
+          TNode nc = t1.isConst() ? t2 : (t2.isConst() ? t1 : TNode::null());
+          Node c2;
+          // merge constants transitivity
+          for (unsigned i = 0; i < 2; i++)
+          {
+            if (cnode[i].isConst() && cnode[1 - i] == nc)
+            {
+              c2 = cnode[i];
+              break;
+            }
+          }
+          if (!c1.isNull() && !c2.isNull())
+          {
+            simpTrans = false;
+            Assert(c1.getType().isComparableTo(c2.getType()));
+            std::shared_ptr<EqProof> eqpmc = std::make_shared<EqProof>();
+            eqpmc->d_id = MERGED_THROUGH_CONSTANTS;
+            eqpmc->d_node = c1.eqNode(c2).eqNode(d_false);
+            eqp->d_children.push_back(eqpmc);
+          }
+        }
+        if (simpTrans)
+        {
+          // The transitivity proof has just one child. Simplify.
+          std::shared_ptr<EqProof> temp = eqp->d_children[0];
+          eqp->d_children.clear();
+          *eqp = *temp;
+        }
       }
 
       if (Debug.isOn("pf::ee"))
@@ -1168,7 +1292,7 @@ void EqualityEngine::getExplanation(
         // We may have cached null in its place, create the trivial proof now.
         Assert(d_nodes[t1Id] == d_nodes[t2Id]);
         Assert(eqp->d_id == MERGED_THROUGH_REFLEXIVITY);
-        eqp->d_node = d_nodes[t1Id];
+        eqp->d_node = d_nodes[t1Id].eqNode(d_nodes[t1Id]);
       }
       return;
     }
@@ -1191,10 +1315,32 @@ void EqualityEngine::getExplanation(
   // If the nodes are the same, we're done
   if (t1Id == t2Id){
     if( eqp ) {
-      if ((d_nodes[t1Id].getKind() == kind::BUILTIN) && (d_nodes[t1Id].getConst<Kind>() == kind::SELECT)) {
+      if (options::proofNew())
+      {
+        // ignore equalities between function symbols, i.e. internal nullary
+        // non-constant nodes.
+        //
+        // Note that this is robust for HOL because in that case function
+        // symbols are not internal nodes
+        if (d_isInternal[t1Id] && d_nodes[t1Id].getNumChildren() == 0
+            && !d_isConstant[t1Id])
+        {
+          eqp->d_node = Node::null();
+        }
+        else
+        {
+          Assert(d_nodes[t1Id].getKind() != kind::BUILTIN);
+          eqp->d_node = d_nodes[t1Id].eqNode(d_nodes[t1Id]);
+        }
+      }
+      else if ((d_nodes[t1Id].getKind() == kind::BUILTIN)
+               && (d_nodes[t1Id].getConst<Kind>() == kind::SELECT))
+      {
         std::vector<Node> no_children;
         eqp->d_node = NodeManager::currentNM()->mkNode(kind::PARTIAL_SELECT_0, no_children);
-      } else {
+      }
+      else
+      {
         eqp->d_node = ProofManager::currentPM()->mkOp(d_nodes[t1Id]);
       }
     }
@@ -1299,13 +1445,36 @@ void EqualityEngine::getExplanation(
               std::shared_ptr<EqProof> eqpc2 =
                   eqpc ? std::make_shared<EqProof>() : nullptr;
               getExplanation(f1.d_b, f2.d_b, equalities, cache, eqpc2.get());
-              if( eqpc ){
-                eqpc->d_children.push_back( eqpc1 );
-                eqpc->d_children.push_back( eqpc2 );
-                if( d_nodes[currentNode].getKind()==kind::EQUAL ){
+              if (eqpc)
+              {
+                eqpc->d_children.push_back(eqpc1);
+                eqpc->d_children.push_back(eqpc2);
+                if (options::proofNew())
+                {
+                  // build conclusion if ids correspond to non-internal nodes or
+                  // if non-internal nodes can be retrieved from them (in the
+                  // case of n-ary applications), otherwise leave conclusion as
+                  // null. This is only done for congruence kinds, since
+                  // congruence is not used otherwise.
+                  Kind k = d_nodes[currentNode].getKind();
+                  if (d_congruenceKinds[k])
+                  {
+                    buildEqConclusion(currentNode, edgeNode, eqpc.get());
+                  }
+                  else
+                  {
+                    Assert(k == kind::EQUAL)
+                        << "not an internal node " << d_nodes[currentNode]
+                        << " with non-congruence with " << k << "\n";
+                  }
+                }
+                else if (d_nodes[currentNode].getKind() == kind::EQUAL)
+                {
                   //leave node null for now
                   eqpc->d_node = Node::null();
-                } else {
+                }
+                else
+                {
                   if (d_nodes[f1.d_a].getKind() == kind::APPLY_UF
                       || d_nodes[f1.d_a].getKind() == kind::SELECT
                       || d_nodes[f1.d_a].getKind() == kind::STORE)
@@ -1369,9 +1538,26 @@ void EqualityEngine::getExplanation(
               Debug("equality") << push;
 
               // Get the node we interpreted
-              TNode interpreted = d_nodes[currentNode];
-              if (interpreted.isConst()) {
-                interpreted = d_nodes[edgeNode];
+              TNode interpreted;
+              if (eqpc && options::proofNew())
+              {
+                // build the conclusion f(c1, ..., cn) = c
+                if (d_nodes[currentNode].isConst())
+                {
+                  interpreted = d_nodes[edgeNode];
+                  eqpc->d_node = d_nodes[edgeNode].eqNode(d_nodes[currentNode]);
+                }
+                else
+                {
+                  interpreted = d_nodes[currentNode];
+                  eqpc->d_node = d_nodes[currentNode].eqNode(d_nodes[edgeNode]);
+                }
+              }
+              else
+              {
+                interpreted = d_nodes[currentNode].isConst()
+                                  ? d_nodes[edgeNode]
+                                  : d_nodes[currentNode];
               }
 
               // Explain why a is a constant by explaining each argument
@@ -1419,7 +1605,11 @@ void EqualityEngine::getExplanation(
                                        eqpc.get());
                 }
                 if (reasonType == MERGED_THROUGH_EQUALITY) {
-                  eqpc->d_node = reason;
+                  // in the new proof infrastructure we can assume that "theory
+                  // assumptions", which are a consequence of theory reasoning
+                  // on other assumptions, are externally justified. In this
+                  // case we can use (= a b) directly as the conclusion here.
+                  eqpc->d_node = !options::proofNew() ? reason : b.eqNode(a);
                 } else {
                   // The LFSC translator prefers (not (= a b)) over (= (= a b) false)
 
@@ -1463,7 +1653,20 @@ void EqualityEngine::getExplanation(
             } else {
               eqp->d_id = MERGED_THROUGH_TRANS;
               eqp->d_children.insert( eqp->d_children.end(), eqp_trans.begin(), eqp_trans.end() );
-              eqp->d_node = NodeManager::currentNM()->mkNode(kind::EQUAL, d_nodes[t1Id], d_nodes[t2Id]);
+              if (options::proofNew())
+              {
+                // build conclusion in case of equality between non-internal
+                // nodes or of n-ary congruence kinds, otherwise leave as
+                // null. The latter is necessary for the overall handling of
+                // congruence proofs involving n-ary kinds, see
+                // EqProof::reduceNestedCongruence for more details.
+                buildEqConclusion(t1Id, t2Id, eqp);
+              }
+              else
+              {
+                eqp->d_node = NodeManager::currentNM()->mkNode(
+                    kind::EQUAL, d_nodes[t1Id], d_nodes[t2Id]);
+              }
             }
             if (Debug.isOn("pf::ee"))
             {
@@ -1954,6 +2157,8 @@ unsigned EqualityEngine::getFreshMergeReasonType() {
   return d_freshMergeReasonType++;
 }
 
+std::string EqualityEngine::identify() const { return d_name; }
+
 void EqualityEngine::addTriggerTerm(TNode t, TheoryId tag)
 {
   Debug("equality::trigger") << d_name << "::eq::addTriggerTerm(" << t << ", " << tag << ")" << std::endl;
index 42ae3437d0c6c656c1439e1d43c35a4fca9bb039..9d1fc6165d196107a0739b6ced4f0fd3739b6e37 100644 (file)
@@ -415,6 +415,23 @@ private:
   /** Are we in propagate */
   bool d_inPropagate;
 
+  /** Proof-new specific construction of equality conclusions for EqProofs
+   *
+   * Given two equality node ids, build an equality between the nodes they
+   * correspond to and add it as a conclusion to the given EqProof.
+   *
+   * The equality is only built if the nodes the ids correspond to are not
+   * internal nodes in the equality engine, i.e., they correspond to full
+   * applications of the respective kinds. Since the equality engine also
+   * applies congruence over n-ary kinds, internal nodes, i.e., partial
+   * applications, may still correspond to "full applications" in the
+   * first-order sense. Therefore this method also checks, in the case of n-ary
+   * congruence kinds, if an equality between "full applications" can be built.
+   */
+  void buildEqConclusion(EqualityNodeId id1,
+                         EqualityNodeId id2,
+                         EqProof* eqp) const;
+
   /**
    * Get an explanation of the equality t1 = t2. Returns the asserted equalities
    * that imply t1 = t2. Returns TNodes as the assertion equalities should be
@@ -695,8 +712,12 @@ public:
    * @param polarity true if asserting the predicate, false if
    *                 asserting the negated predicate
    * @param reason the reason to keep for building explanations
+   * @return true if a new fact was asserted, false if this call was a no-op.
    */
-  void assertPredicate(TNode p, bool polarity, TNode reason, unsigned pid = MERGED_THROUGH_EQUALITY);
+  bool assertPredicate(TNode p,
+                       bool polarity,
+                       TNode reason,
+                       unsigned pid = MERGED_THROUGH_EQUALITY);
 
   /**
    * Adds an equality eq with the given polarity to the database.
@@ -705,8 +726,12 @@ public:
    * @param polarity true if asserting the equality, false if
    *                 asserting the negated equality
    * @param reason the reason to keep for building explanations
+   * @return true if a new fact was asserted, false if this call was a no-op.
    */
-  void assertEquality(TNode eq, bool polarity, TNode reason, unsigned pid = MERGED_THROUGH_EQUALITY);
+  bool assertEquality(TNode eq,
+                      bool polarity,
+                      TNode reason,
+                      unsigned pid = MERGED_THROUGH_EQUALITY);
 
   /**
    * Returns the current representative of the term t.
@@ -807,6 +832,9 @@ public:
    * Returns a fresh merge reason type tag for the client to use.
    */
   unsigned getFreshMergeReasonType();
+
+  /** Identify this equality engine (for debugging, etc..) */
+  std::string identify() const;
 };
 
 } // Namespace eq