SyGuS grammar refactor (#3100)
authoryoni206 <yoni206@users.noreply.github.com>
Fri, 19 Jul 2019 21:19:17 +0000 (14:19 -0700)
committerAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 19 Jul 2019 21:19:17 +0000 (17:19 -0400)
src/theory/quantifiers/sygus/cegis_unif.cpp
src/theory/quantifiers/sygus/sygus_grammar_cons.cpp
src/theory/quantifiers/sygus/sygus_grammar_cons.h

index e34669425430693616e0bcb0237d420f52ed52f9..b72e5048458f9e762bca74334dcfe662a8119532 100644 (file)
@@ -455,16 +455,20 @@ Node CegisUnifEnumDecisionStrategy::mkLiteral(unsigned n)
       std::stringstream ss;
       ss << "_virtual_enum_grammar";
       std::string virtualEnumName(ss.str());
-      std::map<TypeNode, std::vector<Node>> extra_cons;
-      std::map<TypeNode, std::vector<Node>> exclude_cons;
+      std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
+      std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>
+          exclude_cons;
+      std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>
+          include_cons;
       // do not include "-", which is included by default for integers
-      exclude_cons[intTn].push_back(nm->operatorOf(MINUS));
+      exclude_cons[intTn].insert(nm->operatorOf(MINUS));
       std::unordered_set<Node, NodeHashFunction> term_irrelevant;
       TypeNode vtn = CegGrammarConstructor::mkSygusDefaultType(intTn,
                                                                bvl,
                                                                virtualEnumName,
                                                                extra_cons,
                                                                exclude_cons,
+                                                               include_cons,
                                                                term_irrelevant);
       d_virtual_enum = nm->mkSkolem("_ve", vtn);
       d_tds->registerEnumerator(
index 263c88d158dd9c678e31b7aa340e4806662d3d3d..bcd0d709ef6eb0229103ab03a4deefca1457c829 100644 (file)
@@ -57,7 +57,10 @@ bool CegGrammarConstructor::hasSyntaxRestrictions(Node q)
   return false;
 }
 
-void CegGrammarConstructor::collectTerms( Node n, std::map< TypeNode, std::vector< Node > >& consts ){
+void CegGrammarConstructor::collectTerms(
+    Node n,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& consts)
+{
   std::unordered_map<TNode, bool, TNodeHashFunction> visited;
   std::unordered_map<TNode, bool, TNodeHashFunction>::iterator it;
   std::stack<TNode> visit;
@@ -78,7 +81,7 @@ void CegGrammarConstructor::collectTerms( Node n, std::map< TypeNode, std::vecto
         }
         if( std::find( consts[tn].begin(), consts[tn].end(), c )==consts[tn].end() ){
           Trace("cegqi-debug") << "...consider const : " << c << std::endl;
-          consts[tn].push_back( c );
+          consts[tn].insert(c);
         }
       }
       // recurse
@@ -97,12 +100,13 @@ Node CegGrammarConstructor::process(Node q,
   // now, construct the grammar
   Trace("cegqi") << "SynthConjecture : convert to deep embedding..."
                  << std::endl;
-  std::map< TypeNode, std::vector< Node > > extra_cons;
+  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
   if( options::sygusAddConstGrammar() ){
     Trace("cegqi") << "SynthConjecture : collect constants..." << std::endl;
     collectTerms( q[1], extra_cons );
   }
-  std::map<TypeNode, std::vector<Node>> exc_cons;
+  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exc_cons;
+  std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> inc_cons;
 
   NodeManager* nm = NodeManager::currentNM();
 
@@ -150,8 +154,13 @@ Node CegGrammarConstructor::process(Node q,
       }
 
       // make the default grammar
-      tn = mkSygusDefaultType(
-          preGrammarType, sfvl, ss.str(), extra_cons, exc_cons, term_irlv);
+      tn = mkSygusDefaultType(preGrammarType,
+                              sfvl,
+                              ss.str(),
+                              extra_cons,
+                              exc_cons,
+                              inc_cons,
+                              term_irlv);
     }
     // sfvl may be null for constant synthesis functions
     Trace("cegqi-debug") << "...sygus var list associated with " << sf << " is "
@@ -419,8 +428,10 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
     TypeNode range,
     Node bvl,
     const std::string& fun,
-    std::map<TypeNode, std::vector<Node>>& extra_cons,
-    std::map<TypeNode, std::vector<Node>>& exc_cons,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& exc_cons,
+    const std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+        inc_cons,
     std::unordered_set<Node, NodeHashFunction>& term_irrelevant,
     std::vector<CVC4::Datatype>& datatypes,
     std::set<Type>& unres)
@@ -516,13 +527,17 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
     //add constants
     std::vector< Node > consts;
     mkSygusConstantsForType( types[i], consts );
-    std::map< TypeNode, std::vector< Node > >::iterator itec = extra_cons.find( types[i] );
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>::iterator
+        itec = extra_cons.find(types[i]);
     if( itec!=extra_cons.end() ){
-      //consts.insert( consts.end(), itec->second.begin(), itec->second.end() );
-      for (unsigned j = 0, size_j = itec->second.size(); j < size_j; ++j)
+      for (std::unordered_set<Node, NodeHashFunction>::iterator set_it =
+               itec->second.begin();
+           set_it != itec->second.end();
+           set_it++)
       {
-        if( std::find( consts.begin(), consts.end(), itec->second[j] )==consts.end() ){
-          consts.push_back( itec->second[j] );
+        if (std::find(consts.begin(), consts.end(), *set_it) == consts.end())
+        {
+          consts.push_back(*set_it);
         }
       }
     }
@@ -551,9 +566,10 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
 
     if (types[i].isReal())
     {
-      for (unsigned j = 0; j < 2; j++)
+      // Add PLUS, MINUS
+      Kind kinds[2] = {PLUS, MINUS};
+      for (const Kind k : kinds)
       {
-        Kind k = j == 0 ? PLUS : MINUS;
         Trace("sygus-grammar-def") << "...add for " << k << std::endl;
         ops[i].push_back(nm->operatorOf(k).toExpr());
         cnames[i].push_back(kindToString(k));
@@ -760,18 +776,35 @@ void CegGrammarConstructor::mkSygusDefaultGrammar(
   {
     Trace("sygus-grammar-def") << "...make datatype " << datatypes[i] << std::endl;
     datatypes[i].setSygus( types[i].toType(), bvl.toExpr(), true, true );
-    std::map<TypeNode, std::vector<Node>>::iterator itexc =
-        exc_cons.find(types[i]);
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>::iterator
+        itexc = exc_cons.find(types[i]);
+    std::map<TypeNode,
+             std::unordered_set<Node, NodeHashFunction>>::const_iterator itinc =
+        inc_cons.find(types[i]);
     for (unsigned j = 0, size = ops[i].size(); j < size; ++j)
     {
-      // add the constructor if it is not excluded
+      // add the constructor if it is not excluded,
+      // and it is in inc_cons, in case it is not empty
       Node opn = Node::fromExpr(ops[i][j]);
+      Trace("sygus-grammar-def")
+          << "...considering " << opn.toString() << " of kind " << opn.getKind()
+          << " and of type " << opn.getType() << " and of kind of type "
+          << opn.getType().getKind() << " of metakind " << opn.getMetaKind()
+          << std::endl;
       if (itexc == exc_cons.end()
           || std::find(itexc->second.begin(), itexc->second.end(), opn)
                  == itexc->second.end())
       {
-        datatypes[i].addSygusConstructor(
-            ops[i][j], cnames[i][j], cargs[i][j], pcs[i][j], weights[i][j]);
+        Trace("sygus-grammar-def") << "......not excluded " << std::endl;
+        if ((opn.isVar()) || (opn.getType().getKind() != kind::TYPE_CONSTANT)
+            || (itinc == inc_cons.end())
+            || (std::find(itinc->second.begin(), itinc->second.end(), opn)
+                != itinc->second.end()))
+        {
+          Trace("sygus-grammar-def") << "......included " << std::endl;
+          datatypes[i].addSygusConstructor(
+              ops[i][j], cnames[i][j], cargs[i][j], pcs[i][j], weights[i][j]);
+        }
       }
     }
     Trace("sygus-grammar-def") << "...built datatype " << datatypes[i] << " ";
@@ -936,12 +969,19 @@ TypeNode CegGrammarConstructor::mkSygusDefaultType(
     TypeNode range,
     Node bvl,
     const std::string& fun,
-    std::map<TypeNode, std::vector<Node>>& extra_cons,
-    std::map<TypeNode, std::vector<Node>>& exclude_cons,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+        exclude_cons,
+    std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+        include_cons,
     std::unordered_set<Node, NodeHashFunction>& term_irrelevant)
 {
   Trace("sygus-grammar-def") << "*** Make sygus default type " << range << ", make datatypes..." << std::endl;
-  for( std::map< TypeNode, std::vector< Node > >::iterator it = extra_cons.begin(); it != extra_cons.end(); ++it ){
+  for (std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>::iterator
+           it = extra_cons.begin();
+       it != extra_cons.end();
+       ++it)
+  {
     Trace("sygus-grammar-def") << "    ...using " << it->second.size() << " extra constants for " << it->first << std::endl;
   }
   std::set<Type> unres;
@@ -951,6 +991,7 @@ TypeNode CegGrammarConstructor::mkSygusDefaultType(
                         fun,
                         extra_cons,
                         exclude_cons,
+                        include_cons,
                         term_irrelevant,
                         datatypes,
                         unres);
index 7dfa9b4787c82697043bac324a9ce404d22bc57c..c01731d1bd159a00f06d84c4e7d392635b268357 100644 (file)
@@ -63,30 +63,44 @@ public:
  /** is the syntax restricted? */
  bool isSyntaxRestricted() { return d_is_syntax_restricted; }
  /** make the default sygus datatype type corresponding to builtin type range
- *   bvl is the set of free variables to include in the grammar
- *   fun is for naming
- *   extra_cons is a set of extra constant symbols to include in the grammar,
- *   exclude_cons is used to exclude operators from the grammar,
- *   term_irrelevant is a set of terms that should not be included in the
- *      grammar.
- */
+  * arguments:
+  *   - bvl: the set of free variables to include in the grammar
+  *   - fun: used for naming
+  *   - extra_cons: a set of extra constant symbols to include in the grammar,
+  *     regardless of their inclusion in the default grammar.
+  *   - exclude_cons: used to exclude operators from the grammar,
+  *   - term_irrelevant: a set of terms that should not be included in the
+  *      grammar.
+  *   - include_cons: a set of operators such that if this set is not empty,
+  *     its elements that are in the default grammar (and only them) 
+  *     will be included.
+  */
  static TypeNode mkSygusDefaultType(
      TypeNode range,
      Node bvl,
      const std::string& fun,
-     std::map<TypeNode, std::vector<Node> >& extra_cons,
-     std::map<TypeNode, std::vector<Node> >& exclude_cons,
+     std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& extra_cons,
+     std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+         exclude_cons,
+     std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+         include_cons,
      std::unordered_set<Node, NodeHashFunction>& term_irrelevant);
  /** make the default sygus datatype type corresponding to builtin type range */
  static TypeNode mkSygusDefaultType(TypeNode range,
                                     Node bvl,
                                     const std::string& fun)
  {
-   std::map<TypeNode, std::vector<Node> > extra_cons;
-   std::map<TypeNode, std::vector<Node> > exclude_cons;
+   std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> extra_cons;
+   std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> exclude_cons;
+   std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>> include_cons;
    std::unordered_set<Node, NodeHashFunction> term_irrelevant;
-   return mkSygusDefaultType(
-       range, bvl, fun, extra_cons, exclude_cons, term_irrelevant);
+   return mkSygusDefaultType(range,
+                             bvl,
+                             fun,
+                             extra_cons,
+                             exclude_cons,
+                             include_cons,
+                             term_irrelevant);
   }
   /** make the sygus datatype type that encodes the solution space (lambda
   * templ_arg. templ[templ_arg]) where templ_arg
@@ -139,7 +153,9 @@ public:
   /** is the syntax restricted? */
   bool d_is_syntax_restricted;
   /** collect terms */
-  void collectTerms( Node n, std::map< TypeNode, std::vector< Node > >& consts );
+  void collectTerms(
+      Node n,
+      std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>& consts);
   //---------------- grammar construction
   // helper for mkSygusDefaultGrammar (makes unresolved type for mutually recursive datatype construction)
   static TypeNode mkUnresolvedType(const std::string& name, std::set<Type>& unres);
@@ -155,11 +171,16 @@ public:
       TypeNode range,
       Node bvl,
       const std::string& fun,
-      std::map<TypeNode, std::vector<Node> >& extra_cons,
-      std::map<TypeNode, std::vector<Node> >& exclude_cons,
+      std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+          extra_cons,
+      std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+          exclude_cons,
+      const std::map<TypeNode, std::unordered_set<Node, NodeHashFunction>>&
+          include_cons,
       std::unordered_set<Node, NodeHashFunction>& term_irrelevant,
       std::vector<CVC4::Datatype>& datatypes,
       std::set<Type>& unres);
+
   // helper function for mkSygusTemplateType
   static TypeNode mkSygusTemplateTypeRec( Node templ, Node templ_arg, TypeNode templ_arg_sygus_type, Node bvl, 
                                           const std::string& fun, unsigned& tcount );