Fixes for nlAlgSolveSubs.
authorajreynol <andrew.j.reynolds@gmail.com>
Wed, 5 Apr 2017 18:19:21 +0000 (13:19 -0500)
committerajreynol <andrew.j.reynolds@gmail.com>
Wed, 5 Apr 2017 18:19:31 +0000 (13:19 -0500)
src/theory/arith/nonlinear_extension.cpp

index 558a07e3986600131c2eabf7822186b62481a242..d0b1748c46718ccfbc1518244174b48454b4b12a 100644 (file)
@@ -129,7 +129,8 @@ struct SubstitutionConstResult {
 }; /* struct SubstitutionConstResult */
 
 SubstitutionConstResult getSubstitutionConst(
-    Node n, const std::vector<Node>& sum, const std::vector<Node>& rep_sum,
+    Node n, Node n_rsu, Node rsu_exp,
+    const std::vector<Node>& sum, const std::vector<Node>& rep_sum,
     const std::map<Node, Node>& rep_to_const,
     const std::map<Node, Node>& rep_to_const_exp,
     const std::map<Node, Node>& rep_to_const_base) {
@@ -163,8 +164,11 @@ SubstitutionConstResult getSubstitutionConst(
     vars.push_back(sum[i]);
     subs.push_back(const_of_cr);
   }
+  if( n!=n_rsu && !rsu_exp.isNull() ){
+    result.const_exp.push_back( rsu_exp );
+  }
   Node substituted =
-      n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
+      n_rsu.substitute(vars.begin(), vars.end(), subs.begin(), subs.end());
   result.term = Rewriter::rewrite(substituted);
   return result;
 }
