Add posRewriteEqual to bags rewriter (#5498)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Sat, 21 Nov 2020 01:54:40 +0000 (19:54 -0600)
committerGitHub <noreply@github.com>
Sat, 21 Nov 2020 01:54:40 +0000 (19:54 -0600)
This PR fixes #5460 by adding posRewriteEqual to bags rewriter

src/theory/bags/bags_rewriter.cpp
src/theory/bags/bags_rewriter.h
src/theory/bags/rewrites.cpp
src/theory/bags/rewrites.h
test/unit/theory/theory_bags_rewriter_white.h

index f0540e9b78aee91011051ca19379e9c002c8679a..9479d2cc287caf25c9433edf4cc7426234e3ef80 100644 (file)
@@ -51,6 +51,10 @@ RewriteResponse BagsRewriter::postRewrite(TNode n)
     // no need to rewrite n if it is already in a normal form
     response = BagsRewriteResponse(n, Rewrite::NONE);
   }
+  else if(n.getKind() == EQUAL)
+  {
+    response = postRewriteEqual(n);
+  }
   else if (NormalForm::areChildrenConstants(n))
   {
     Node value = NormalForm::evaluate(n);
@@ -98,7 +102,7 @@ RewriteResponse BagsRewriter::preRewrite(TNode n)
   Kind k = n.getKind();
   switch (k)
   {
-    case EQUAL: response = rewriteEqual(n); break;
+    case EQUAL: response = preRewriteEqual(n); break;
     case SUBBAG: response = rewriteSubBag(n); break;
     default: response = BagsRewriteResponse(n, Rewrite::NONE);
   }
@@ -117,7 +121,7 @@ RewriteResponse BagsRewriter::preRewrite(TNode n)
   return RewriteResponse(RewriteStatus::REWRITE_DONE, n);
 }
 
-BagsRewriteResponse BagsRewriter::rewriteEqual(const TNode& n) const
+BagsRewriteResponse BagsRewriter::preRewriteEqual(const TNode& n) const
 {
   Assert(n.getKind() == EQUAL);
   if (n[0] == n[1])
@@ -475,6 +479,30 @@ BagsRewriteResponse BagsRewriter::rewriteToSet(const TNode& n) const
   return BagsRewriteResponse(n, Rewrite::NONE);
 }
 
+BagsRewriteResponse BagsRewriter::postRewriteEqual(const TNode& n) const
+{
+  Assert(n.getKind() == kind::EQUAL);
+  if (n[0] == n[1])
+  {
+    Node ret = NodeManager::currentNM()->mkConst(true);
+    return BagsRewriteResponse(ret, Rewrite::EQ_REFL);
+  }
+
+  if (n[0].isConst() && n[1].isConst())
+  {
+    Node ret = NodeManager::currentNM()->mkConst(false);
+    return BagsRewriteResponse(ret, Rewrite::EQ_CONST_FALSE);
+  }
+
+  // standard ordering
+  if (n[0] > n[1])
+  {
+    Node ret = NodeManager::currentNM()->mkNode(kind::EQUAL, n[1], n[0]);
+    return BagsRewriteResponse(ret, Rewrite::EQ_SYM);
+  }
+  return BagsRewriteResponse(n, Rewrite::NONE);
+}
+
 }  // namespace bags
 }  // namespace theory
 }  // namespace CVC4
index 8be6b948acde7dfea1fef4b41242e083ab68600f..a9b3b90bb5874afe54b8f45467a556290690f813 100644 (file)
@@ -60,7 +60,7 @@ class BagsRewriter : public TheoryRewriter
    * rewrites for n include:
    * - (= A A) = true where A is a bag
    */
-  BagsRewriteResponse rewriteEqual(const TNode& n) const;
+  BagsRewriteResponse preRewriteEqual(const TNode& n) const;
 
   /**
    * rewrites for n include:
@@ -202,6 +202,14 @@ class BagsRewriter : public TheoryRewriter
    */
   BagsRewriteResponse rewriteToSet(const TNode& n) const;
 
