Cleaning up the printing of theory model representative sets. (#1538)
authorTim King <taking@cs.nyu.edu>
Mon, 5 Feb 2018 23:17:45 +0000 (15:17 -0800)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Mon, 5 Feb 2018 23:17:45 +0000 (17:17 -0600)
src/printer/cvc/cvc_printer.cpp
src/printer/smt2/smt2_printer.cpp
src/theory/rep_set.cpp
src/theory/rep_set.h

index f20cb7cce66c2300b1322b354447b6c49be4a805..27105c3b4693df70cad9a809c86d577bcc58b01d 100644 (file)
@@ -1010,121 +1010,121 @@ void CvcPrinter::toStream(std::ostream& out, const CommandStatus* s) const
 
 }/* CvcPrinter::toStream(CommandStatus*) */
 
-void CvcPrinter::toStream(std::ostream& out,
-                          const Model& m,
-                          const Command* c) const
+namespace {
+
+void DeclareTypeCommandToStream(std::ostream& out,
+                                const theory::TheoryModel& model,
+                                const DeclareTypeCommand& command)
 {
-  const theory::TheoryModel& tm = (const theory::TheoryModel&) m;
-  if(dynamic_cast<const DeclareTypeCommand*>(c) != NULL) {
-    TypeNode tn = TypeNode::fromType( ((const DeclareTypeCommand*)c)->getType() );
-    if (options::modelUninterpDtEnum() && tn.isSort())
+  TypeNode type_node = TypeNode::fromType(command.getType());
+  const std::vector<Node>* type_reps =
+      model.getRepSet()->getTypeRepsOrNull(type_node);
+  if (options::modelUninterpDtEnum() && type_node.isSort()
+      && type_reps != nullptr)
+  {
+    out << "DATATYPE" << std::endl;
+    out << "  " << command.getSymbol() << " = ";
+    for (size_t i = 0; i < type_reps->size(); i++)
     {
-      const theory::RepSet* rs = tm.getRepSet();
-      if (rs->d_type_reps.find(tn) != rs->d_type_reps.end())
+      if (i > 0)
       {
-        out << "DATATYPE" << std::endl;
-        out << "  " << dynamic_cast<const DeclareTypeCommand*>(c)->getSymbol()
-            << " = ";
-        for (size_t i = 0; i < (*rs->d_type_reps.find(tn)).second.size(); i++)
-        {
-          if (i > 0)
-          {
-            out << "| ";
-          }
-          out << (*rs->d_type_reps.find(tn)).second[i] << " ";
-        }
-        out << std::endl << "END;" << std::endl;
+        out << "| ";
       }
-      else
+      out << (*type_reps)[i] << " ";
+    }
+    out << std::endl << "END;" << std::endl;
+  }
+  else if (type_node.isSort() && type_reps != nullptr)
+  {
+    out << "% cardinality of " << type_node << " is " << type_reps->size()
+        << std::endl;
+    out << command << std::endl;
+    for (Node type_rep : *type_reps)
+    {
+      if (type_rep.isVar())
       {
-        if (tn.isSort())
-        {
-          // print the cardinality
-          if (rs->d_type_reps.find(tn) != rs->d_type_reps.end())
-          {
-            out << "% cardinality of " << tn << " is "
-                << (*rs->d_type_reps.find(tn)).second.size() << std::endl;
-          }
-        }
-        out << c << std::endl;
-        if (tn.isSort())
-        {
-          // print the representatives
-          if (rs->d_type_reps.find(tn) != rs->d_type_reps.end())
-          {
-            for (size_t i = 0; i < (*rs->d_type_reps.find(tn)).second.size();
-                 i++)
-            {
-              if ((*rs->d_type_reps.find(tn)).second[i].isVar())
-              {
-                out << (*rs->d_type_reps.find(tn)).second[i] << " : " << tn
-                    << ";" << std::endl;
-              }
-              else
-              {
-                out << "% rep: " << (*rs->d_type_reps.find(tn)).second[i]
-                    << std::endl;
-              }
-            }
-          }
-        }
+        out << type_rep << " : " << type_node << ";" << std::endl;
       }
-    }
-  } else if(dynamic_cast<const DeclareFunctionCommand*>(c) != NULL) {
-    Node n = Node::fromExpr( ((const DeclareFunctionCommand*)c)->getFunction() );
-    if(n.getKind() == kind::SKOLEM) {
-      // don't print out internal stuff
-      return;
-    }
-    TypeNode tn = n.getType();
-    out << n << " : ";
-    if( tn.isFunction() || tn.isPredicate() ){
-      out << "(";
-      for( size_t i=0; i<tn.getNumChildren()-1; i++ ){
-        if( i>0 ) out << ", ";
-        out << tn[i];
+      else
+      {
+        out << "% rep: " << type_rep << std::endl;
       }
-      out << ") -> " << tn.getRangeType();
-    }else{
-      out << tn;
     }
-    Node val = Node::fromExpr(tm.getSmtEngine()->getValue(n.toExpr()));
-    if( options::modelUninterpDtEnum() && val.getKind() == kind::STORE ) {
-      const theory::RepSet* rs = tm.getRepSet();
-      TypeNode tn = val[1].getType();
-      if (tn.isSort() && rs->d_type_reps.find(tn) != rs->d_type_reps.end())
+  }
+  else
+  {
+    out << command << std::endl;
+  }
+}
+
+void DeclareFunctionCommandToStream(std::ostream& out,
+                                    const theory::TheoryModel& model,
+                                    const DeclareFunctionCommand& command)
+{
+  Node n = Node::fromExpr(command.getFunction());
+  if (n.getKind() == kind::SKOLEM)
+  {
+    // don't print out internal stuff
+    return;
+  }
+  TypeNode tn = n.getType();
+  out << n << " : ";
+  if (tn.isFunction() || tn.isPredicate())
+  {
+    out << "(";
+    for (size_t i = 0; i < tn.getNumChildren() - 1; i++)
+    {
+      if (i > 0)
       {
-        Cardinality indexCard((*rs->d_type_reps.find(tn)).second.size());
-        val = theory::arrays::TheoryArraysRewriter::normalizeConstant( val, indexCard );
+        out << ", ";
       }
+      out << tn[i];
     }
-    out << " = " << val << ";" << std::endl;
-
-/*
-    //for table format (work in progress)
-    bool printedModel = false;
-    if( tn.isFunction() ){
-      if( options::modelFormatMode()==MODEL_FORMAT_MODE_TABLE ){
-        //specialized table format for functions
-        RepSetIterator riter( &d_rep_set );
-        riter.setFunctionDomain( n );
-        while( !riter.isFinished() ){
-          std::vector< Node > children;
-          children.push_back( n );
-          for( int i=0; i<riter.getNumTerms(); i++ ){
-            children.push_back( riter.getTerm( i ) );
-          }
-          Node nn = NodeManager::currentNM()->mkNode( APPLY_UF, children );
-          Node val = getValue( nn );
-          out << val << " ";
-          riter.increment();
-        }
-        printedModel = true;
+    out << ") -> " << tn.getRangeType();
+  }
+  else
+  {
+    out << tn;
+  }
+  Node val = Node::fromExpr(model.getSmtEngine()->getValue(n.toExpr()));
+  if (options::modelUninterpDtEnum() && val.getKind() == kind::STORE)
+  {
+    TypeNode type_node = val[1].getType();
+    if (tn.isSort())
+    {
+      if (const std::vector<Node>* type_reps =
+              model.getRepSet()->getTypeRepsOrNull(type_node))
+      {
+        Cardinality indexCard(type_reps->size());
+        val = theory::arrays::TheoryArraysRewriter::normalizeConstant(
+            val, indexCard);
       }
     }
-*/
-  }else{
-    out << c << std::endl;
+  }
+  out << " = " << val << ";" << std::endl;
+}
+
+}  // namespace
+
+void CvcPrinter::toStream(std::ostream& out,
+                          const Model& model,
+                          const Command* command) const
+{
+  const auto* theory_model = dynamic_cast<const theory::TheoryModel*>(&model);
+  AlwaysAssert(theory_model != nullptr);
+  if (const auto* declare_type_command =
+          dynamic_cast<const DeclareTypeCommand*>(command))
+  {
+    DeclareTypeCommandToStream(out, *theory_model, *declare_type_command);
+  }
+  else if (const auto* dfc =
+               dynamic_cast<const DeclareFunctionCommand*>(command))
+  {
+    DeclareFunctionCommandToStream(out, *theory_model, *dfc);
+  }
+  else
+  {
+    out << command << std::endl;
   }
 }
 
index 54fc107199e5e4c2721a86a63e01f1e4de0953ad..24fd0924f97bf1ae930fbfb573abaa374bbfd8dd 100644 (file)
@@ -1277,116 +1277,140 @@ void Smt2Printer::toStream(std::ostream& out, const Model& m) const
   }
 }
 
-void Smt2Printer::toStream(std::ostream& out,
-                           const Model& m,
-                           const Command* c) const
+namespace {
+
+void DeclareTypeCommandToStream(std::ostream& out,
+                                const theory::TheoryModel& model,
+                                const DeclareTypeCommand& command,
+                                Variant variant)
 {
-  const theory::TheoryModel& tm = (const theory::TheoryModel&) m;
-  if(dynamic_cast<const DeclareTypeCommand*>(c) != NULL) {
-    TypeNode tn = TypeNode::fromType( ((const DeclareTypeCommand*)c)->getType() );
-    const theory::RepSet* rs = tm.getRepSet();
-    const std::map<TypeNode, std::vector<Node> >& type_reps = rs->d_type_reps;
-
-    std::map< TypeNode, std::vector< Node > >::const_iterator tn_iterator = type_reps.find( tn );
-    if( options::modelUninterpDtEnum() && tn.isSort() && tn_iterator != type_reps.end() ){
-      if(d_variant == smt2_6_variant) {
-        out << "(declare-datatypes ((" << dynamic_cast<const DeclareTypeCommand*>(c)->getSymbol() << " 0)) (";
-      }else{
-        out << "(declare-datatypes () ((" << dynamic_cast<const DeclareTypeCommand*>(c)->getSymbol() << " ";
-      }
-      for( size_t i=0, N = tn_iterator->second.size(); i < N; i++ ){
-        out << "(" << (*tn_iterator).second[i] << ")";
-      }
-      out << ")))" << endl;
-    } else {
-      if( tn.isSort() ){
-        //print the cardinality
-        if( tn_iterator != type_reps.end() ) {
-          out << "; cardinality of " << tn << " is " << tn_iterator->second.size() << endl;
-        }
+  TypeNode tn = TypeNode::fromType(command.getType());
+  const std::vector<Node>* type_refs = model.getRepSet()->getTypeRepsOrNull(tn);
+  if (options::modelUninterpDtEnum() && tn.isSort() && type_refs != nullptr)
+  {
+    if (variant == smt2_6_variant)
+    {
+      out << "(declare-datatypes ((" << command.getSymbol() << " 0)) (";
+    }
+    else
+    {
+      out << "(declare-datatypes () ((" << command.getSymbol() << " ";
+    }
+    for (Node type_ref : *type_refs)
+    {
+      out << "(" << type_ref << ")";
+    }
+    out << ")))" << endl;
+  }
+  else if (tn.isSort() && type_refs != nullptr)
+  {
+    // print the cardinality
+    out << "; cardinality of " << tn << " is " << type_refs->size() << endl;
+    out << command << endl;
+    // print the representatives
+    for (Node type_ref : *type_refs)
+    {
+      if (type_ref.isVar())
+      {
+        out << "(declare-fun " << quoteSymbol(type_ref) << " () " << tn << ")"
+            << endl;
       }
-      out << c << endl;
-      if( tn.isSort() ){
-        //print the representatives
-        if( tn_iterator != type_reps.end() ){
-          for( size_t i = 0, N = (*tn_iterator).second.size(); i < N; i++ ){
-            TNode current = (*tn_iterator).second[i];
-            if( current.isVar() ){
-              out << "(declare-fun " << quoteSymbol(current) << " () " << tn << ")" << endl;
-            }else{
-              out << "; rep: " << current << endl;
-            }
-          }
-        }
+      else
+      {
+        out << "; rep: " << type_ref << endl;
       }
     }
-  } else if(dynamic_cast<const DeclareFunctionCommand*>(c) != NULL) {
-    const DeclareFunctionCommand* dfc = (const DeclareFunctionCommand*)c;
-    Node n = Node::fromExpr( dfc->getFunction() );
-    if(dfc->getPrintInModelSetByUser()) {
-      if(!dfc->getPrintInModel()) {
-        return;
-      }
-    } else if(n.getKind() == kind::SKOLEM) {
-      // don't print out internal stuff
+  }
+  else
+  {
+    out << command << endl;
+  }
+}
+
+void DeclareFunctionCommandToStream(std::ostream& out,
+                                    const theory::TheoryModel& model,
+                                    const DeclareFunctionCommand& command)
+{
+  Node n = Node::fromExpr(command.getFunction());
+  if (command.getPrintInModelSetByUser())
+  {
+    if (!command.getPrintInModel())
+    {
       return;
     }
-    Node val = Node::fromExpr(tm.getSmtEngine()->getValue(n.toExpr()));
-    if(val.getKind() == kind::LAMBDA) {
-      out << "(define-fun " << n << " " << val[0]
-          << " " << n.getType().getRangeType()
-          << " " << val[1] << ")" << endl;
-    } else {
-      if( options::modelUninterpDtEnum() && val.getKind() == kind::STORE ) {
-        TypeNode tn = val[1].getType();
-        const theory::RepSet* rs = tm.getRepSet();
-        if (tn.isSort() && rs->d_type_reps.find(tn) != rs->d_type_reps.end())
-        {
-          Cardinality indexCard((*rs->d_type_reps.find(tn)).second.size());
-          val = theory::arrays::TheoryArraysRewriter::normalizeConstant( val, indexCard );
-        }
-      }
-      out << "(define-fun " << n << " () "
-          << n.getType() << " ";
-      if(val.getType().isInteger() && n.getType().isReal() && !n.getType().isInteger()) {
-        //toStreamReal(out, val, true);
-        toStreamRational(out, val.getConst<Rational>(), true);
-        //out << val << ".0";
-      } else {
-        out << val;
+  }
+  else if (n.getKind() == kind::SKOLEM)
+  {
+    // don't print out internal stuff
+    return;
+  }
+  Node val = Node::fromExpr(model.getSmtEngine()->getValue(n.toExpr()));
+  if (val.getKind() == kind::LAMBDA)
+  {
+    out << "(define-fun " << n << " " << val[0] << " "
+        << n.getType().getRangeType() << " " << val[1] << ")" << endl;
+  }
+  else
+  {
+    if (options::modelUninterpDtEnum() && val.getKind() == kind::STORE)
+    {
+      TypeNode tn = val[1].getType();
+      const std::vector<Node>* type_refs =
+          model.getRepSet()->getTypeRepsOrNull(tn);
+      if (tn.isSort() && type_refs != nullptr)
+      {
+        Cardinality indexCard(type_refs->size());
+        val = theory::arrays::TheoryArraysRewriter::normalizeConstant(
+            val, indexCard);
       }
-      out << ")" << endl;
     }
-/*
-    //for table format (work in progress)
-    bool printedModel = false;
-    if( tn.isFunction() ){
-      if( options::modelFormatMode()==MODEL_FORMAT_MODE_TABLE ){
-        //specialized table format for functions
-        RepSetIterator riter( &d_rep_set );
-        riter.setFunctionDomain( n );
-        while( !riter.isFinished() ){
-          std::vector< Node > children;
-          children.push_back( n );
-          for( int i=0; i<riter.getNumTerms(); i++ ){
-            children.push_back( riter.getTerm( i ) );
-          }
-          Node nn = NodeManager::currentNM()->mkNode( APPLY_UF, children );
-          Node val = getValue( nn );
-          out << val << " ";
-          riter.increment();
-        }
-        printedModel = true;
-      }
+    out << "(define-fun " << n << " () " << n.getType() << " ";
+    if (val.getType().isInteger() && n.getType().isReal()
+        && !n.getType().isInteger())
+    {
+      // toStreamReal(out, val, true);
+      toStreamRational(out, val.getConst<Rational>(), true);
+      // out << val << ".0";
     }
-*/
-  } else {
-    DatatypeDeclarationCommand* c1 = (DatatypeDeclarationCommand*)c;
-    const vector<DatatypeType>& datatypes = c1->getDatatypes();
-    if (!datatypes[0].isTuple()) {
-      out << c << endl;
+    else
+    {
+      out << val;
+    }
+    out << ")" << endl;
+  }
+}
+
+}  // namespace
+
+void Smt2Printer::toStream(std::ostream& out,
+                           const Model& model,
+                           const Command* command) const
+{
+  const theory::TheoryModel* theory_model =
+      dynamic_cast<const theory::TheoryModel*>(&model);
+  AlwaysAssert(theory_model != nullptr);
+  if (const DeclareTypeCommand* dtc =
+          dynamic_cast<const DeclareTypeCommand*>(command))
+  {
+    DeclareTypeCommandToStream(out, *theory_model, *dtc, d_variant);
+  }
+  else if (const DeclareFunctionCommand* dfc =
+               dynamic_cast<const DeclareFunctionCommand*>(command))
+  {
+    DeclareFunctionCommandToStream(out, *theory_model, *dfc);
+  }
+  else if (const DatatypeDeclarationCommand* datatype_declaration_command =
+               dynamic_cast<const DatatypeDeclarationCommand*>(command))
+  {
+    if (!datatype_declaration_command->getDatatypes()[0].isTuple())
+    {
+      out << command << endl;
     }
   }
+  else
+  {
+    Unreachable();
+  }
 }
 
 void Smt2Printer::toStreamSygus(std::ostream& out, TNode n) const
index bff5e36cdd82d7612830a4a23189522b0a37aa88..04c39c8976b17b89900157982fa9f4e2bf005ac8 100644 (file)
@@ -43,12 +43,8 @@ bool RepSet::hasRep(TypeNode tn, Node n) const
 
 unsigned RepSet::getNumRepresentatives(TypeNode tn) const
 {
-  std::map< TypeNode, std::vector< Node > >::const_iterator it = d_type_reps.find( tn );
-  if( it!=d_type_reps.end() ){
-    return it->second.size();
-  }else{
-    return 0;
-  }
+  const std::vector<Node>* reps = getTypeRepsOrNull(tn);
+  return (reps != nullptr) ? reps->size() : 0;
 }
 
 Node RepSet::getRepresentative(TypeNode tn, unsigned i) const
@@ -60,14 +56,18 @@ Node RepSet::getRepresentative(TypeNode tn, unsigned i) const
   return it->second[i];
 }
 
-void RepSet::getRepresentatives(TypeNode tn, std::vector<Node>& reps) const
+const std::vector<Node>* RepSet::getTypeRepsOrNull(TypeNode tn) const
 {
-  std::map<TypeNode, std::vector<Node> >::const_iterator it =
-      d_type_reps.find(tn);
-  Assert(it != d_type_reps.end());
-  reps.insert(reps.end(), it->second.begin(), it->second.end());
+  auto it = d_type_reps.find(tn);
+  if (it == d_type_reps.end())
+  {
+    return nullptr;
+  }
+  return &(it->second);
 }
 
+namespace {
+
 bool containsStoreAll(Node n, std::unordered_set<Node, NodeHashFunction>& cache)
 {
   if( std::find( cache.begin(), cache.end(), n )==cache.end() ){
@@ -85,6 +85,8 @@ bool containsStoreAll(Node n, std::unordered_set<Node, NodeHashFunction>& cache)
   return false;
 }
 
+}  // namespace
+
 void RepSet::add( TypeNode tn, Node n ){
   //for now, do not add array constants FIXME
   if( tn.isArray() ){
@@ -264,7 +266,12 @@ bool RepSetIterator::initialize()
       if (d_rs->hasType(tn))
       {
         d_enum_type.push_back( ENUM_DEFAULT );
-        d_rs->getRepresentatives(tn, d_domain_elements[v]);
+        if (const auto* type_reps = d_rs->getTypeRepsOrNull(tn))
+        {
+          std::vector<Node>& v_domain_elements = d_domain_elements[v];
+          v_domain_elements.insert(v_domain_elements.end(),
+                                   type_reps->begin(), type_reps->end());
+        }
       }else{
         Assert( d_incomplete );
         return false;
index 5b75fa943ba99ecf2ecb17c7e22a4dc00f93ba73..a75918b5a204ac2b5ea184c1603bb36185e9b347 100644 (file)
@@ -57,9 +57,9 @@ class QuantifiersEngine;
  * finite types.
  */
 class RepSet {
-public:
+ public:
   RepSet(){}
-  ~RepSet(){}
+
   /** map from types to the list of representatives
    * TODO : as part of #1199, encapsulate this
    */
@@ -67,15 +67,19 @@ public:
   /** clear the set */
   void clear();
   /** does this set have representatives of type tn? */
-  bool hasType( TypeNode tn ) const { return d_type_reps.find( tn )!=d_type_reps.end(); }
+  bool hasType(TypeNode tn) const { return d_type_reps.count(tn) > 0; }
   /** does this set have representative n of type tn? */
   bool hasRep(TypeNode tn, Node n) const;
   /** get the number of representatives for type */
   unsigned getNumRepresentatives(TypeNode tn) const;
   /** get representative at index */
   Node getRepresentative(TypeNode tn, unsigned i) const;
-  /** get representatives of type tn, appends them to reps */
-  void getRepresentatives(TypeNode tn, std::vector<Node>& reps) const;
+  /**
+   * Returns the representatives of a type for a `type_node` if one exists.
+   * Otherwise, returns nullptr.
+   */
+  const std::vector<Node>* getTypeRepsOrNull(TypeNode type_node) const;
+
   /** add representative n for type tn, where n has type tn */
   void add( TypeNode tn, Node n );
   /** returns index in d_type_reps for node n */