@@ -295,6 +299,9 @@ class NonLinearExtentionSubstitutionSolver {
 
   const eq::EqualityEngine* d_ee;
 
+  std::map<Node, Node > d_rep_sum_unique;
+  std::map<Node, Node > d_rep_sum_unique_exp;
+
   std::map<Node, Node> d_rep_to_const;
   std::map<Node, Node> d_rep_to_const_exp;
   std::map<Node, Node> d_rep_to_const_base;
@@ -346,50 +353,102 @@ bool NonLinearExtentionSubstitutionSolver::solve(
           if (QuantArith::getMonomialSum(n, msum)) {
             int nconst_count = 0;
             bool evaluatable = true;
+            //first, collect sums of equal terms
+            std::map< Node, Node > rep_to_mon;
+            std::vector< Node > subs_rm;
+            std::vector< Node > vars_rm;
+            std::vector< Node > exp_rm;
             for (std::map<Node, Node>::iterator itm = msum.begin();
                  itm != msum.end(); ++itm) {
               if (!itm->first.isNull()) {
                 if (d_ee->hasTerm(itm->first)) {
-                  Trace("nl-subs-debug")
-                      << "      ...monomial " << itm->first << std::endl;
                   Node cr = d_ee->getRepresentative(itm->first);
-                  d_term_to_sum[n].push_back(itm->first);
-                  d_term_to_rep_sum[n].push_back(cr);
-                  if (!Contains(d_rep_to_const, cr)) {
-                    if (!IsInVector(d_reps_to_parent_terms[cr], n)) {
-                      d_reps_to_parent_terms[cr].push_back(n);
-                      nconst_count++;
-                    }
+                  std::map< Node, Node >::iterator itrm = rep_to_mon.find( cr );
+                  if( itrm==rep_to_mon.end() ){
+                    rep_to_mon[cr] = itm->first;
+                  }else{
+                    vars_rm.push_back( itm->first );
+                    subs_rm.push_back( itrm->second );
+                    exp_rm.push_back( itm->first.eqNode( itrm->second ) );
                   }
-                } else {
-                  Trace("nl-subs-debug")
-                      << "...is not evaluatable due to monomial " << itm->first
-                      << std::endl;
-                  evaluatable = false;
-                  break;
                 }
+              }else{
+                Trace("nl-subs-debug")
+                    << "...is not evaluatable due to monomial " << itm->first
+                    << std::endl;
+                evaluatable = false;
+                break;
               }
             }
-            if (evaluatable) {
-              Trace("nl-subs-debug")
-                  << "  ...term has " << nconst_count
-                  << " unique non-constant represenative children."
-                  << std::endl;
-              if (nconst_count == 0) {
-                if (r_c.isNull()) {
-                  const SubstitutionConstResult result = getSubstitutionConst(
-                      n, d_term_to_sum[n], d_term_to_rep_sum[n], d_rep_to_const,
-                      d_rep_to_const_exp, d_rep_to_const_base);
-                  r_c_exp.insert(r_c_exp.end(), result.const_exp.begin(),
-                                 result.const_exp.end());
-                  r_c = result.term;
-                  r_cb = n;
-                  Assert(result.variable_term.isNull());
-                  Assert(r_c.isConst());
+            if( evaluatable ){
+              bool success;
+              if( !vars_rm.empty() ){
+                Node ns = n.substitute( vars_rm.begin(), vars_rm.end(), subs_rm.begin(), subs_rm.end() );
+                ns = Rewriter::rewrite( ns );
+                if( ns.isConst() ){
+                  success = false;
+                  if( r_c.isNull() ){
+                    r_c = ns;
+                    r_cb = n;
+                    r_c_exp.insert( r_c_exp.end(), exp_rm.begin(), exp_rm.end() );
+                  }
+                }else{
+                  //recompute the monomial
+                  msum.clear();
+                  if (!QuantArith::getMonomialSum(ns, msum)) {
+                    success = false;
+                  }else{
+                    d_rep_sum_unique_exp[n] = exp_rm.size()==1 ? exp_rm[0] : NodeManager::currentNM()->mkNode( kind::AND, exp_rm );
+                    d_rep_sum_unique[n] = ns;
+                  }
+                }
+              }else{
+                d_rep_sum_unique[n] = n;
+              }
+              if( success ){
+                for (std::map<Node, Node>::iterator itm = msum.begin();
+                     itm != msum.end(); ++itm) {
+                  if (!itm->first.isNull()) {
+                    if (d_ee->hasTerm(itm->first)) {
+                      Trace("nl-subs-debug")
+                          << "      ...monomial " << itm->first << std::endl;
+                      Node cr = d_ee->getRepresentative(itm->first);
+                      d_term_to_sum[n].push_back(itm->first);
+                      d_term_to_rep_sum[n].push_back(cr);
+                      if (!Contains(d_rep_to_const, cr)) {
+                        if (!IsInVector(d_reps_to_parent_terms[cr], n)) {
+                          d_reps_to_parent_terms[cr].push_back(n);
+                          nconst_count++;
+                        }
+                      }
+                    } else {
+                      Assert( false );
+                    }
+                  }
+                }
+                if (evaluatable) {
+                  Trace("nl-subs-debug")
+                      << "  ...term has " << nconst_count
+                      << " unique non-constant represenative children."
+                      << std::endl;
+                  if (nconst_count == 0) {
+                    if (r_c.isNull()) {
+                      const SubstitutionConstResult result = getSubstitutionConst(
+                          n, d_rep_sum_unique[n], d_rep_sum_unique_exp[n],
+                          d_term_to_sum[n], d_term_to_rep_sum[n], d_rep_to_const,
+                          d_rep_to_const_exp, d_rep_to_const_base);
+                      r_c_exp.insert(r_c_exp.end(), result.const_exp.begin(),
+                                     result.const_exp.end());
+                      r_c = result.term;
+                      r_cb = n;
+                      Assert(result.variable_term.isNull());
+                      Assert(r_c.isConst());
+                    }
+                  } else {
+                    d_reps_to_terms[r].push_back(n);
+                    d_term_to_nconst_rep_count[n] = nconst_count;
+                  }
                 }
-              } else {
-                d_reps_to_terms[r].push_back(n);
-                d_term_to_nconst_rep_count[n] = nconst_count;
               }
             }
           } else {
@@ -471,7 +530,8 @@ bool NonLinearExtentionSubstitutionSolver::setSubstitutionConst(
                                  << " evaluates to constant." << std::endl;
           if (!Contains(new_const, m)) {
             const SubstitutionConstResult result = getSubstitutionConst(
-                m, d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const,
+                m, d_rep_sum_unique[m], d_rep_sum_unique_exp[m],
+                d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const,
                 d_rep_to_const_exp, d_rep_to_const_base);
             new_const_exp[m].insert(new_const_exp[m].end(),
                                     result.const_exp.begin(),
@@ -489,7 +549,8 @@ bool NonLinearExtentionSubstitutionSolver::setSubstitutionConst(
           Trace("nl-subs-debug") << "...parent term " << m
                                  << " is univariate solved." << std::endl;
           const SubstitutionConstResult result = getSubstitutionConst(
-              m, d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const,
+              m, d_rep_sum_unique[m], d_rep_sum_unique_exp[m],
+              d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const,
               d_rep_to_const_exp, d_rep_to_const_base);
           Node eq = (result.term).eqNode(d_rep_to_const[r]);
           Node v_c = QuantArith::solveEqualityFor(eq, result.variable_term);
@@ -520,7 +581,8 @@ bool NonLinearExtentionSubstitutionSolver::setSubstitutionConst(
         Trace("nl-subs-debug")
             << "...term " << m << " is univariate solved." << std::endl;
         const SubstitutionConstResult result = getSubstitutionConst(
-            m, d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const,
+            m, d_rep_sum_unique[m], d_rep_sum_unique_exp[m],
+            d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const,
             d_rep_to_const_exp, d_rep_to_const_base);
         Node v = result.variable_term;
         Node m_t = result.term;