+  /**
+   *  rewrites for n include:
+   *  - (= A A) = true
+   *  - (= A B) = false if A and B are different bag constants
+   *  - (= B A) = (= A B) if A < B and at least one of A or B is not a constant
+   */
+  BagsRewriteResponse postRewriteEqual(const TNode& n) const;
+
  private:
   /** Reference to the rewriter statistics. */
   NodeManager* d_nm;
index d640bcdce49110eeea104041daf7f4b99391e630..85d0820afd196fb1f51dbbb302bdde3ee5792ad1 100644 (file)
@@ -32,6 +32,9 @@ const char* toString(Rewrite r)
     case Rewrite::COUNT_EMPTY: return "COUNT_EMPTY";
     case Rewrite::COUNT_MK_BAG: return "COUNT_MK_BAG";
     case Rewrite::DUPLICATE_REMOVAL_MK_BAG: return "DUPLICATE_REMOVAL_MK_BAG";
+    case Rewrite::EQ_CONST_FALSE: return "EQ_CONST_FALSE";
+    case Rewrite::EQ_REFL: return "EQ_REFL";
+    case Rewrite::EQ_SYM: return "EQ_SYM";
     case Rewrite::FROM_SINGLETON: return "FROM_SINGLETON";
     case Rewrite::IDENTICAL_NODES: return "IDENTICAL_NODES";
     case Rewrite::INTERSECTION_EMPTY_LEFT: return "INTERSECTION_EMPTY_LEFT";
index 36e30ca688467ccc2a43b078715e34f1f23e733d..5574aa080aea05a8b13b06599e449c37d1f9d9db 100644 (file)
@@ -37,6 +37,9 @@ enum class Rewrite : uint32_t
   COUNT_EMPTY,
   COUNT_MK_BAG,
   DUPLICATE_REMOVAL_MK_BAG,
+  EQ_CONST_FALSE,
+  EQ_REFL,
+  EQ_SYM,
   FROM_SINGLETON,
   IDENTICAL_NODES,
   INTERSECTION_EMPTY_LEFT,
index f2cc09240f0ea89ec048b7b9f575ada5f05377c7..e47e3278455da4e6847a547ab2cdc641c2cb7216 100644 (file)
@@ -74,16 +74,36 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite
     Node y = elements[1];
     Node c = d_nm->mkSkolem("c", d_nm->integerType());
     Node d = d_nm->mkSkolem("d", d_nm->integerType());
-    Node bagX = d_nm->mkBag(d_nm->stringType(), x, c);
-    Node bagY = d_nm->mkBag(d_nm->stringType(), y, d);
+    Node A = d_nm->mkBag(d_nm->stringType(), x, c);
+    Node B = d_nm->mkBag(d_nm->stringType(), y, d);
     Node emptyBag =
         d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
+    Node emptyString = d_nm->mkConst(String(""));
+    Node constantBag = d_nm->mkBag(
+        d_nm->stringType(), emptyString, d_nm->mkConst(Rational(1)));
 
     // (= A A) = true where A is a bag
-    Node n1 = emptyBag.eqNode(emptyBag);
+    Node n1 = A.eqNode(A);
     RewriteResponse response1 = d_rewriter->preRewrite(n1);
     TS_ASSERT(response1.d_node == d_nm->mkConst(true)
               && response1.d_status == REWRITE_AGAIN_FULL);
+
+    // (= A B) = false if A and B are different bag constants
+    Node n2 = constantBag.eqNode(emptyBag);
+    RewriteResponse response2 = d_rewriter->postRewrite(n2);
+    TS_ASSERT(response2.d_node == d_nm->mkConst(false)
+              && response2.d_status == REWRITE_AGAIN_FULL);
+
+    // (= B A) = (= A B) if A < B and at least one of A or B is not a constant
+    Node n3 = B.eqNode(A);
+    RewriteResponse response3 = d_rewriter->postRewrite(n3);
+    TS_ASSERT(response3.d_node == A.eqNode(B)
+              && response3.d_status == REWRITE_AGAIN_FULL);
+
+    // (= A B) = (= A B) no rewrite
+    Node n4 = A.eqNode(B);
+    RewriteResponse response4 = d_rewriter->postRewrite(n4);
+    TS_ASSERT(response4.d_node == n4 && response4.d_status == REWRITE_DONE);
   }
 
   void testMkBagConstantElement()