Move tuple/record update elimination from ppRewrite to expandDefinition (#2839)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Tue, 12 Mar 2019 19:43:42 +0000 (14:43 -0500)
committerGitHub <noreply@github.com>
Tue, 12 Mar 2019 19:43:42 +0000 (14:43 -0500)
src/theory/datatypes/theory_datatypes.cpp
test/regress/CMakeLists.txt
test/regress/regress0/datatypes/issue2838.cvc [new file with mode: 0644]

index 5ed623190727681e3d648ed2a95c54e63e31e32c..97b20d989eab76ff0310afa258c14bb3be3a57d8 100644 (file)
@@ -552,100 +552,130 @@ void TheoryDatatypes::finishInit() {
 }
 
 Node TheoryDatatypes::expandDefinition(LogicRequest &logicRequest, Node n) {
-  switch( n.getKind() ){
-  case kind::APPLY_SELECTOR: {
-    Trace("dt-expand") << "Dt Expand definition : " << n << std::endl;
-    Node selector = n.getOperator();
-    Expr selectorExpr = selector.toExpr();
-    // APPLY_SELECTOR always applies to an external selector, cindexOf is legal here
-    size_t cindex = Datatype::cindexOf(selectorExpr);
-    const Datatype& dt = Datatype::datatypeOf(selectorExpr);
-    const DatatypeConstructor& c = dt[cindex];
-    Node selector_use;
-    TypeNode ndt = n[0].getType();
-    if( options::dtSharedSelectors() ){
-      size_t selectorIndex = DatatypesRewriter::indexOf(selector);
-      Trace("dt-expand") << "...selector index = " << selectorIndex << std::endl;
-      Assert( selectorIndex<c.getNumArgs() );
-      selector_use = Node::fromExpr( c.getSelectorInternal( ndt.toType(), selectorIndex ) );
-    }else{
-      selector_use = selector;
-    }
-    Node sel = NodeManager::currentNM()->mkNode( kind::APPLY_SELECTOR_TOTAL, selector_use, n[0] );
-    if( options::dtRewriteErrorSel() ){
-      return sel;
-    }else{
-      Expr tester = c.getTester();
-      Node tst = NodeManager::currentNM()->mkNode( kind::APPLY_TESTER, Node::fromExpr( tester ), n[0] );
-      tst = Rewriter::rewrite( tst );
-      Node n_ret;
-      if( tst==d_true ){
-        n_ret = sel;
+  NodeManager* nm = NodeManager::currentNM();
+  switch (n.getKind())
+  {
+    case kind::APPLY_SELECTOR:
+    {
+      Trace("dt-expand") << "Dt Expand definition : " << n << std::endl;
+      Node selector = n.getOperator();
+      Expr selectorExpr = selector.toExpr();
+      // APPLY_SELECTOR always applies to an external selector, cindexOf is
+      // legal here
+      size_t cindex = Datatype::cindexOf(selectorExpr);
+      const Datatype& dt = Datatype::datatypeOf(selectorExpr);
+      const DatatypeConstructor& c = dt[cindex];
+      Node selector_use;
+      TypeNode ndt = n[0].getType();
+      if (options::dtSharedSelectors())
+      {
+        size_t selectorIndex = DatatypesRewriter::indexOf(selector);
+        Trace("dt-expand") << "...selector index = " << selectorIndex
+                           << std::endl;
+        Assert(selectorIndex < c.getNumArgs());
+        selector_use =
+            Node::fromExpr(c.getSelectorInternal(ndt.toType(), selectorIndex));
       }else{
-        mkExpDefSkolem( selector, ndt, n.getType() );
-        Node sk = NodeManager::currentNM()->mkNode( kind::APPLY_UF, d_exp_def_skolem[ndt][ selector ], n[0]  );
-        if( tst==NodeManager::currentNM()->mkConst( false ) ){
-          n_ret = sk;
+        selector_use = selector;
+      }
+      Node sel = nm->mkNode(kind::APPLY_SELECTOR_TOTAL, selector_use, n[0]);
+      if (options::dtRewriteErrorSel())
+      {
+        return sel;
+      }
+      else
+      {
+        Expr tester = c.getTester();
+        Node tst = nm->mkNode(kind::APPLY_TESTER, Node::fromExpr(tester), n[0]);
+        tst = Rewriter::rewrite(tst);
+        Node n_ret;
+        if (tst == d_true)
+        {
+          n_ret = sel;
         }else{
-          n_ret = NodeManager::currentNM()->mkNode( kind::ITE, tst, sel, sk );
+          mkExpDefSkolem(selector, ndt, n.getType());
+          Node sk =
+              nm->mkNode(kind::APPLY_UF, d_exp_def_skolem[ndt][selector], n[0]);
+          if (tst == nm->mkConst(false))
+          {
+            n_ret = sk;
+          }
+          else
+          {
+            n_ret = nm->mkNode(kind::ITE, tst, sel, sk);
+          }
         }
+        // n_ret = Rewriter::rewrite( n_ret );
+        Trace("dt-expand") << "Expand def : " << n << " to " << n_ret
+                           << std::endl;
+        return n_ret;
       }
-      //n_ret = Rewriter::rewrite( n_ret );
-      Trace("dt-expand") << "Expand def : " << n << " to " << n_ret << std::endl;
-      return n_ret;
     }
-  }
     break;
-  default:
-    return n;
+    case TUPLE_UPDATE:
+    case RECORD_UPDATE:
+    {
+      TypeNode t = n.getType();
+      Assert(t.isDatatype());
+      const Datatype& dt = DatatypeType(t.toType()).getDatatype();
+      NodeBuilder<> b(APPLY_CONSTRUCTOR);
+      b << Node::fromExpr(dt[0].getConstructor());
+      size_t size, updateIndex;
+      if (n.getKind() == TUPLE_UPDATE)
+      {
+        Assert(t.isTuple());
+        size = t.getTupleLength();
+        updateIndex = n.getOperator().getConst<TupleUpdate>().getIndex();
+      }
+      else
+      {
+        Assert(t.isRecord());
+        const Record& record = t.getRecord();
+        size = record.getNumFields();
+        updateIndex = record.getIndex(
+            n.getOperator().getConst<RecordUpdate>().getField());
+      }
+      Debug("tuprec") << "expr is " << n << std::endl;
+      Debug("tuprec") << "updateIndex is " << updateIndex << std::endl;
+      Debug("tuprec") << "t is " << t << std::endl;
+      Debug("tuprec") << "t has arity " << size << std::endl;
+      for (size_t i = 0; i < size; ++i)
+      {
+        if (i == updateIndex)
+        {
+          b << n[1];
+          Debug("tuprec") << "arg " << i << " gets updated to " << n[1]
+                          << std::endl;
+        }
+        else
+        {
+          b << nm->mkNode(
+              APPLY_SELECTOR_TOTAL,
+              Node::fromExpr(dt[0].getSelectorInternal(t.toType(), i)),
+              n[0]);
+          Debug("tuprec") << "arg " << i << " copies "
+                          << b[b.getNumChildren() - 1] << std::endl;
+        }
+      }
+      Node n_ret = b;
+      Debug("tuprec") << "return " << n_ret << std::endl;
+      return n_ret;
+    }
     break;
+    default: return n; break;
   }
   Unreachable();
 }
 
-void TheoryDatatypes::presolve() {
+void TheoryDatatypes::presolve()
+{
   Debug("datatypes") << "TheoryDatatypes::presolve()" << endl;
 }
 
-Node TheoryDatatypes::ppRewrite(TNode in) {
+Node TheoryDatatypes::ppRewrite(TNode in)
+{
   Debug("tuprec") << "TheoryDatatypes::ppRewrite(" << in << ")" << endl;
 
-  TypeNode t = in.getType();
-
-  if(in.getKind() == kind::TUPLE_UPDATE || in.getKind() == kind::RECORD_UPDATE) {
-    Assert( t.isDatatype() );
-    const Datatype& dt = DatatypeType(t.toType()).getDatatype();
-    NodeBuilder<> b(kind::APPLY_CONSTRUCTOR);
-    b << Node::fromExpr(dt[0].getConstructor());
-    size_t size, updateIndex;
-    if(in.getKind() == kind::TUPLE_UPDATE) {
-      Assert( t.isTuple() );
-      size = t.getTupleLength();
-      updateIndex = in.getOperator().getConst<TupleUpdate>().getIndex();
-    } else { // kind::RECORD_UPDATE
-      Assert( t.isRecord() );
-      const Record& record = t.getRecord();
-      size = record.getNumFields();
-      updateIndex = record.getIndex(in.getOperator().getConst<RecordUpdate>().getField());
-    }
-    Debug("tuprec") << "expr is " << in << std::endl;
-    Debug("tuprec") << "updateIndex is " << updateIndex << std::endl;
-    Debug("tuprec") << "t is " << t << std::endl;
-    Debug("tuprec") << "t has arity " << size << std::endl;
-    for(size_t i = 0; i < size; ++i) {
-      if(i == updateIndex) {
-        b << in[1];
-        Debug("tuprec") << "arg " << i << " gets updated to " << in[1] << std::endl;
-      } else {
-        b << NodeManager::currentNM()->mkNode(kind::APPLY_SELECTOR_TOTAL, Node::fromExpr(dt[0].getSelectorInternal( t.toType(), i )), in[0]);
-        Debug("tuprec") << "arg " << i << " copies " << b[b.getNumChildren() - 1] << std::endl;
-      }
-    }
-    Debug("tuprec") << "builder says " << b << std::endl;
-    Node n = b;
-    return n;
-  }
-
   if( in.getKind()==EQUAL ){
     Node nn;
     std::vector< Node > rew;
index 6f147db3c7ca16aeff33c0d69b3381f92234157c..abec884c2a7d5b9a8c974133aca66a64c525f57b 100644 (file)
@@ -363,6 +363,7 @@ set(regress_0_tests
   regress0/datatypes/example-dailler-min.smt2
   regress0/datatypes/is_test.smt2
   regress0/datatypes/issue1433.smt2
+  regress0/datatypes/issue2838.cvc
   regress0/datatypes/jsat-2.6.smt2
   regress0/datatypes/model-subterms-min.smt2
   regress0/datatypes/mutually-recursive.cvc
diff --git a/test/regress/regress0/datatypes/issue2838.cvc b/test/regress/regress0/datatypes/issue2838.cvc
new file mode 100644 (file)
index 0000000..95e1c89
--- /dev/null
@@ -0,0 +1,14 @@
+% EXPECT: sat
+Ints_0 : ARRAY INT OF INT;
+C : TYPE = [# i : INT #];
+CType : TYPE = ARRAY INT OF C;
+C_0 : CType;
+x : INT;
+C_1 : CType = C_0 WITH [0].i := 2;
+
+ASSERT C_0[0].i = 0;
+ASSERT C_0[1].i = 1;
+ASSERT Ints_0[2] = Ints_0[0];
+ASSERT x = Ints_0[C_1[0].i];
+ASSERT x /= Ints_0[C_1[1].i];
+CHECKSAT;