Properly handle subtypes in smt2 printer.
authorajreynol <andrew.j.reynolds@gmail.com>
Wed, 21 Jun 2017 19:21:45 +0000 (14:21 -0500)
committerajreynol <andrew.j.reynolds@gmail.com>
Wed, 21 Jun 2017 19:23:56 +0000 (14:23 -0500)
src/printer/smt2/smt2_printer.cpp
src/printer/smt2/smt2_printer.h
test/regress/regress0/arith/mult.02.smt2

index 98993dba524ca8b410f28d6ab7cae1134ef883ec..4dd6e73aa6579c2156963ea01a6353f96d4d6759 100644 (file)
@@ -61,14 +61,14 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
       theory::SubstitutionMap::const_iterator i_end = lets.end();
       for(; i != i_end; ++ i) {
         out << "(let ((";
-        toStream(out, (*i).second, toDepth, types);
+        toStream(out, (*i).second, toDepth, types, TypeNode::null());
         out << ' ';
-        toStream(out, (*i).first, toDepth, types);
+        toStream(out, (*i).first, toDepth, types, TypeNode::null());
         out << ")) ";
       }
     }
     Node body = dv.getDagifiedBody();
-    toStream(out, body, toDepth, types);
+    toStream(out, body, toDepth, types, TypeNode::null());
     if(!lets.empty()) {
       theory::SubstitutionMap::const_iterator i = lets.begin();
       theory::SubstitutionMap::const_iterator i_end = lets.end();
@@ -77,7 +77,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
       }
     }
   } else {
-    toStream(out, n, toDepth, types);
+    toStream(out, n, toDepth, types, TypeNode::null());
   }
 }
 
@@ -107,14 +107,30 @@ static bool stringifyRegexp(Node n, stringstream& ss) {
   return true;
 }
 
+// force_nt is the type that n must have
 void Smt2Printer::toStream(std::ostream& out, TNode n,
-                           int toDepth, bool types) const throw() {
+                           int toDepth, bool types, TypeNode force_nt) const throw() {
   // null
   if(n.getKind() == kind::NULL_EXPR) {
     out << "null";
     return;
   }
 
+  if( !force_nt.isNull() && n.getKind()!=kind::CONST_RATIONAL ){
+    if( n.getType()!=force_nt ){
+      if( force_nt.isReal() ){
+        out << "(" << smtKindString( force_nt.isInteger() ? kind::TO_INTEGER : kind::TO_REAL) << " ";
+        toStream(out, n, toDepth, types, TypeNode::null());
+        out << ")";
+      }else{            
+        Node nn = NodeManager::currentNM()->mkNode(kind::APPLY_TYPE_ASCRIPTION,
+                                                   NodeManager::currentNM()->mkConst(AscriptionType(force_nt.toType())), n );
+        toStream(out, nn, toDepth, types, TypeNode::null());                                 
+      }
+      return;  
+    }
+  }
+
   // variable
   if(n.isVar()) {
     string s;
@@ -216,10 +232,10 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
         if(r < 0) {
           out << "-" << -r;
         }else{
-          toStreamRational(out, r, false);
+          toStreamRational(out, r, !force_nt.isNull() && !force_nt.isInteger());
         }
       }else{
-        toStreamRational(out, r, false);
+        toStreamRational(out, r, !force_nt.isNull() && !force_nt.isInteger());
       }
       // Rational r = n.getConst<Rational>();
       // if(r < 0) {
@@ -314,8 +330,8 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
     }
     if(n.getNumChildren() != 0) {
       for(unsigned i = 0; i < n.getNumChildren(); ++i) {
-       out << ' ';
-       toStream(out, n[i], toDepth, types);
+             out << ' ';
+             toStream(out, n[i], toDepth, types, TypeNode::null());
       }
       out << ')';
     }
@@ -324,6 +340,8 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
 
   bool stillNeedToPrintParams = true;
   bool forceBinary = false; // force N-ary to binary when outputing children
