[proof-new] Adding MACRO_RESOLUTION rule and updating proof checker (#5612)
authorHaniel Barbosa <hanielbbarbosa@gmail.com>
Tue, 8 Dec 2020 03:54:11 +0000 (00:54 -0300)
committerGitHub <noreply@github.com>
Tue, 8 Dec 2020 03:54:11 +0000 (21:54 -0600)
Previously CHAIN_RESOLUTION's definition and checker were not properly capturing its intended behavior as merely an n-ary RESOLUTION rule (i.e., no factoring nor reordering). A new rule, MACRO_RESOLUTION, now captures this behavior: it combines CHAIN_RESOLUTION, FACTORING, and REORDERING.

This commit also adds a proof checker for the new rule and updates the proof checker of CHAIN_RESOLUTION.

src/expr/proof_rule.h
src/theory/booleans/proof_checker.cpp

index ada10cb9f253aea81f61bb5e2ecf7087d349e237..9754735e6fe2a0063ac0e11aaf7d9e0417a8c46e 100644 (file)
@@ -251,16 +251,17 @@ enum class PfRule : uint32_t
   // ======== Resolution
   // Children:
   //  (P1:C1, P2:C2)
-  // Arguments: (id, L)
+  // Arguments: (pol, L)
   // ---------------------
   // Conclusion: C
   // where
   //   - C1 and C2 are nodes viewed as clauses, i.e., either an OR node with
   //     each children viewed as a literal or a node viewed as a literal. Note
   //     that an OR node could also be a literal.
-  //   - id is either true or false
+  //   - pol is either true or false, representing the polarity of the pivot on
+  //     the first clause
   //   - L is the pivot of the resolution, which occurs as is (resp. under a
-  //     NOT) in C1 and negatively (as is) in C2 if id = true (id = false).
+  //     NOT) in C1 and negatively (as is) in C2 if pol = true (pol = false).
   //   C is a clause resulting from collecting all the literals in C1, minus the
   //   first occurrence of the pivot or its negation, and C2, minus the first
   //   occurrence of the pivot or its negation, according to the policy above.
@@ -273,15 +274,18 @@ enum class PfRule : uint32_t
   //   to resolution but rather to a weakening of the clause that did not have a
   //   literal eliminated.
   RESOLUTION,
-  // ======== Chain Resolution
-  // Children: (P1:(or F_{1,1} ... F_{1,n1}), ..., Pm:(or F_{m,1} ... F_{m,nm}))
-  // Arguments: (L_1, ..., L_{m-1})
+  // ======== N-ary Resolution
+  // Children: (P1:C_1, ..., Pm:C_n)
+  // Arguments: (pol_1, L_1, ..., pol_{n-1}, L_{n-1})
   // ---------------------
-  // Conclusion: C_m'
+  // Conclusion: C
   // where
-  //   let "C_1 <>_l C_2" represent the resolution of C_1 with C_2 with pivot l,
-  //   let C_1' = C_1 (from P_1),
-  //   for each i > 1, C_i' = C_i <>_L_i C_{i-1}'
+  //   - let C_1 ... C_n be nodes viewed as clauses, as defined above
+  //   - let "C_1 <>_{L,pol} C_2" represent the resolution of C_1 with C_2 with
+  //     pivot L and polarity pol, as defined above
+  //   - let C_1' = C_1 (from P1),
+  //   - for each i > 1, let C_i' = C_{i-1} <>_{L_{i-1}, pol_{i-1}} C_i'
+  //   The result of the chain resolution is C = C_n'
   CHAIN_RESOLUTION,
   // ======== Factoring
   // Children: (P:C1)
@@ -301,6 +305,21 @@ enum class PfRule : uint32_t
   //  Set representations of C1 and C2 is the same but the number of literals in
   //  C2 is the same of that of C1
   REORDERING,
+  // ======== N-ary Resolution + Factoring + Reordering
+  // Children: (P1:C_1, ..., Pm:C_n)
+  // Arguments: (C, pol_1, L_1, ..., pol_{n-1}, L_{n-1})
+  // ---------------------
+  // Conclusion: C
+  // where
+  //   - let C_1 ... C_n be nodes viewed as clauses, as defined in RESOLUTION
+  //   - let "C_1 <>_{L,pol} C_2" represent the resolution of C_1 with C_2 with
+  //     pivot L and polarity pol, as defined in RESOLUTION
+  //   - let C_1' be equal, in its set representation, to C_1 (from P1),
+  //   - for each i > 1, let C_i' be equal, it its set representation, to
+  //     C_{i-1} <>_{L_{i-1}, pol_{i-1}} C_i'
+  //   The result of the chain resolution is C, which is equal, in its set
+  //   representation, to C_n'
+  MACRO_RESOLUTION,
 
   // ======== Split
   // Children: none
index 3eb523ab5c670f1dd119c811dee97856b55b301c..2cbf6a2e8256a9c11a97999a16288c66878229a6 100644 (file)
@@ -25,6 +25,7 @@ void BoolProofRuleChecker::registerTo(ProofChecker* pc)
   pc->registerChecker(PfRule::SPLIT, this);
   pc->registerChecker(PfRule::RESOLUTION, this);
   pc->registerChecker(PfRule::CHAIN_RESOLUTION, this);
+  pc->registerChecker(PfRule::MACRO_RESOLUTION, this);
   pc->registerChecker(PfRule::FACTORING, this);
   pc->registerChecker(PfRule::REORDERING, this);
   pc->registerChecker(PfRule::EQ_RESOLVE, this);
@@ -97,9 +98,9 @@ Node BoolProofRuleChecker::checkInternal(PfRule id,
     }
     for (unsigned i = 0; i < 2; ++i)
     {
-      // determine whether the clause is unit for effects of resolution, which
-      // is the case if it's not an OR node or it is an OR node but it is equal
-      // to the pivot
+      // determine whether the clause is a singleton for effects of resolution,
+      // which is the case if it's not an OR node or it is an OR node but it is
+      // equal to the pivot
       std::vector<Node> lits;
       if (children[i].getKind() == kind::OR && pivots[i] != children[i])
       {
@@ -190,29 +191,167 @@ Node BoolProofRuleChecker::checkInternal(PfRule id,
   if (id == PfRule::CHAIN_RESOLUTION)
   {
     Assert(children.size() > 1);
-    Assert(args.size() == children.size() - 1);
+    Assert(args.size() == 2 * (children.size() - 1));
     Trace("bool-pfcheck") << "chain_res:\n" << push;
+    NodeManager* nm = NodeManager::currentNM();
+    Node trueNode = nm->mkConst(true);
+    Node falseNode = nm->mkConst(false);
     std::vector<Node> clauseNodes;
-    for (unsigned i = 0, childrenSize = children.size(); i < childrenSize; ++i)
+    // literals to be removed from the virtual lhs clause of the resolution
+    std::unordered_map<Node, unsigned, NodeHashFunction> lhsElim;
+    for (std::size_t i = 0, argsSize = args.size(); i < argsSize; i = i + 2)
+    {
+      // whether pivot should occur as is or negated depends on the polarity of
+      // each step in the chain
+      if (args[i] == trueNode)
+      {
+        lhsElim[args[i + 1]]++;
+      }
+      else
+      {
+        Assert(args[i] == falseNode);
+        lhsElim[args[i + 1].notNode()]++;
+      }
+    }
+    if (Trace.isOn("bool-pfcheck"))
+    {
+      Trace("bool-pfcheck")
+          << "Original elimination multiset for lhs clause:\n";
+      for (const auto& pair : lhsElim)
+      {
+        Trace("bool-pfcheck")
+            << "\t- " << pair.first << " {" << pair.second << "}\n";
+      }
+    }
+    for (std::size_t i = 0, childrenSize = children.size(); i < childrenSize;
+         ++i)
+    {
+      // literal to be removed from rhs clause. They will be negated
+      Node rhsElim = Node::null();
+      if (Trace.isOn("bool-pfcheck"))
+      {
+        Trace("bool-pfcheck") << i << ": current lhsElim:\n";
+        for (const auto& pair : lhsElim)
+        {
+          Trace("bool-pfcheck")
+              << "\t- " << pair.first << " {" << pair.second << "}\n";
+        }
+      }
+      if (i > 0)
+      {
+        std::size_t index = 2 * (i - 1);
+        rhsElim = args[index] == trueNode ? args[index + 1].notNode()
+                                          : args[index + 1];
+        Trace("bool-pfcheck") << i << ": rhs elim: " << rhsElim << "\n";
+      }
+      // Only add to conclusion nodes that are not in elimination set. First get
+      // the nodes.
+      //
+      // Since a Node cannot hold an OR with a single child we need to
+      // disambiguate singleton clauses that are OR nodes from non-singleton
+      // clauses (i.e. unit clauses in the SAT solver).
+      //
+      // If the child is not an OR, it is a singleton clause and we take the
+      // child itself as the clause. Otherwise the child can only be a singleton
+      // clause if the child itself is used as a resolution literal, i.e. if the
+      // child is in lhsElim or is equal to rhsElim (which means that the
+      // negation of the child is in lhsElim).
+      std::vector<Node> lits;
+      if (children[i].getKind() == kind::OR && lhsElim.count(children[i]) == 0
+          && children[i] != rhsElim)
+      {
+        lits.insert(lits.end(), children[i].begin(), children[i].end());
+      }
+      else
+      {
+        lits.push_back(children[i]);
+      }
+      Trace("bool-pfcheck") << i << ": clause lits: " << lits << "\n";
+      std::vector<Node> added;
+      for (std::size_t j = 0, size = lits.size(); j < size; ++j)
+      {
+        if (lits[j] == rhsElim)
+        {
+          rhsElim == Node::null();
+          continue;
+        }
+        auto it = lhsElim.find(lits[j]);
+        if (it == lhsElim.end())
+        {
+          clauseNodes.push_back(lits[j]);
+          added.push_back(lits[j]);
+        }
+        else
+        {
+          // remove occurrence
+          it->second--;
+          if (it->second == 0)
+          {
+            lhsElim.erase(it);
+          }
+        }
+      }
+      Trace("bool-pfcheck") << i << ": added lits: " << added << "\n\n";
+    }
+    Trace("bool-pfcheck") << "clause: " << clauseNodes << "\n" << pop;
+    return clauseNodes.empty()
+               ? nm->mkConst(false)
+               : clauseNodes.size() == 1 ? clauseNodes[0]
+                                         : nm->mkNode(kind::OR, clauseNodes);
+  }
+  if (id == PfRule::MACRO_RESOLUTION)
+  {
+    Assert(children.size() > 1);
+    Assert(args.size() == 2 * (children.size() - 1) + 1);
+    Trace("bool-pfcheck") << "macro_res: " << args[0] << "\n" << push;
+    NodeManager* nm = NodeManager::currentNM();
+    Node trueNode = nm->mkConst(true);
+    Node falseNode = nm->mkConst(false);
+    std::vector<Node> clauseNodes;
+    for (std::size_t i = 0, childrenSize = children.size(); i < childrenSize;
+         ++i)
     {
       std::unordered_set<Node, NodeHashFunction> elim;
       // literals to be removed from "first" clause
       if (i < childrenSize - 1)
       {
-        elim.insert(args.begin() + i, args.end());
+        for (std::size_t j = (2 * i) + 1, argsSize = args.size(); j < argsSize;
+             j = j + 2)
+        {
+          // whether pivot should occur as is or negated depends on the polarity
+          // of each step in the macro
+          if (args[j] == trueNode)
+          {
+            elim.insert(args[j + 1]);
+          }
+          else
+          {
+            Assert(args[j] == falseNode);
+            elim.insert(args[j + 1].notNode());
+          }
+        }
       }
       // literal to be removed from "second" clause. They will be negated
       if (i > 0)
       {
-        elim.insert(args[i - 1].negate());
+        std::size_t index = 2 * (i - 1) + 1;
+        Node pivot = args[index] == trueNode ? args[index + 1].notNode()
+                                             : args[index + 1];
+        elim.insert(pivot);
       }
       Trace("bool-pfcheck") << i << ": elimination set: " << elim << "\n";
       // only add to conclusion nodes that are not in elimination set. First get
       // the nodes.
       //
-      // Since unit clauses can also be OR nodes, we rely on the invariant that
-      // non-unit clauses will not occur themselves in their elimination sets.
-      // If they do then they must be unit.
+      // Since a Node cannot hold an OR with a single child we need to
+      // disambiguate singleton clauses that are OR nodes from non-singleton
+      // clauses (i.e. unit clauses in the SAT solver).
+      //
+      // If the child is not an OR, it is a singleton clause and we take the
+      // child itself as the clause. Otherwise the child can only be a singleton
+      // clause if the child itself is used as a resolution literal, i.e. if the
+      // child is in lhsElim or is equal to rhsElim (which means that the
+      // negation of the child is in lhsElim).
       std::vector<Node> lits;
       if (children[i].getKind() == kind::OR && elim.count(children[i]) == 0)
       {
@@ -224,22 +363,52 @@ Node BoolProofRuleChecker::checkInternal(PfRule id,
       }
       Trace("bool-pfcheck") << i << ": clause lits: " << lits << "\n";
       std::vector<Node> added;
-      for (unsigned j = 0, size = lits.size(); j < size; ++j)
+      for (std::size_t j = 0, size = lits.size(); j < size; ++j)
       {
+        // only add if literal does not occur in elimination set
         if (elim.count(lits[j]) == 0)
         {
           clauseNodes.push_back(lits[j]);
           added.push_back(lits[j]);
+          // eliminate duplicates
+          elim.insert(lits[j]);
         }
       }
       Trace("bool-pfcheck") << i << ": added lits: " << added << "\n\n";
     }
-    Trace("bool-pfcheck") << "clause: " << clauseNodes << "\n" << pop;
-    NodeManager* nm = NodeManager::currentNM();
-    return clauseNodes.empty()
-               ? nm->mkConst<bool>(false)
-               : clauseNodes.size() == 1 ? clauseNodes[0]
-                                         : nm->mkNode(kind::OR, clauseNodes);
+    Trace("bool-pfcheck") << "clause: " << clauseNodes << "\n";
+    // check that set representation is the same as of the given conclusion
+    std::unordered_set<Node, NodeHashFunction> clauseComputed{
+        clauseNodes.begin(), clauseNodes.end()};
+    Trace("bool-pfcheck") << "clauseSet: " << clauseComputed << "\n" << pop;
+    if (clauseComputed.empty())
+    {
+      // conclusion differ
+      if (args[0] != falseNode)
+      {
+        return Node::null();
+      }
+      return args[0];
+    }
+    if (clauseComputed.size() == 1)
+    {
+      // conclusion differ
+      if (args[0] != *clauseComputed.begin())
+      {
+        return Node::null();
+      }
+      return args[0];
+    }
+    // At this point, should amount to them differing only on order. So the
+    // original result can't be a singleton clause
+    if (args[0].getKind() != kind::OR
+        || clauseComputed.size() != args[0].getNumChildren())
+    {
+      return Node::null();
+    }
+    std::unordered_set<Node, NodeHashFunction> clauseGiven{args[0].begin(),
+                                                           args[0].end()};
+    return clauseComputed == clauseGiven ? args[0] : Node::null();
   }
   if (id == PfRule::SPLIT)
   {
@@ -508,7 +677,8 @@ Node BoolProofRuleChecker::checkInternal(PfRule id,
       return Node::null();
     }
     std::vector<Node> disjuncts;
-    for (unsigned i = 0, size = children[0][0].getNumChildren(); i < size; ++i)
+    for (std::size_t i = 0, size = children[0][0].getNumChildren(); i < size;
+         ++i)
     {
       disjuncts.push_back(children[0][0][i].notNode());
     }