Fix BagsRewriter::rewriteUnionDisjoint (#5840)
authormudathirmahgoub <mudathirmahgoub@gmail.com>
Mon, 1 Feb 2021 14:42:39 +0000 (08:42 -0600)
committerGitHub <noreply@github.com>
Mon, 1 Feb 2021 14:42:39 +0000 (08:42 -0600)
This PR fixes the implementation of (union_disjoint (union_max A B) (intersection_min A B)) =(union_disjoint A B).
It also skips processed bags during model building.

src/theory/bags/bags_rewriter.cpp
src/theory/bags/theory_bags.cpp
test/unit/theory/theory_bags_rewriter_white.h

index 66886bfbff1208435a83c4bb14acf30916fc6846..9f53c29ca1f8f0f4e79748fca7ca55e5c0e4d6b9 100644 (file)
@@ -246,7 +246,7 @@ BagsRewriteResponse BagsRewriter::rewriteUnionDisjoint(const TNode& n) const
     //         (union_disjoint A B) // sum(a,b) = max(a,b) + min(a,b)
     // check if the operands of union_max and intersection_min are the same
     std::set<Node> left(n[0].begin(), n[0].end());
-    std::set<Node> right(n[0].begin(), n[0].end());
+    std::set<Node> right(n[1].begin(), n[1].end());
     if (left == right)
     {
       Node rewritten = d_nm->mkNode(UNION_DISJOINT, n[0][0], n[0][1]);
index 15e8e00e7a5532010bc0d1b17ca59aaa86f1bf04..6df44295e7feb6b3e99deadfd8f7bd3dd89d4eb5 100644 (file)
@@ -144,15 +144,26 @@ bool TheoryBags::collectModelValues(TheoryModel* m,
 
   Trace("bags-model") << "Term set: " << termSet << std::endl;
 
+  std::set<Node> processedBags;
+
   // get the relevant bag equivalence classes
   for (const Node& n : termSet)
   {
     TypeNode tn = n.getType();
     if (!tn.isBag())
     {
+      // we are only concerned here about bag terms
       continue;
     }
     Node r = d_state.getRepresentative(n);
+    if (processedBags.find(r) != processedBags.end())
+    {
+      // skip bags whose representatives are already processed
+      continue;
+    }
+
+    processedBags.insert(r);
+
     std::set<Node> solverElements = d_state.getElements(r);
     std::set<Node> elements;
     // only consider terms in termSet and ignore other elements in the solver
index 98e3cf88727fd413fdeb947bdbf5d1a172cc2185..10a624238ab9dfc3b59f01d1883857cd484ddf5c 100644 (file)
@@ -280,16 +280,20 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite
   void testUnionDisjoint()
   {
     int n = 3;
-    vector<Node> elements = getNStrings(2);
+    vector<Node> elements = getNStrings(3);
     Node emptyBag =
         d_nm->mkConst(EmptyBag(d_nm->mkBagType(d_nm->stringType())));
     Node A = d_nm->mkBag(
         d_nm->stringType(), elements[0], d_nm->mkConst(Rational(n)));
     Node B = d_nm->mkBag(
         d_nm->stringType(), elements[1], d_nm->mkConst(Rational(n + 1)));
+    Node C = d_nm->mkBag(
+        d_nm->stringType(), elements[2], d_nm->mkConst(Rational(n + 2)));
+
     Node unionDisjointAB = d_nm->mkNode(UNION_DISJOINT, A, B);
     Node unionDisjointBA = d_nm->mkNode(UNION_DISJOINT, B, A);
     Node unionMaxAB = d_nm->mkNode(UNION_MAX, A, B);
+    Node unionMaxAC = d_nm->mkNode(UNION_MAX, A, C);
     Node unionMaxBA = d_nm->mkNode(UNION_MAX, B, A);
     Node intersectionAB = d_nm->mkNode(INTERSECTION_MIN, A, B);
     Node intersectionBA = d_nm->mkNode(INTERSECTION_MIN, B, A);
@@ -321,6 +325,14 @@ class BagsTypeRuleWhite : public CxxTest::TestSuite
     RewriteResponse response4 = d_rewriter->postRewrite(unionDisjoint4);
     TS_ASSERT(response4.d_node == unionDisjointBA
               && response4.d_status == REWRITE_AGAIN_FULL);
+
+    // (union_disjoint (intersection_min B A)) (union_max A B) =
+    //          (union_disjoint B A) // sum(a,b) = max(a,b) + min(a,b)
+    Node unionDisjoint5 =
+        d_nm->mkNode(UNION_DISJOINT, unionMaxAC, intersectionAB);
+    RewriteResponse response5 = d_rewriter->postRewrite(unionDisjoint5);
+    TS_ASSERT(response5.d_node == unionDisjoint5
+              && response5.d_status == REWRITE_DONE);
   }
 
   void testIntersectionMin()