Replace conditional rewrite pass in quantifiers with the extended rewriter (#3841)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Sat, 29 Feb 2020 03:43:49 +0000 (21:43 -0600)
committerGitHub <noreply@github.com>
Sat, 29 Feb 2020 03:43:49 +0000 (21:43 -0600)
Fixes #3839.

Previously, the quantifiers rewriter had a rewriting step that was an ad-hoc version of some of the rewrites that have been incorporated into the extended rewriter. Moreover, the code for that pass was buggy.

This eliminates the previous conditional rewriting step from the "term process" rewrite pass in quantifiers. It additional adds an optional (disabled by default) rewriting pass that calls the extended rewriter on the body of quantified formulas. This subsumes the previous behavior and should not be buggy.

Notice that the indentation in computeProcessTerms changed and subsequently has been updated to the new coding standards.

This PR relies on #3840.

src/options/quantifiers_options.toml
src/theory/quantifiers/quantifiers_rewriter.cpp
src/theory/quantifiers/quantifiers_rewriter.h
test/regress/CMakeLists.txt
test/regress/regress0/quantifiers/agg-rew-test-cf.smt2
test/regress/regress0/quantifiers/agg-rew-test.smt2
test/regress/regress1/sygus/issue3839-cond-rewrite.smt2 [new file with mode: 0644]

index cb989b4333cf024217a6ab0318bfda9ee7671150..1101f70c5d4bbd600a834d4f412d43a7d41e2638 100644 (file)
@@ -189,13 +189,13 @@ header = "options/quantifiers_options.h"
   help       = "eliminate extended arithmetic symbols in quantified formulas"
 
 [[option]]
-  name       = "condRewriteQuant"
+  name       = "extRewriteQuant"
   category   = "regular"
-  long       = "cond-rewrite-quant"
+  long       = "ext-rewrite-quant"
   type       = "bool"
-  default    = "true"
+  default    = "false"
   read_only  = true
-  help       = "conditional rewriting of quantified formulas"
+  help       = "apply extended rewriting to bodies of quantified formulas"
 
 [[option]]
   name       = "globalNegate"
index ee2461c238ff688ddc905a87b28455a129e2895a..aed2ae429d879f40f2bca90406fa3f56e3159226 100644 (file)
@@ -21,6 +21,7 @@
 #include "theory/datatypes/theory_datatypes_utils.h"
 #include "theory/quantifiers/bv_inverter.h"
 #include "theory/quantifiers/ematching/trigger.h"
+#include "theory/quantifiers/extended_rewrite.h"
 #include "theory/quantifiers/quantifiers_attributes.h"
 #include "theory/quantifiers/skolemize.h"
 #include "theory/quantifiers/term_database.h"
@@ -44,6 +45,7 @@ std::ostream& operator<<(std::ostream& out, RewriteStep s)
     case COMPUTE_AGGRESSIVE_MINISCOPING:
       out << "COMPUTE_AGGRESSIVE_MINISCOPING";
       break;
+    case COMPUTE_EXT_REWRITE: out << "COMPUTE_EXT_REWRITE"; break;
     case COMPUTE_PROCESS_TERMS: out << "COMPUTE_PROCESS_TERMS"; break;
     case COMPUTE_PRENEX: out << "COMPUTE_PRENEX"; break;
     case COMPUTE_VAR_ELIMINATION: out << "COMPUTE_VAR_ELIMINATION"; break;
@@ -389,142 +391,8 @@ void QuantifiersRewriter::computeDtTesterIteSplit( Node n, std::map< Node, Node
   }
 }
 
-int getEntailedCond( Node n, std::map< Node, bool >& currCond ){
-  std::map< Node, bool >::iterator it = currCond.find( n );
-  if( it!=currCond.end() ){
-    return it->second ? 1 : -1;
-  }else if( n.getKind()==NOT ){
-    return -getEntailedCond( n[0], currCond );
-  }else if( n.getKind()==AND || n.getKind()==OR ){
-    bool hasZero = false;
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      int res = getEntailedCond( n[i], currCond );
-      if( res==0 ){
-        hasZero = true;
-      }else if( n.getKind()==AND && res==-1 ){
-        return -1;
-      }else if( n.getKind()==OR && res==1 ){
-        return 1;
-      }
-    }
-    return hasZero ? 0 : ( n.getKind()==AND ? 1 : -1 );
-  }else if( n.getKind()==ITE ){
-    int res = getEntailedCond( n[0], currCond );
-    if( res==1 ){
-      return getEntailedCond( n[1], currCond );
-    }else if( res==-1 ){
-      return getEntailedCond( n[2], currCond );
-    }
-  }else if( ( n.getKind()==EQUAL && n[0].getType().isBoolean() ) || n.getKind()==ITE ){
-    unsigned start = n.getKind()==EQUAL ? 0 : 1;
-    int res1 = 0;
-    for( unsigned j=start; j<=(start+1); j++ ){
-      int res = getEntailedCond( n[j], currCond );
-      if( res==0 ){
-        return 0;
-      }else if( j==start ){
-        res1 = res;
-      }else{
-        Assert(res != 0);
-        if( n.getKind()==ITE ){
-          return res1==res ? res : 0;
-        }else if( n.getKind()==EQUAL ){
-          return res1==res ? 1 : -1;
-        }
-      }
-    }
-  }
-  else if (n.isConst())
-  {
-    return n.getConst<bool>() ? 1 : -1;
-  }
-  return 0;
-}
-
-bool addEntailedCond( Node n, bool pol, std::map< Node, bool >& currCond, std::vector< Node >& new_cond, bool& conflict ) {
-  if (n.isConst())
-  {
-    Trace("quantifiers-rewrite-term-debug")
-        << "constant cond : " << n << " -> " << pol << std::endl;
-    if (n.getConst<bool>() != pol)
-    {
-      conflict = true;
-    }
-    return false;
-  }
-  std::map< Node, bool >::iterator it = currCond.find( n );
-  if( it==currCond.end() ){
-    Trace("quantifiers-rewrite-term-debug") << "cond : " << n << " -> " << pol << std::endl;
-    new_cond.push_back( n );
-    currCond[n] = pol;
-    return true;
-  }
-  else if (it->second != pol)
-  {
-    Trace("quantifiers-rewrite-term-debug")
-        << "CONFLICTING cond : " << n << " -> " << pol << std::endl;
-    conflict = true;
-  }
-  return false;
-}
-
-void setEntailedCond( Node n, bool pol, std::map< Node, bool >& currCond, std::vector< Node >& new_cond, bool& conflict ) {
-  if( ( n.getKind()==AND && pol ) || ( n.getKind()==OR && !pol ) ){
-    for( unsigned i=0; i<n.getNumChildren(); i++ ){
-      setEntailedCond( n[i], pol, currCond, new_cond, conflict );
-      if( conflict ){
-        break;
-      }
-    }
-  }else if( n.getKind()==NOT ){
-    setEntailedCond( n[0], !pol, currCond, new_cond, conflict );
-    return;
-  }else if( n.getKind()==ITE ){
-    int pol = getEntailedCond( n, currCond );
-    if( pol==1 ){
-      setEntailedCond( n[1], pol, currCond, new_cond, conflict );
-    }else if( pol==-1 ){
-      setEntailedCond( n[2], pol, currCond, new_cond, conflict );
-    }
-  }
-  if( addEntailedCond( n, pol, currCond, new_cond, conflict ) ){
-    if( n.getKind()==APPLY_TESTER ){
-      NodeManager* nm = NodeManager::currentNM();
-      const DType& dt = datatypes::utils::datatypeOf(n.getOperator());
-      unsigned index = datatypes::utils::indexOf(n.getOperator());
-      Assert(dt.getNumConstructors() > 1);
-      if( pol ){
-        for( unsigned i=0; i<dt.getNumConstructors(); i++ ){
-          if( i!=index ){
-            Node t = nm->mkNode(APPLY_TESTER, dt[i].getTester(), n[0]);
-            addEntailedCond( t, false, currCond, new_cond, conflict );
-          }
-        }
-      }else{
-        if( dt.getNumConstructors()==2 ){
-          int oindex = 1-index;
-          Node t = nm->mkNode(APPLY_TESTER, dt[oindex].getTester(), n[0]);
-          addEntailedCond( t, true, currCond, new_cond, conflict );
-        }
-      }
-    }
-  }
-}
-
-void removeEntailedCond( std::map< Node, bool >& currCond, std::vector< Node >& new_cond, std::map< Node, Node >& cache ) {
-  if( !new_cond.empty() ){
-    for( unsigned j=0; j<new_cond.size(); j++ ){
-      currCond.erase( new_cond[j] );
-    }
-    new_cond.clear();
-    cache.clear();
-  }
-}
-
 Node QuantifiersRewriter::computeProcessTerms( Node body, std::vector< Node >& new_vars, std::vector< Node >& new_conds, Node q, QAttributes& qa ){
-  std::map< Node, bool > curr_cond;
   std::map< Node, Node > cache;
-  std::map< Node, Node > icache;
   if( qa.isFunDef() ){
     Node h = QuantAttributes::getFunDefHead( q );
     Assert(!h.isNull());
@@ -534,12 +402,7 @@ Node QuantifiersRewriter::computeProcessTerms( Node body, std::vector< Node >& n
     if (!fbody.isNull())
     {
       Node r = computeProcessTerms2(fbody,
-                                    true,
-                                    true,
-                                    curr_cond,
-                                    0,
                                     cache,
-                                    icache,
                                     new_vars,
                                     new_conds,
                                     false);
@@ -551,241 +414,205 @@ Node QuantifiersRewriter::computeProcessTerms( Node body, std::vector< Node >& n
     // forall xy. false.
   }
   return computeProcessTerms2(body,
-                              true,
-                              true,
-                              curr_cond,
-                              0,
                               cache,
-                              icache,
                               new_vars,
                               new_conds,
                               options::elimExtArithQuant());
 }
 
-Node QuantifiersRewriter::computeProcessTerms2( Node body, bool hasPol, bool pol, std::map< Node, bool >& currCond, int nCurrCond,
-                                                std::map< Node, Node >& cache, std::map< Node, Node >& icache,
-                                                std::vector< Node >& new_vars, std::vector< Node >& new_conds, bool elimExtArith ) {
+Node QuantifiersRewriter::computeProcessTerms2(Node body,
+                                               std::map<Node, Node>& cache,
+                                               std::vector<Node>& new_vars,
+                                               std::vector<Node>& new_conds,
+                                               bool elimExtArith)
+{
   NodeManager* nm = NodeManager::currentNM();
-  Trace("quantifiers-rewrite-term-debug2") << "computeProcessTerms " << body << " " << hasPol << " " << pol << std::endl;
-  Node ret;
+  Trace("quantifiers-rewrite-term-debug2")
+      << "computeProcessTerms " << body << std::endl;
   std::map< Node, Node >::iterator iti = cache.find( body );
   if( iti!=cache.end() ){
-    ret = iti->second;
-    Trace("quantifiers-rewrite-term-debug2") << "Return (cached) " << ret << " for " << body << std::endl;
-  }else{
-    //only do context dependent processing up to depth 8
-    bool doCD = options::condRewriteQuant() && nCurrCond < 8;
-    bool changed = false;
-    std::vector< Node > children;
-    //set entailed conditions based on OR/AND
-    std::map< int, std::vector< Node > > new_cond_children;
-    if( doCD && ( body.getKind()==OR || body.getKind()==AND ) ){
-      nCurrCond = nCurrCond + 1;
-      bool conflict = false;
-      bool use_pol = body.getKind()==AND;
-      for( unsigned j=0; j<body.getNumChildren(); j++ ){
-        setEntailedCond( body[j], use_pol, currCond, new_cond_children[j], conflict );
-      }
-      if( conflict ){
-        Trace("quantifiers-rewrite-term-debug") << "-------conflict, return " << !use_pol << std::endl;
-        ret = NodeManager::currentNM()->mkConst( !use_pol );
-      }
+    return iti->second;
+  }
+  bool changed = false;
+  std::vector<Node> children;
+  for (size_t i = 0; i < body.getNumChildren(); i++)
+  {
+    // do the recursive call on children
+    Node nn =
+        computeProcessTerms2(body[i], cache, new_vars, new_conds, elimExtArith);
+    children.push_back(nn);
+    changed = changed || nn != body[i];
+  }
+
+  // make return value
+  Node ret;
+  if (changed)
+  {
+    if (body.getMetaKind() == kind::metakind::PARAMETERIZED)
+    {
+      children.insert(children.begin(), body.getOperator());
     }
-    if( ret.isNull() ){
-      for( size_t i=0; i<body.getNumChildren(); i++ ){
-      
-        //set/update entailed conditions
-        std::vector< Node > new_cond;
-        bool conflict = false;
-        if( doCD ){
-          if( Trace.isOn("quantifiers-rewrite-term-debug") ){
-            if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){
-              Trace("quantifiers-rewrite-term-debug") << "---rewrite " << body[i] << " under conditions:----" << std::endl;
-            }
-          }
-          if( body.getKind()==ITE && i>0 ){
-            if( i==1 ){
-              nCurrCond = nCurrCond + 1;
-            }
-            setEntailedCond( children[0], i==1, currCond, new_cond, conflict );
-            // should not conflict (entailment check failed)
-            Assert(!conflict);
-          }
-          if( body.getKind()==OR || body.getKind()==AND ){
-            bool use_pol = body.getKind()==AND;
-            //remove the current condition
-            removeEntailedCond( currCond, new_cond_children[i], cache );
-            if( i>0 ){
-              //add the previous condition
-              setEntailedCond( children[i-1], use_pol, currCond, new_cond_children[i-1], conflict );
-            }
-            if( conflict ){
-              Trace("quantifiers-rewrite-term-debug") << "-------conflict, return " << !use_pol << std::endl;
-              ret = NodeManager::currentNM()->mkConst( !use_pol );
-            }
-          }
-          if( !new_cond.empty() ){
-            cache.clear();
-          }
-          if( Trace.isOn("quantifiers-rewrite-term-debug") ){
-            if( ( body.getKind()==ITE && i>0 ) || body.getKind()==OR || body.getKind()==AND ){      
-              Trace("quantifiers-rewrite-term-debug") << "-------" << std::endl;
-            }
-          }
-        }
-        
-        //do the recursive call on children
-        if( !conflict ){
-          bool newHasPol;
-          bool newPol;
-          QuantPhaseReq::getPolarity( body, i, hasPol, pol, newHasPol, newPol );
-          Node nn = computeProcessTerms2( body[i], newHasPol, newPol, currCond, nCurrCond, cache, icache, new_vars, new_conds, elimExtArith );
-          if( body.getKind()==ITE && i==0 ){
-            int res = getEntailedCond( nn, currCond );
-            Trace("quantifiers-rewrite-term-debug") << "Condition for " << body << " is " << nn << ", entailment check=" << res << std::endl;
-            if( res==1 ){
-              ret = computeProcessTerms2( body[1], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds, elimExtArith );
-            }else if( res==-1 ){
-              ret = computeProcessTerms2( body[2], hasPol, pol, currCond, nCurrCond, cache, icache, new_vars, new_conds, elimExtArith );
+    ret = nm->mkNode(body.getKind(), children);
+  }
+  else
+  {
+    ret = body;
+  }
+
+  Trace("quantifiers-rewrite-term-debug2")
+      << "Returning " << ret << " for " << body << std::endl;
+  // do context-independent rewriting
+  if (ret.getKind() == EQUAL
+      && options::iteLiftQuant() != options::IteLiftQuantMode::NONE)
+  {
+    for (size_t i = 0; i < 2; i++)
+    {
+      if (ret[i].getKind() == ITE)
+      {
+        Node no = i == 0 ? ret[1] : ret[0];
+        if (no.getKind() != ITE)
+        {
+          bool doRewrite =
+              options::iteLiftQuant() == options::IteLiftQuantMode::ALL;
+          std::vector<Node> children;
+          children.push_back(ret[i][0]);
+          for (size_t j = 1; j <= 2; j++)
+          {
+            // check if it rewrites to a constant
+            Node nn = nm->mkNode(EQUAL, no, ret[i][j]);
+            nn = Rewriter::rewrite(nn);
+            children.push_back(nn);
+            if (nn.isConst())
+            {
+              doRewrite = true;
             }
           }
-          children.push_back( nn );
-          changed = changed || nn!=body[i];
-        }
-        
-        //clean up entailed conditions
-        removeEntailedCond( currCond, new_cond, cache );
-        
-        if( !ret.isNull() ){
-          break;
-        }
-      }
-      
-      //make return value
-      if( ret.isNull() ){
-        if( changed ){
-          if( body.getMetaKind() == kind::metakind::PARAMETERIZED ){
-            children.insert( children.begin(), body.getOperator() );
+          if (doRewrite)
+          {
+            ret = nm->mkNode(ITE, children);
+            break;
           }
-          ret = NodeManager::currentNM()->mkNode( body.getKind(), children );
-        }else{
-          ret = body;
         }
       }
     }
-    
-    //clean up entailed conditions
-    if( body.getKind()==OR || body.getKind()==AND ){
-      for( unsigned j=0; j<body.getNumChildren(); j++ ){
-        removeEntailedCond( currCond, new_cond_children[j], cache );
-      }
+  }
+  else if (ret.getKind() == SELECT && ret[0].getKind() == STORE)
+  {
+    Node st = ret[0];
+    Node index = ret[1];
+    std::vector<Node> iconds;
+    std::vector<Node> elements;
+    while (st.getKind() == STORE)
+    {
+      iconds.push_back(index.eqNode(st[1]));
+      elements.push_back(st[2]);
+      st = st[0];
+    }
+    ret = nm->mkNode(SELECT, st, index);
+    // conditions
+    for (int i = (iconds.size() - 1); i >= 0; i--)
+    {
+      ret = nm->mkNode(ITE, iconds[i], elements[i], ret);
     }
-    
-    Trace("quantifiers-rewrite-term-debug2") << "Returning " << ret << " for " << body << std::endl;
-    cache[body] = ret;
   }
-
-  //do context-independent rewriting
-  iti = icache.find( ret );
-  if( iti!=icache.end() ){
-    return iti->second;
-  }else{
-    Node prev = ret;
-    if (ret.getKind() == EQUAL
-        && options::iteLiftQuant() != options::IteLiftQuantMode::NONE)
+  else if (elimExtArith)
+  {
+    if (ret.getKind() == INTS_DIVISION_TOTAL
+        || ret.getKind() == INTS_MODULUS_TOTAL)
     {
-      for( size_t i=0; i<2; i++ ){
-        if( ret[i].getKind()==ITE ){
-          Node no = i==0 ? ret[1] : ret[0];
-          if( no.getKind()!=ITE ){
-            bool doRewrite =
-                options::iteLiftQuant() == options::IteLiftQuantMode::ALL;
-            std::vector< Node > children;
-            children.push_back( ret[i][0] );
-            for( size_t j=1; j<=2; j++ ){
-              //check if it rewrites to a constant
-              Node nn = NodeManager::currentNM()->mkNode( EQUAL, no, ret[i][j] );
-              nn = Rewriter::rewrite( nn );
-              children.push_back( nn );
-              if( nn.isConst() ){
-                doRewrite = true;
-              }
-            }
-            if( doRewrite ){
-              ret = NodeManager::currentNM()->mkNode( ITE, children );
-              break;
-            }
+      Node num = ret[0];
+      Node den = ret[1];
+      if (den.isConst())
+      {
+        const Rational& rat = den.getConst<Rational>();
+        Assert(!num.isConst());
+        if (rat != 0)
+        {
+          Node intVar = nm->mkBoundVar(nm->integerType());
+          new_vars.push_back(intVar);
+          Node cond;
+          if (rat > 0)
+          {
+            cond = nm->mkNode(
+                AND,
+                nm->mkNode(LEQ, nm->mkNode(MULT, den, intVar), num),
+                nm->mkNode(
+                    LT,
+                    num,
+                    nm->mkNode(
+                        MULT,
+                        den,
+                        nm->mkNode(PLUS, intVar, nm->mkConst(Rational(1))))));
+          }
+          else
+          {
+            cond = nm->mkNode(
+                AND,
+                nm->mkNode(LEQ, nm->mkNode(MULT, den, intVar), num),
+                nm->mkNode(
+                    LT,
+                    num,
+                    nm->mkNode(
+                        MULT,
+                        den,
+                        nm->mkNode(PLUS, intVar, nm->mkConst(Rational(-1))))));
+          }
+          new_conds.push_back(cond.negate());
+          if (ret.getKind() == INTS_DIVISION_TOTAL)
+          {
+            ret = intVar;
+          }
+          else
+          {
+            ret = nm->mkNode(MINUS, num, nm->mkNode(MULT, den, intVar));
           }
         }
       }
     }
-    else if (ret.getKind() == SELECT && ret[0].getKind() == STORE)
+    else if (ret.getKind() == TO_INTEGER || ret.getKind() == IS_INTEGER)
     {
-      Node st = ret[0];
-      Node index = ret[1];
-      std::vector<Node> iconds;
-      std::vector<Node> elements;
-      while (st.getKind() == STORE)
+      Node intVar = nm->mkBoundVar(nm->integerType());
+      new_vars.push_back(intVar);
+      new_conds.push_back(
+          nm->mkNode(
+                AND,
+                nm->mkNode(LT,
+                           nm->mkNode(MINUS, ret[0], nm->mkConst(Rational(1))),
+                           intVar),
+                nm->mkNode(LEQ, intVar, ret[0]))
+              .negate());
+      if (ret.getKind() == TO_INTEGER)
       {
-        iconds.push_back(index.eqNode(st[1]));
-        elements.push_back(st[2]);
-        st = st[0];
+        ret = intVar;
       }
-      ret = nm->mkNode(SELECT, st, index);
-      // conditions
-      for (int i = (iconds.size() - 1); i >= 0; i--)
+      else
       {
-        ret = nm->mkNode(ITE, iconds[i], elements[i], ret);
+        ret = ret[0].eqNode(intVar);
       }
     }
-    else if( elimExtArith )
+  }
+  cache[body] = ret;
+  return ret;
+}
+
+Node QuantifiersRewriter::computeExtendedRewrite(Node q)
+{
+  Node body = q[1];
+  // apply extended rewriter
+  ExtendedRewriter er;
+  Node bodyr = er.extendedRewrite(body);
+  if (body != bodyr)
+  {
+    std::vector<Node> children;
+    children.push_back(q[0]);
+    children.push_back(bodyr);
+    if (q.getNumChildren() == 3)
     {
-      if( ret.getKind()==INTS_DIVISION_TOTAL || ret.getKind()==INTS_MODULUS_TOTAL ){
-        Node num = ret[0];
-        Node den = ret[1];
-        if(den.isConst()) {
-          const Rational& rat = den.getConst<Rational>();
-          Assert(!num.isConst());
-          if(rat != 0) {
-            Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType());
-            new_vars.push_back( intVar );
-            Node cond;
-            if(rat > 0) {
-              cond = NodeManager::currentNM()->mkNode(kind::AND,
-                       NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num),
-                       NodeManager::currentNM()->mkNode(kind::LT, num,
-                         NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(1))))));
-            } else {
-              cond = NodeManager::currentNM()->mkNode(kind::AND,
-                       NodeManager::currentNM()->mkNode(kind::LEQ, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar), num),
-                       NodeManager::currentNM()->mkNode(kind::LT, num,
-                         NodeManager::currentNM()->mkNode(kind::MULT, den, NodeManager::currentNM()->mkNode(kind::PLUS, intVar, NodeManager::currentNM()->mkConst(Rational(-1))))));
-            }
-            new_conds.push_back( cond.negate() );
-            if( ret.getKind()==INTS_DIVISION_TOTAL ){
-              ret = intVar;
-            }else{
-              ret = NodeManager::currentNM()->mkNode(kind::MINUS, num, NodeManager::currentNM()->mkNode(kind::MULT, den, intVar));
-            }
-          }
-        }
-      }else if( ret.getKind()==TO_INTEGER || ret.getKind()==IS_INTEGER ){
-        Node intVar = NodeManager::currentNM()->mkBoundVar(NodeManager::currentNM()->integerType());
-        new_vars.push_back( intVar );
-        new_conds.push_back(NodeManager::currentNM()->mkNode(kind::AND,
-                              NodeManager::currentNM()->mkNode(kind::LT,
-                                NodeManager::currentNM()->mkNode(kind::MINUS, ret[0], NodeManager::currentNM()->mkConst(Rational(1))), intVar),
-                              NodeManager::currentNM()->mkNode(kind::LEQ, intVar, ret[0])).negate());
-        if( ret.getKind()==TO_INTEGER ){
-          ret = intVar;
-        }else{
-          ret = ret[0].eqNode( intVar );
-        }
-      }
+      children.push_back(q[2]);
     }
-    icache[prev] = ret;
-    return ret;
+    return NodeManager::currentNM()->mkNode(FORALL, children);
   }
+  return q;
 }
 
 Node QuantifiersRewriter::computeCondSplit(Node body,
@@ -2018,9 +1845,13 @@ bool QuantifiersRewriter::doOperation(Node q,
   {
     return options::aggressiveMiniscopeQuant() && is_std;
   }
+  else if (computeOption == COMPUTE_EXT_REWRITE)
+  {
+    return options::extRewriteQuant();
+  }
   else if (computeOption == COMPUTE_PROCESS_TERMS)
   {
-    return options::condRewriteQuant() || options::elimExtArithQuant()
+    return options::elimExtArithQuant()
            || options::iteLiftQuant() != options::IteLiftQuantMode::NONE;
   }
   else if (computeOption == COMPUTE_COND_SPLIT)
@@ -2069,16 +1900,26 @@ Node QuantifiersRewriter::computeOperation(Node f,
     return computeMiniscoping( args, n, qa );
   }else if( computeOption==COMPUTE_AGGRESSIVE_MINISCOPING ){
     return computeAggressiveMiniscoping( args, n );
-  }else if( computeOption==COMPUTE_PROCESS_TERMS ){
+  }
+  else if (computeOption == COMPUTE_EXT_REWRITE)
+  {
+    return computeExtendedRewrite(f);
+  }
+  else if (computeOption == COMPUTE_PROCESS_TERMS)
+  {
     std::vector< Node > new_conds;
     n = computeProcessTerms( n, args, new_conds, f, qa );
     if( !new_conds.empty() ){
       new_conds.push_back( n );
       n = NodeManager::currentNM()->mkNode( OR, new_conds );
     }
-  }else if( computeOption==COMPUTE_COND_SPLIT ){
+  }
+  else if (computeOption == COMPUTE_COND_SPLIT)
+  {
     n = computeCondSplit(n, args, qa);
-  }else if( computeOption==COMPUTE_PRENEX ){
+  }
+  else if (computeOption == COMPUTE_PRENEX)
+  {
     if (options::prenexQuant() == options::PrenexQuantMode::DISJ_NORMAL
         || options::prenexQuant() == options::PrenexQuantMode::NORMAL)
     {
@@ -2091,7 +1932,9 @@ Node QuantifiersRewriter::computeOperation(Node f,
       n = computePrenex( n, args, nargs, true, false );
       Assert(nargs.empty());
     }
-  }else if( computeOption==COMPUTE_VAR_ELIMINATION ){
+  }
+  else if (computeOption == COMPUTE_VAR_ELIMINATION)
+  {
     n = computeVarElimination( n, args, qa );
   }
   Trace("quantifiers-rewrite-debug") << "Compute Operation: return " << n << ", " << args.size() << std::endl;
index ac87f944ceeaf41fce6dea5abf1118c66df3e895..2a3180e781e38b2adb4190158a75a1a3484948b6 100644 (file)
@@ -39,8 +39,10 @@ enum RewriteStep
   COMPUTE_MINISCOPING,
   /** Aggressive miniscoping */
   COMPUTE_AGGRESSIVE_MINISCOPING,
+  /** Apply the extended rewriter to quantified formula bodies */
+  COMPUTE_EXT_REWRITE,
   /**
-   * Term processing (e.g. simplifying terms based on ITE conditions,
+   * Term processing (e.g. simplifying terms based on ITE lifting,
    * eliminating extended arithmetic symbols).
    */
   COMPUTE_PROCESS_TERMS,
@@ -150,12 +152,7 @@ class QuantifiersRewriter : public TheoryRewriter
                              Node n,
                              Node ipl);
   static Node computeProcessTerms2(Node body,
-                                   bool hasPol,
-                                   bool pol,
-                                   std::map<Node, bool>& currCond,
-                                   int nCurrCond,
                                    std::map<Node, Node>& cache,
-                                   std::map<Node, Node>& icache,
                                    std::vector<Node>& new_vars,
                                    std::vector<Node>& new_conds,
                                    bool elimExtArith);
@@ -201,12 +198,42 @@ class QuantifiersRewriter : public TheoryRewriter
                                const std::vector<Node>& args,
                                QAttributes& qa);
   //-------------------------------------end conditional splitting
+  //------------------------------------- process terms
+  /** compute process terms
+   *
+   * This takes as input a quantified formula q with attributes qa whose
+   * body is body.
+   *
+   * This rewrite eliminates problematic terms from the bodies of
+   * quantified formulas, which includes performing:
+   * - Certain cases of ITE lifting,
+   * - Elimination of extended arithmetic functions like to_int/is_int/div/mod,
+   * - Elimination of select over store.
+   *
+   * It may introduce new variables V into new_vars and new conditions C into
+   * new_conds. It returns a node retBody such that q of the form
+   *   forall X. body
+   * is equivalent to:
+   *   forall X, V. ( C => retBody )
+   */
+  static Node computeProcessTerms(Node body,
+                                  std::vector<Node>& new_vars,
+                                  std::vector<Node>& new_conds,
+                                  Node q,
+                                  QAttributes& qa);
+  //------------------------------------- end process terms
+  //------------------------------------- extended rewrite
+  /** compute extended rewrite
+   *
+   * This returns the result of applying the extended rewriter on the body
+   * of quantified formula q.
+   */
+  static Node computeExtendedRewrite(Node q);
+  //------------------------------------- end extended rewrite
  public:
   static Node computeElimSymbols( Node body );
   static Node computeMiniscoping( std::vector< Node >& args, Node body, QAttributes& qa );
   static Node computeAggressiveMiniscoping( std::vector< Node >& args, Node body );
-  //cache is dependent upon currCond, icache is not, new_conds are negated conditions
-  static Node computeProcessTerms( Node body, std::vector< Node >& new_vars, std::vector< Node >& new_conds, Node q, QAttributes& qa );
   static Node computePrenex( Node body, std::vector< Node >& args, std::vector< Node >& nargs, bool pol, bool prenexAgg );
   static Node computePrenexAgg( Node n, bool topLevel, std::map< unsigned, std::map< Node, Node > >& visited );
   static Node computeSplit( std::vector< Node >& args, Node body, QAttributes& qa );
index 332b703e8b915d170eb94de320c048fc4b48826d..32ee2a744996099ffad745b1ed8a6191dd51f4ad 100644 (file)
@@ -1830,6 +1830,7 @@ set(regress_1_tests
   regress1/sygus/issue3648.smt2
   regress1/sygus/issue3649.sy
   regress1/sygus/issue3802-default-consts.sy
+  regress1/sygus/issue3839-cond-rewrite.smt2
   regress1/sygus/large-const-simp.sy
   regress1/sygus/let-bug-simp.sy
   regress1/sygus/list-head-x.sy
index 44f475d838ef59ead997ee10f37558b59c7f4940..f46147d7b74cd14ec599373f58b93a7efbcce9f1 100644 (file)
@@ -1,3 +1,5 @@
+; COMMAND-LINE: --ext-rewrite-quant
+; EXPECT: sat
 (set-logic UFLIA)
 (set-info :status sat)
 (declare-fun Q (Int Int) Bool)
index d1159278eb53ee05df55096a209d23e2fbf1c580..7dfb1430efb5152ef49986b5536a6410e1223a0a 100644 (file)
@@ -1,3 +1,5 @@
+; COMMAND-LINE: --ext-rewrite-quant
+; EXPECT: sat
 (set-logic UFLIA)
 (set-info :status sat)
 (declare-fun Q (Int Int) Bool)
diff --git a/test/regress/regress1/sygus/issue3839-cond-rewrite.smt2 b/test/regress/regress1/sygus/issue3839-cond-rewrite.smt2
new file mode 100644 (file)
index 0000000..cbe8f08
--- /dev/null
@@ -0,0 +1,10 @@
+; EXPECT: sat
+; COMMAND-LINE: --sygus-inference
+(set-logic ALL)
+(declare-fun a () Int)
+(declare-fun b () Int)
+(assert (xor (> a 0) (not (and (ite (= a b) (> (* 4 a b) 1) true) (> (* a a) 0)))))
+(assert (= a b))
+(assert (> (* a b) 0))
+(check-sat)
+