From d14cb15be7eb3482afeec277d56d6ed2e9cdc76a Mon Sep 17 00:00:00 2001 From: ajreynol Date: Wed, 5 Apr 2017 13:19:21 -0500 Subject: [PATCH] Fixes for nlAlgSolveSubs. --- src/theory/arith/nonlinear_extension.cpp | 140 ++++++++++++++++------- 1 file changed, 101 insertions(+), 39 deletions(-) diff --git a/src/theory/arith/nonlinear_extension.cpp b/src/theory/arith/nonlinear_extension.cpp index 558a07e39..d0b1748c4 100644 --- a/src/theory/arith/nonlinear_extension.cpp +++ b/src/theory/arith/nonlinear_extension.cpp @@ -129,7 +129,8 @@ struct SubstitutionConstResult { }; /* struct SubstitutionConstResult */ SubstitutionConstResult getSubstitutionConst( - Node n, const std::vector& sum, const std::vector& rep_sum, + Node n, Node n_rsu, Node rsu_exp, + const std::vector& sum, const std::vector& rep_sum, const std::map& rep_to_const, const std::map& rep_to_const_exp, const std::map& rep_to_const_base) { @@ -163,8 +164,11 @@ SubstitutionConstResult getSubstitutionConst( vars.push_back(sum[i]); subs.push_back(const_of_cr); } + if( n!=n_rsu && !rsu_exp.isNull() ){ + result.const_exp.push_back( rsu_exp ); + } Node substituted = - n.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); + n_rsu.substitute(vars.begin(), vars.end(), subs.begin(), subs.end()); result.term = Rewriter::rewrite(substituted); return result; } @@ -295,6 +299,9 @@ class NonLinearExtentionSubstitutionSolver { const eq::EqualityEngine* d_ee; + std::map d_rep_sum_unique; + std::map d_rep_sum_unique_exp; + std::map d_rep_to_const; std::map d_rep_to_const_exp; std::map d_rep_to_const_base; @@ -346,50 +353,102 @@ bool NonLinearExtentionSubstitutionSolver::solve( if (QuantArith::getMonomialSum(n, msum)) { int nconst_count = 0; bool evaluatable = true; + //first, collect sums of equal terms + std::map< Node, Node > rep_to_mon; + std::vector< Node > subs_rm; + std::vector< Node > vars_rm; + std::vector< Node > exp_rm; for (std::map::iterator itm = msum.begin(); itm != msum.end(); ++itm) { if (!itm->first.isNull()) { if (d_ee->hasTerm(itm->first)) { - Trace("nl-subs-debug") - << " ...monomial " << itm->first << std::endl; Node cr = d_ee->getRepresentative(itm->first); - d_term_to_sum[n].push_back(itm->first); - d_term_to_rep_sum[n].push_back(cr); - if (!Contains(d_rep_to_const, cr)) { - if (!IsInVector(d_reps_to_parent_terms[cr], n)) { - d_reps_to_parent_terms[cr].push_back(n); - nconst_count++; - } + std::map< Node, Node >::iterator itrm = rep_to_mon.find( cr ); + if( itrm==rep_to_mon.end() ){ + rep_to_mon[cr] = itm->first; + }else{ + vars_rm.push_back( itm->first ); + subs_rm.push_back( itrm->second ); + exp_rm.push_back( itm->first.eqNode( itrm->second ) ); } - } else { - Trace("nl-subs-debug") - << "...is not evaluatable due to monomial " << itm->first - << std::endl; - evaluatable = false; - break; } + }else{ + Trace("nl-subs-debug") + << "...is not evaluatable due to monomial " << itm->first + << std::endl; + evaluatable = false; + break; } } - if (evaluatable) { - Trace("nl-subs-debug") - << " ...term has " << nconst_count - << " unique non-constant represenative children." - << std::endl; - if (nconst_count == 0) { - if (r_c.isNull()) { - const SubstitutionConstResult result = getSubstitutionConst( - n, d_term_to_sum[n], d_term_to_rep_sum[n], d_rep_to_const, - d_rep_to_const_exp, d_rep_to_const_base); - r_c_exp.insert(r_c_exp.end(), result.const_exp.begin(), - result.const_exp.end()); - r_c = result.term; - r_cb = n; - Assert(result.variable_term.isNull()); - Assert(r_c.isConst()); + if( evaluatable ){ + bool success; + if( !vars_rm.empty() ){ + Node ns = n.substitute( vars_rm.begin(), vars_rm.end(), subs_rm.begin(), subs_rm.end() ); + ns = Rewriter::rewrite( ns ); + if( ns.isConst() ){ + success = false; + if( r_c.isNull() ){ + r_c = ns; + r_cb = n; + r_c_exp.insert( r_c_exp.end(), exp_rm.begin(), exp_rm.end() ); + } + }else{ + //recompute the monomial + msum.clear(); + if (!QuantArith::getMonomialSum(ns, msum)) { + success = false; + }else{ + d_rep_sum_unique_exp[n] = exp_rm.size()==1 ? exp_rm[0] : NodeManager::currentNM()->mkNode( kind::AND, exp_rm ); + d_rep_sum_unique[n] = ns; + } + } + }else{ + d_rep_sum_unique[n] = n; + } + if( success ){ + for (std::map::iterator itm = msum.begin(); + itm != msum.end(); ++itm) { + if (!itm->first.isNull()) { + if (d_ee->hasTerm(itm->first)) { + Trace("nl-subs-debug") + << " ...monomial " << itm->first << std::endl; + Node cr = d_ee->getRepresentative(itm->first); + d_term_to_sum[n].push_back(itm->first); + d_term_to_rep_sum[n].push_back(cr); + if (!Contains(d_rep_to_const, cr)) { + if (!IsInVector(d_reps_to_parent_terms[cr], n)) { + d_reps_to_parent_terms[cr].push_back(n); + nconst_count++; + } + } + } else { + Assert( false ); + } + } + } + if (evaluatable) { + Trace("nl-subs-debug") + << " ...term has " << nconst_count + << " unique non-constant represenative children." + << std::endl; + if (nconst_count == 0) { + if (r_c.isNull()) { + const SubstitutionConstResult result = getSubstitutionConst( + n, d_rep_sum_unique[n], d_rep_sum_unique_exp[n], + d_term_to_sum[n], d_term_to_rep_sum[n], d_rep_to_const, + d_rep_to_const_exp, d_rep_to_const_base); + r_c_exp.insert(r_c_exp.end(), result.const_exp.begin(), + result.const_exp.end()); + r_c = result.term; + r_cb = n; + Assert(result.variable_term.isNull()); + Assert(r_c.isConst()); + } + } else { + d_reps_to_terms[r].push_back(n); + d_term_to_nconst_rep_count[n] = nconst_count; + } } - } else { - d_reps_to_terms[r].push_back(n); - d_term_to_nconst_rep_count[n] = nconst_count; } } } else { @@ -471,7 +530,8 @@ bool NonLinearExtentionSubstitutionSolver::setSubstitutionConst( << " evaluates to constant." << std::endl; if (!Contains(new_const, m)) { const SubstitutionConstResult result = getSubstitutionConst( - m, d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const, + m, d_rep_sum_unique[m], d_rep_sum_unique_exp[m], + d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const, d_rep_to_const_exp, d_rep_to_const_base); new_const_exp[m].insert(new_const_exp[m].end(), result.const_exp.begin(), @@ -489,7 +549,8 @@ bool NonLinearExtentionSubstitutionSolver::setSubstitutionConst( Trace("nl-subs-debug") << "...parent term " << m << " is univariate solved." << std::endl; const SubstitutionConstResult result = getSubstitutionConst( - m, d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const, + m, d_rep_sum_unique[m], d_rep_sum_unique_exp[m], + d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const, d_rep_to_const_exp, d_rep_to_const_base); Node eq = (result.term).eqNode(d_rep_to_const[r]); Node v_c = QuantArith::solveEqualityFor(eq, result.variable_term); @@ -520,7 +581,8 @@ bool NonLinearExtentionSubstitutionSolver::setSubstitutionConst( Trace("nl-subs-debug") << "...term " << m << " is univariate solved." << std::endl; const SubstitutionConstResult result = getSubstitutionConst( - m, d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const, + m, d_rep_sum_unique[m], d_rep_sum_unique_exp[m], + d_term_to_sum[m], d_term_to_rep_sum[m], d_rep_to_const, d_rep_to_const_exp, d_rep_to_const_base); Node v = result.variable_term; Node m_t = result.term; -- 2.30.2