+  bool parametricTypeChildren = false;   // parametric operators that are (op t1 ... tn) where t1...tn must have same type
+  bool typeChildren = false;  // operators (op t1...tn) where at least one of t1...tn may require a type cast e.g. Int -> Real
   // operator
   if(n.getNumChildren() != 0 &&
      n.getKind() != kind::INST_PATTERN_LIST &&
@@ -334,7 +352,10 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
     // builtin theory
   case kind::APPLY: break;
   case kind::EQUAL:
-  case kind::DISTINCT: out << smtKindString(k) << " "; break;
+  case kind::DISTINCT: 
+    out << smtKindString(k) << " "; 
+    parametricTypeChildren = true;
+    break;
   case kind::CHAIN: break;
   case kind::FUNCTION_TYPE:
     for(size_t i = 0; i < n.getNumChildren() - 1; ++i) {
@@ -356,7 +377,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
   case kind::ITE: out << smtKindString(k) << " "; break;
 
     // uf theory
-  case kind::APPLY_UF: break;
+  case kind::APPLY_UF: typeChildren = true; break;
 
     // arith theory
   case kind::PLUS:
@@ -378,7 +399,10 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
   case kind::IS_INTEGER:
   case kind::TO_INTEGER:
   case kind::TO_REAL:
-  case kind::POW: out << smtKindString(k) << " "; break;
+  case kind::POW: 
+    parametricTypeChildren = true;
+    out << smtKindString(k) << " "; 
+    break;
 
   case kind::DIVISIBLE:
     out << "(_ divisible " << n.getOperator().getConst<Divisible>().k << ")";
@@ -387,7 +411,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
 
     // arrays theory
   case kind::SELECT:
-  case kind::STORE:
+  case kind::STORE: typeChildren = true;
   case kind::PARTIAL_SELECT_0:
   case kind::PARTIAL_SELECT_1:
   case kind::ARRAY_TYPE: out << smtKindString(k) << " "; break;
@@ -397,7 +421,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
     if(d_variant == z3str_variant) {
       out << "Concat ";
       for(unsigned i = 0; i < n.getNumChildren(); ++i) {
-        toStream(out, n[i], -1, types);
+        toStream(out, n[i], -1, types, TypeNode::null());
         if(i + 1 < n.getNumChildren()) {
           out << ' ';
         }
@@ -416,10 +440,10 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
     stringstream ss;
     if(d_variant == z3str_variant && stringifyRegexp(n[1], ss)) {
       out << "= ";
-      toStream(out, n[0], -1, types);
+      toStream(out, n[0], -1, types, TypeNode::null());
       out << " ";
       Node str = NodeManager::currentNM()->mkConst(String(ss.str()));
-      toStream(out, str, -1, types);
+      toStream(out, str, -1, types, TypeNode::null());
       out << ")";
       return;
     }
@@ -508,7 +532,10 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
   case kind::INTERSECTION:
   case kind::SETMINUS:
   case kind::SUBSET:
-  case kind::MEMBER:
+    parametricTypeChildren = true;
+    out << smtKindString(k) << " "; 
+    break;
+  case kind::MEMBER: typeChildren = true;
   case kind::SET_TYPE:
   case kind::SINGLETON:
   case kind::COMPLEMENT:out << smtKindString(k) << " "; break;
@@ -575,12 +602,12 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
         return;
       }
       out << "(as ";
-      toStream(out, n[0], toDepth < 0 ? toDepth : toDepth - 1, types);
+      toStream(out, n[0], toDepth < 0 ? toDepth : toDepth - 1, types, TypeNode::null());
       out << ' ' << (t.isFunctionLike() ? t.getRangeType() : t) << ')';
       return;
     }
     break;
-  case kind::APPLY_CONSTRUCTOR:
+  case kind::APPLY_CONSTRUCTOR: typeChildren = true;
   case kind::APPLY_TESTER:
   case kind::APPLY_SELECTOR:
   case kind::APPLY_SELECTOR_TOTAL:
@@ -660,7 +687,7 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
     if(toDepth != 0) {
       if( d_variant==sygus_variant && n.getKind()==kind::APPLY_CONSTRUCTOR ){
         std::stringstream ss;
-        toStream(ss, n.getOperator(), toDepth < 0 ? toDepth : toDepth - 1, types);
+        toStream(ss, n.getOperator(), toDepth < 0 ? toDepth : toDepth - 1, types, TypeNode::null());
         std::string tmp = ss.str();
         size_t pos = 0;
         if((pos = tmp.find("__Enum__", pos)) != std::string::npos){
@@ -672,14 +699,14 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
         const Datatype& dt = Datatype::datatypeOf(n.getOperator().toExpr());
         if( d_variant==smt2_6_variant ){
           out << "(_ is ";
-          toStream(out, Node::fromExpr(dt[cindex].getConstructor()), toDepth < 0 ? toDepth : toDepth - 1, types);
+          toStream(out, Node::fromExpr(dt[cindex].getConstructor()), toDepth < 0 ? toDepth : toDepth - 1, types, TypeNode::null());
           out << ")";
         }else{
           out << "is-";
-          toStream(out, Node::fromExpr(dt[cindex].getConstructor()), toDepth < 0 ? toDepth : toDepth - 1, types);
+          toStream(out, Node::fromExpr(dt[cindex].getConstructor()), toDepth < 0 ? toDepth : toDepth - 1, types, TypeNode::null());
         }
       }else{
-        toStream(out, n.getOperator(), toDepth < 0 ? toDepth : toDepth - 1, types);
+        toStream(out, n.getOperator(), toDepth < 0 ? toDepth : toDepth - 1, types, TypeNode::null());
       }
     } else {
       out << "(...)";
@@ -689,9 +716,63 @@ void Smt2Printer::toStream(std::ostream& out, TNode n,
     }
   }
   stringstream parens;
+  
+  // calculate the child type casts
+  std::map< unsigned, TypeNode > force_child_type;
+  if( parametricTypeChildren ){
+    if( n.getNumChildren()>1 ){
+      TypeNode force_ct = n[0].getType();
+      bool do_force = false;
+      for(size_t i = 1; i < n.getNumChildren(); ++i ) {
+        TypeNode ct = n[i].getType();
+        if( ct!=force_ct ){
+          force_ct = TypeNode::leastCommonTypeNode( force_ct, ct );
+          do_force = true;
+        }
+      }
+      if( do_force ){
+        for(size_t i = 0; i < n.getNumChildren(); ++i ) {
+          force_child_type[i] = force_ct;
+        }
+      }
+    }
+  // operators that may require type casting
+  }else if( typeChildren ){
+    if(n.getKind()==kind::SELECT){
+      TypeNode indexType = TypeNode::leastCommonTypeNode( n[0].getType().getArrayIndexType(), n[1].getType() );
+      TypeNode elemType = n[0].getType().getArrayConstituentType();
+      force_child_type[0] = NodeManager::currentNM()->mkArrayType( indexType, elemType );
+      force_child_type[1] = indexType;
+    }else if(n.getKind()==kind::STORE){
+      TypeNode indexType = TypeNode::leastCommonTypeNode( n[0].getType().getArrayIndexType(), n[1].getType() );
+      TypeNode elemType = TypeNode::leastCommonTypeNode( n[0].getType().getArrayConstituentType(), n[2].getType() );
+      force_child_type[0] = NodeManager::currentNM()->mkArrayType( indexType, elemType );
+      force_child_type[1] = indexType;
+      force_child_type[2] = elemType;
+    }else if(n.getKind()==kind::MEMBER){
+      TypeNode elemType = TypeNode::leastCommonTypeNode( n[0].getType(), n[1].getType().getSetElementType() );
+      force_child_type[0] = elemType;
+      force_child_type[1] = NodeManager::currentNM()->mkSetType( elemType );
+    }else{
+      // APPLY_UF, APPLY_CONSTRUCTOR, etc.
+      Assert( n.hasOperator() );
+      TypeNode opt = n.getOperator().getType();
+      Assert( opt.getNumChildren() == n.getNumChildren() + 1 );
+      for(size_t i = 0; i < n.getNumChildren(); ++i ) {
+        force_child_type[i] = opt[i];
+      }
+    }
+  }
+  
   for(size_t i = 0, c = 1; i < n.getNumChildren(); ) {
     if(toDepth != 0) {
-      toStream(out, n[i], toDepth < 0 ? toDepth : toDepth - c, types);
+      Node cn = n[i];
+      std::map< unsigned, TypeNode >::iterator itfc = force_child_type.find( i );
+      if( itfc!=force_child_type.end() ){
+        toStream(out, cn, toDepth < 0 ? toDepth : toDepth - c, types, itfc->second);
+      }else{
+        toStream(out, cn, toDepth < 0 ? toDepth : toDepth - c, types, TypeNode::null());
+      }
     } else {
       out << "(...)";
     }
index d73f11b59a5618a5b3d87cf4072f75338776f824..eeba0b12a91859be4e4c3a08af44c907b6095b1a 100644 (file)
@@ -38,7 +38,7 @@ enum Variant {
 class Smt2Printer : public CVC4::Printer {
   Variant d_variant;
 
-  void toStream(std::ostream& out, TNode n, int toDepth, bool types) const throw();
+  void toStream(std::ostream& out, TNode n, int toDepth, bool types, TypeNode nt) const throw();
   void toStream(std::ostream& out, const Model& m, const Command* c) const throw();
 public:
   Smt2Printer(Variant variant = no_variant) : d_variant(variant) { }
index 54b876d383c38413dbce5fe6a0921108db0837de..57167fc76d0024d97c453146a6cf81361f1d0787 100644 (file)
@@ -1,5 +1,5 @@
 ; EXPECT: (error "A non-linear fact was asserted to arithmetic in a linear logic.
-; EXPECT: The fact in question: (>= (* (- 1) (* n n)) (- 1))
+; EXPECT: The fact in question: (>= (* (- 1.0) (* n n)) (- 1.0))
 ; EXPECT: ")
 ; EXIT: 1
 (set-logic QF_LRA)