Fix Boolean terms w.r.t. parametric datatypes (e.g., (Pair Bool Bool) now works).
authorMorgan Deters <mdeters@cs.nyu.edu>
Thu, 5 Dec 2013 19:07:47 +0000 (14:07 -0500)
committerMorgan Deters <mdeters@cs.nyu.edu>
Thu, 5 Dec 2013 20:41:21 +0000 (15:41 -0500)
src/smt/boolean_terms.cpp
src/smt/model_postprocessor.cpp
test/regress/regress0/datatypes/Makefile.am
test/regress/regress0/datatypes/pair-bool-bool.cvc [new file with mode: 0644]

index 108c888291086a897cc8332ec60dfa47b08afb49..30aa79acae3040ffcd324bd6b00c29262af84bfd 100644 (file)
@@ -152,6 +152,40 @@ Node BooleanTermConverter::rewriteAs(TNode in, TypeNode as) throw() {
     }
     return out;
   }
+  if(in.getType().isParametricDatatype() &&
+     in.getType().isInstantiatedDatatype()) {
+    // We have something here like (Pair Bool Bool)---need to dig inside
+    // and make it (Pair BV1 BV1)
+    Assert(as.isParametricDatatype() && as.isInstantiatedDatatype());
+    const Datatype* dt2 = &as[0].getDatatype();
+    std::vector<TypeNode> fromParams, toParams;
+    for(unsigned i = 0; i < dt2->getNumParameters(); ++i) {
+      fromParams.push_back(TypeNode::fromType(dt2->getParameter(i)));
+      toParams.push_back(as[i + 1]);
+    }
+    const Datatype* dt1 = d_datatypeCache[dt2];
+    Assert(dt1 != NULL, "expected datatype in cache");
+    Assert(*dt1 == in.getType()[0].getDatatype(), "improper rewriteAs() between datatypes");
+    Node out;
+    for(size_t i = 0; i < dt1->getNumConstructors(); ++i) {
+      DatatypeConstructor ctor = (*dt1)[i];
+      NodeBuilder<> appctorb(kind::APPLY_CONSTRUCTOR);
+      appctorb << (*dt2)[i].getConstructor();
+      for(size_t j = 0; j < ctor.getNumArgs(); ++j) {
+        TypeNode asType = TypeNode::fromType(SelectorType((*dt2)[i][j].getSelector().getType()).getRangeType());
+        asType = asType.substitute(fromParams.begin(), fromParams.end(), toParams.begin(), toParams.end());
+        appctorb << rewriteAs(NodeManager::currentNM()->mkNode(kind::APPLY_SELECTOR, ctor[j].getSelector(), in), asType);
+      }
+      Node appctor = appctorb;
+      if(i == 0) {
+        out = appctor;
+      } else {
+        Node newOut = NodeManager::currentNM()->mkNode(kind::ITE, ctor.getTester(), appctor, out);
+        out = newOut;
+      }
+    }
+    return out;
+  }
 
   Unhandled(in);
 }
index 686ecbbe67e79e1d8b725c993c90498264796adb..a66e02778b6b69b901f5bd7338e81986a8e1842e 100644 (file)
@@ -88,6 +88,31 @@ Node ModelPostprocessor::rewriteAs(TNode n, TypeNode asType) {
     Node val = rewriteAs(asa.getExpr(), asType[1]);
     return NodeManager::currentNM()->mkConst(ArrayStoreAll(asType.toType(), val.toExpr()));
   }
+  if(n.getType().isParametricDatatype() &&
+     n.getType().isInstantiatedDatatype() &&
+     asType.isParametricDatatype() &&
+     asType.isInstantiatedDatatype() &&
+     n.getType()[0] == asType[0]) {
+    // Here, we're doing something like rewriting a (Pair BV1 BV1) as a
+    // (Pair Bool Bool).
+    const Datatype* dt2 = &asType[0].getDatatype();
+    std::vector<TypeNode> fromParams, toParams;
+    for(unsigned i = 0; i < dt2->getNumParameters(); ++i) {
+      fromParams.push_back(TypeNode::fromType(dt2->getParameter(i)));
+      toParams.push_back(asType[i + 1]);
+    }
+    Assert(dt2 == &Datatype::datatypeOf(n.getOperator().toExpr()));
+    size_t ctor_ix = Datatype::indexOf(n.getOperator().toExpr());
+    NodeBuilder<> appctorb(kind::APPLY_CONSTRUCTOR);
+    appctorb << (*dt2)[ctor_ix].getConstructor();
+    for(size_t j = 0; j < n.getNumChildren(); ++j) {
+      TypeNode asType = TypeNode::fromType(SelectorType((*dt2)[ctor_ix][j].getSelector().getType()).getRangeType());
+      asType = asType.substitute(fromParams.begin(), fromParams.end(), toParams.begin(), toParams.end());
+      appctorb << rewriteAs(n[j], asType);
+    }
+    Node out = appctorb;
+    return out;
+  }
   if(asType.getNumChildren() != n.getNumChildren() ||
      n.getMetaKind() == kind::metakind::CONSTANT) {
     return n;
index 31999b203aa1c895c94bbb64fd7289872f421e77..67b97add3b26bda7b1bcb6fd9522905c110d2199 100644 (file)
@@ -36,6 +36,8 @@ TESTS =       \
        datatype13.cvc \
        empty_tuprec.cvc \
        mutually-recursive.cvc \
+       pair-real-bool.smt2 \
+       pair-bool-bool.cvc \
        rewriter.cvc \
        typed_v10l30054.cvc \
        typed_v1l80005.cvc \
@@ -60,7 +62,6 @@ TESTS =       \
        wrong-sel-simp.cvc
 
 FAILING_TESTS = \
-       pair-real-bool.smt2 \
        datatype-dump.cvc
 
 EXTRA_DIST = $(TESTS)
diff --git a/test/regress/regress0/datatypes/pair-bool-bool.cvc b/test/regress/regress0/datatypes/pair-bool-bool.cvc
new file mode 100644 (file)
index 0000000..7525e2d
--- /dev/null
@@ -0,0 +1,10 @@
+% EXPECT: sat
+
+DATATYPE
+  pair[T1,T2] = mkpair(first:T1, second:T2)
+END;
+
+x : pair[BOOLEAN,BOOLEAN];
+
+ASSERT x = mkpair(TRUE,TRUE);
+CHECKSAT;