Major simplifications to macros module.
authorajreynol <reynolds@larapc05.epfl.ch>
Thu, 8 May 2014 16:18:53 +0000 (18:18 +0200)
committerajreynol <reynolds@larapc05.epfl.ch>
Thu, 8 May 2014 16:18:53 +0000 (18:18 +0200)
src/theory/quantifiers/macros.cpp
src/theory/quantifiers/macros.h

index 72d42cf4bc01c46b0638cef5e4882c2793f6dd8d..24fb53d7f596d613142b7034ea2b7f85db35e95c 100644 (file)
@@ -26,48 +26,12 @@ using namespace CVC4::theory::quantifiers;
 using namespace CVC4::kind;
 using namespace CVC4::context;
 
+
 bool QuantifierMacros::simplify( std::vector< Node >& assertions, bool doRewrite ){
+  Trace("macros") << "Find macros..." << std::endl;
   //first, collect macro definitions
   for( size_t i=0; i<assertions.size(); i++ ){
-    if( assertions[i].getKind()==FORALL ){
-      std::vector< Node > args;
-      for( size_t j=0; j<assertions[i][0].getNumChildren(); j++ ){
-        args.push_back( assertions[i][0][j] );
-      }
-      //look at the body of the quantifier for macro definition
-      process( assertions[i][1], true, args, assertions[i] );
-    }
-  }
-  //create macro defs
-  for( std::map< Node, std::vector< std::pair< Node, Node > > >::iterator it = d_macro_def_cases.begin();
-       it != d_macro_def_cases.end(); ++it ){
-    //create ite based on case definitions
-    Node val;
-    for( size_t i=0; i<it->second.size(); ++i ){
-      if( it->second[i].first.isNull() ){
-        Assert( i==0 );
-        val = it->second[i].second;
-      }else{
-        //if value is null, must generate it
-        if( val.isNull() ){
-          std::stringstream ss;
-          ss << "mdo_" << it->first << "";
-          Node op = NodeManager::currentNM()->mkSkolem( ss.str(), it->first.getType(), "op created during macro definitions" );
-          //will be defined in terms of fresh operator
-          std::vector< Node > children;
-          children.push_back( op );
-          children.insert( children.end(), d_macro_basis[ it->first ].begin(), d_macro_basis[ it->first ].end() );
-          val = NodeManager::currentNM()->mkNode( APPLY_UF, children );
-        }
-        val = NodeManager::currentNM()->mkNode( ITE, it->second[i].first, it->second[i].second, val );
-      }
-    }
-    d_macro_defs[ it->first ] = val;
-    Trace("macros-def") << "* " << val << " is a macro for " << it->first << std::endl;
-  }
-  //now simplify bodies
-  for( std::map< Node, Node >::iterator it = d_macro_defs.begin(); it != d_macro_defs.end(); ++it ){
-    d_macro_defs[ it->first ] = Rewriter::rewrite( simplify( it->second ) );
+    processAssertion( assertions[i] );
   }
   bool retVal = false;
   if( doRewrite && !d_macro_defs.empty() ){
@@ -85,6 +49,21 @@ bool QuantifierMacros::simplify( std::vector< Node >& assertions, bool doRewrite
   return retVal;
 }
 
+void QuantifierMacros::processAssertion( Node n ) {
+  if( n.getKind()==AND ){
+    for( unsigned i=0; i<n.getNumChildren(); i++ ){
+      processAssertion( n[i] );
+    }
+  }else if( n.getKind()==FORALL ){
+    std::vector< Node > args;
+    for( size_t j=0; j<n[0].getNumChildren(); j++ ){
+      args.push_back( n[0][j] );
+    }
+    //look at the body of the quantifier for macro definition
+    process( n[1], true, args, n );
+  }
+}
+
 bool QuantifierMacros::contains( Node n, Node n_s ){
   if( n==n_s ){
     return true;
@@ -105,7 +84,7 @@ bool QuantifierMacros::containsBadOp( Node n, Node n_op ){
       if( op==n_op.getOperator() ){
         return true;
       }
-      if( d_macro_def_cases.find( op )!=d_macro_def_cases.end() && !d_macro_def_cases[op].empty() ){
+      if( d_macro_defs.find( op )!=d_macro_defs.end() ){
         return true;
       }
     }
@@ -122,16 +101,24 @@ bool QuantifierMacros::isMacroLiteral( Node n, bool pol ){
   return pol && ( n.getKind()==EQUAL || n.getKind()==IFF );
 }
 
-void QuantifierMacros::getMacroCandidates( Node n, std::vector< Node >& candidates ){
-  if( n.getKind()==APPLY_UF ){
-    bool allBoundVar = true;
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      if( n[i].getKind()!=BOUND_VARIABLE ){
-        allBoundVar = false;
-        break;
+bool QuantifierMacros::isBoundVarApplyUf( Node n ) {
+  Assert( n.getKind()==APPLY_UF );
+  for( unsigned i=0; i<n.getNumChildren(); i++ ){
+    if( n[i].getKind()!=BOUND_VARIABLE ){
+      return false;
+    }
+    for( unsigned j=0; j<i; j++ ){
+      if( n[j]==n[i] ){
+        return false;
       }
     }
-    if( allBoundVar ){
+  }
+  return true;
+}
+
+void QuantifierMacros::getMacroCandidates( Node n, std::vector< Node >& candidates ){
+  if( n.getKind()==APPLY_UF ){
+    if( isBoundVarApplyUf( n ) ){
       candidates.push_back( n );
     }
   }else if( n.getKind()==PLUS ){
@@ -196,14 +183,6 @@ Node QuantifierMacros::solveInEquality( Node n, Node lit ){
   return Node::null();
 }
 
-bool QuantifierMacros::isConsistentDefinition( Node op, Node cond, Node def ){
-  if( d_macro_def_cases[op].empty() || ( cond.isNull() && !d_macro_def_cases[op][0].first.isNull() ) ){
-    return true;
-  }else{
-    return false;
-  }
-}
-
 bool QuantifierMacros::getFreeVariables( Node n, std::vector< Node >& v_quant, std::vector< Node >& vars, bool retOnly ){
   if( std::find( v_quant.begin(), v_quant.end(), n )!=v_quant.end() ){
     if( std::find( vars.begin(), vars.end(), n )==vars.end() ){
@@ -247,6 +226,13 @@ void QuantifierMacros::process( Node n, bool pol, std::vector< Node >& args, Nod
     //conditional?
   }else if( n.getKind()==ITE ){
     //can not do anything
+  }else if( n.getKind()==APPLY_UF ){
+    //predicate case
+    if( isBoundVarApplyUf( n ) ){
+      Node n_def = NodeManager::currentNM()->mkConst( pol );
+      Trace("macros") << "* " << n_def << " is a macro for " << n.getOperator() << std::endl;
+      d_macro_defs[ n.getOperator() ] = n_def;
+    }
   }else{
     //literal case
     if( isMacroLiteral( n, pol ) ){
@@ -257,7 +243,7 @@ void QuantifierMacros::process( Node n, bool pol, std::vector< Node >& args, Nod
       for( size_t i=0; i<candidates.size(); i++ ){
         Node m = candidates[i];
         Node op = m.getOperator();
-        if( !containsBadOp( n, m ) ){
+        if( d_macro_defs.find( op )==d_macro_defs.end() && !containsBadOp( n, m ) ){
           std::vector< Node > fvs;
           getFreeVariables( m, args, fvs, false );
           //get definition and condition
@@ -267,7 +253,7 @@ void QuantifierMacros::process( Node n, bool pol, std::vector< Node >& args, Nod
             Node n_cond;  //condition when this definition holds
             //conditional must not contain any free variables apart from fvs
             if( n_cond.isNull() || !getFreeVariables( n_cond, args, fvs, true ) ){
-              Trace("macros") << m << " is possible macro in " << f << std::endl;
+              Trace("macros-debug") << m << " is possible macro in " << f << std::endl;
               //now we must rewrite candidates[i] to a term of form g( x1, ..., xn ) where
               // x1 ... xn are distinct variables
               if( d_macro_basis[op].empty() ){
@@ -278,74 +264,17 @@ void QuantifierMacros::process( Node n, bool pol, std::vector< Node >& args, Nod
                   d_macro_basis[op].push_back( v );
                 }
               }
-              std::vector< Node > eq;
-              for( size_t a=0; a<m.getNumChildren(); a++ ){
-                eq.push_back( m[a] );
-              }
-              //solve system of equations "d_macro_basis[op] = m" for variables in fvs
               std::map< Node, Node > solved;
-              //solve obvious cases first
-              for( size_t a=0; a<eq.size(); a++ ){
-                if( std::find( fvs.begin(), fvs.end(), eq[a] )!=fvs.end() ){
-                  if( solved[ eq[a] ].isNull() ){
-                    solved[ eq[a] ] = d_macro_basis[op][a];
-                  }
-                }
+              for( size_t a=0; a<m.getNumChildren(); a++ ){
+                solved[m[a]] = d_macro_basis[op][a];
               }
-              //now, apply substitution for obvious cases
               std::vector< Node > vars;
               std::vector< Node > subs;
-              getSubstitution( fvs, solved, vars, subs, false );
-              for( size_t a=0; a<eq.size(); a++ ){
-                eq[a] = eq[a].substitute( vars.begin(), vars.end(), subs.begin(), subs.end() );
-              }
-
-              Trace("macros-eq") << "Solve system of equations : " << std::endl;
-              for( size_t a=0; a<m.getNumChildren(); a++ ){
-                if( d_macro_basis[op][a]!=eq[a] ){
-                  Trace("macros-eq") << "   " << d_macro_basis[op][a] << " = " << eq[a] << std::endl;
-                }
-              }
-              Trace("macros-eq") << " for ";
-              for( size_t a=0; a<fvs.size(); a++ ){
-                if( solved[ fvs[a] ].isNull() ){
-                  Trace("macros-eq") << fvs[a] << " ";
-                }
-              }
-              Trace("macros-eq") << std::endl;
-              //DO_THIS
-
-
-              vars.clear();
-              subs.clear();
               if( getSubstitution( fvs, solved, vars, subs, true ) ){
-                //build condition
-                std::vector< Node > conds;
-                if( !n_cond.isNull() ){
-                  //must apply substitution obtained from solving system of equations to original condition
-                  n_cond = n_cond.substitute( vars.begin(), vars.end(), subs.begin(), subs.end() );
-                  conds.push_back( n_cond );
-                }
-                for( size_t a=0; a<eq.size(); a++ ){
-                  //collect conditions based on solving argument's system of equations
-                  if( d_macro_basis[op][a]!=eq[a] ){
-                    conds.push_back( NodeManager::currentNM()->mkNode( eq[a].getType().isBoolean() ? IFF : EQUAL, d_macro_basis[op][a], eq[a] ) );
-                  }
-                }
-                //build the condition
-                if( !conds.empty() ){
-                  n_cond = conds.size()==1 ? conds[0] : NodeManager::currentNM()->mkNode( AND, conds );
-                }
-                //apply the substitution to the
                 n_def = n_def.substitute( vars.begin(), vars.end(), subs.begin(), subs.end() );
-                //now see if definition is consistent with others
-                if( isConsistentDefinition( op, n_cond, n_def ) ){
-                  //must clear if it is a base definition
-                  if( n_cond.isNull() ){
-                    d_macro_def_cases[ op ].clear();
-                  }
-                  d_macro_def_cases[ op ].push_back( std::pair< Node, Node >( n_cond, n_def ) );
-                }
+                Trace("macros") << "* " << n_def << " is a macro for " << op << std::endl;
+                d_macro_defs[op] = n_def;
+                return;
               }
             }
           }
@@ -367,9 +296,12 @@ Node QuantifierMacros::simplify( Node n ){
   if( n.getKind()==APPLY_UF ){
     Node op = n.getOperator();
     if( d_macro_defs.find( op )!=d_macro_defs.end() && !d_macro_defs[op].isNull() ){
-      //do substitution
+      //do substitution if necessary
+      std::map< Node, std::vector< Node > >::iterator it = d_macro_basis.find( op );
       Node ret = d_macro_defs[op];
-      ret = ret.substitute( d_macro_basis[op].begin(), d_macro_basis[op].end(), children.begin(), children.end() );
+      if( it!=d_macro_basis.end() ){
+        ret = ret.substitute( it->second.begin(), it->second.end(), children.begin(), children.end() );
+      }
       return ret;
     }
   }
index ad5cd2e55c4499f922bb64a462256a0fdba08020..682e47930f8f7639f5918c989cca9418503fb434 100644 (file)
@@ -30,20 +30,19 @@ namespace quantifiers {
 
 class QuantifierMacros{
 private:
+  void processAssertion( Node n );
+  bool isBoundVarApplyUf( Node n );
   void process( Node n, bool pol, std::vector< Node >& args, Node f );
   bool contains( Node n, Node n_s );
   bool containsBadOp( Node n, Node n_op );
   bool isMacroLiteral( Node n, bool pol );
   void getMacroCandidates( Node n, std::vector< Node >& candidates );
   Node solveInEquality( Node n, Node lit );
-  bool isConsistentDefinition( Node op, Node cond, Node def );
   bool getFreeVariables( Node n, std::vector< Node >& v_quant, std::vector< Node >& vars, bool retOnly );
   bool getSubstitution( std::vector< Node >& v_quant, std::map< Node, Node >& solved,
                         std::vector< Node >& vars, std::vector< Node >& subs, bool reqComplete );
   //map from operators to macro basis terms
   std::map< Node, std::vector< Node > > d_macro_basis;
-  //map from operators to map from conditions to definition cases
-  std::map< Node, std::vector< std::pair< Node, Node > > > d_macro_def_cases;
   //map from operators to macro definition
   std::map< Node, Node > d_macro_defs;
 private: