From 1cadb3c1034a9a5b0a778b25769ab1a8101de4f1 Mon Sep 17 00:00:00 2001 From: Tim King Date: Wed, 26 Mar 2014 16:56:13 -0400 Subject: [PATCH] Merging in a fix from 1.3.x. Fixes an idempotency issue for non-linear multiplication with integer and real variables. Conflicts: src/theory/arith/normal_form.cpp --- src/theory/arith/normal_form.cpp | 37 ++++++++++++----- src/theory/arith/normal_form.h | 71 +++++++++++++++++++++++++------- 2 files changed, 82 insertions(+), 26 deletions(-) diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp index 3adb72f37..9ccf057b1 100644 --- a/src/theory/arith/normal_form.cpp +++ b/src/theory/arith/normal_form.cpp @@ -67,6 +67,12 @@ bool Variable::isLeafMember(Node n){ (Theory::isLeafOf(n, theory::THEORY_ARITH)); } +VarList::VarList(Node n) + : NodeWrapper(n) +{ + Assert(isSorted(begin(), end())); +} + bool Variable::isDivMember(Node n){ switch(n.getKind()){ case kind::DIVISION: @@ -96,9 +102,15 @@ bool VarList::isMember(Node n) { Node prev = *curr; if(!Variable::isMember(prev)) return false; + Variable::VariableNodeCmp cmp; + while( (++curr) != end) { if(!Variable::isMember(*curr)) return false; - if(!(prev <= *curr)) return false; + // prev <= curr : accept + // !(prev <= curr) : reject + // !(!(prev > curr)) : reject + // curr < prev : reject + if((cmp(*curr, prev))) return false; prev = *curr; } return true; @@ -118,15 +130,16 @@ int VarList::cmp(const VarList& vl) const { } VarList VarList::parseVarList(Node n) { - if(Variable::isMember(n)) { - return VarList(Variable(n)); - } else { - Assert(n.getKind() == kind::MULT); - for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) { - Assert(Variable::isMember(*i)); - } - return VarList(n); - } + return VarList(n); + // if(Variable::isMember(n)) { + // return VarList(Variable(n)); + // } else { + // Assert(n.getKind() == kind::MULT); + // for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) { + // Assert(Variable::isMember(*i)); + // } + // return VarList(n); + // } } VarList VarList::operator*(const VarList& other) const { @@ -143,7 +156,9 @@ VarList VarList::operator*(const VarList& other) const { otherBegin = other.internalBegin(), otherEnd = other.internalEnd(); - merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result); + Variable::VariableNodeCmp cmp; + + merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result, cmp); Assert(result.size() >= 2); Node mult = NodeManager::currentNM()->mkNode(kind::MULT, result); diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h index f098d8b54..7e8ff556d 100644 --- a/src/theory/arith/normal_form.h +++ b/src/theory/arith/normal_form.h @@ -268,21 +268,43 @@ public: } bool operator<(const Variable& v) const { - bool thisIsVariable = isMetaKindVariable(); - bool vIsVariable = v.isMetaKindVariable(); - - if(thisIsVariable == vIsVariable){ - bool thisIsInteger = isIntegral(); - bool vIsInteger = v.isIntegral(); - if(thisIsInteger == vIsInteger){ - return getNode() < v.getNode(); + VariableNodeCmp cmp; + return cmp(this->getNode(), v.getNode()); + + // bool thisIsVariable = isMetaKindVariable(); + // bool vIsVariable = v.isMetaKindVariable(); + + // if(thisIsVariable == vIsVariable){ + // bool thisIsInteger = isIntegral(); + // bool vIsInteger = v.isIntegral(); + // if(thisIsInteger == vIsInteger){ + // return getNode() < v.getNode(); + // }else{ + // return thisIsInteger && !vIsInteger; + // } + // }else{ + // return thisIsVariable && !vIsVariable; + // } + } + + struct VariableNodeCmp { + bool operator()(Node n, Node m) const { + bool nIsVariable = n.isVar(); + bool mIsVariable = m.isVar(); + + if(nIsVariable == mIsVariable){ + bool nIsInteger = n.getType().isInteger(); + bool mIsInteger = m.getType().isInteger(); + if(nIsInteger == mIsInteger){ + return n < m; + }else{ + return nIsInteger && !mIsInteger; + } }else{ - return thisIsInteger && !vIsInteger; + return nIsVariable && !mIsVariable; } - }else{ - return thisIsVariable && !vIsVariable; } - } + }; bool operator==(const Variable& v) const { return getNode() == v.getNode();} @@ -419,6 +441,27 @@ static void merge_ranges(GetNodeIterator first1, copy_range(first2, last2, result); } +template +static void merge_ranges(GetNodeIterator first1, + GetNodeIterator last1, + GetNodeIterator first2, + GetNodeIterator last2, + std::vector& result, + const Cmp& cmp) { + + while(first1 != last1 && first2 != last2){ + if( cmp(*first1, *first2) ){ + result.push_back(*first1); + ++ first1; + }else{ + result.push_back(*first2); + ++ first2; + } + } + copy_range(first1, last1, result); + copy_range(first2, last2, result); +} + /** * A VarList is a sorted list of variables representing a product. * If the VarList is empty, it represents an empty product or 1. @@ -437,9 +480,7 @@ private: VarList() : NodeWrapper(Node::null()) {} - VarList(Node n) : NodeWrapper(n) { - Assert(isSorted(begin(), end())); - } + VarList(Node n); typedef expr::NodeSelfIterator internal_iterator; -- 2.30.2