From 62c2271daa9dd538a2abf25baca4cddb19a74986 Mon Sep 17 00:00:00 2001 From: Tim King Date: Wed, 26 Mar 2014 16:56:13 -0400 Subject: [PATCH] Fixes an idempotency issue for non-linear multiplication with integer and real variables. --- src/theory/arith/normal_form.cpp | 38 ++++++++++++----- src/theory/arith/normal_form.h | 71 +++++++++++++++++++++++++------- 2 files changed, 83 insertions(+), 26 deletions(-) diff --git a/src/theory/arith/normal_form.cpp b/src/theory/arith/normal_form.cpp index 4edc55cca..f38bfa2f4 100644 --- a/src/theory/arith/normal_form.cpp +++ b/src/theory/arith/normal_form.cpp @@ -25,6 +25,13 @@ namespace CVC4 { namespace theory { namespace arith { + +VarList::VarList(Node n) + : NodeWrapper(n) +{ + Assert(isSorted(begin(), end())); +} + bool Variable::isDivMember(Node n){ switch(n.getKind()){ case kind::DIVISION: @@ -52,9 +59,15 @@ bool VarList::isMember(Node n) { Node prev = *curr; if(!Variable::isMember(prev)) return false; + Variable::VariableNodeCmp cmp; + while( (++curr) != end) { if(!Variable::isMember(*curr)) return false; - if(!(prev <= *curr)) return false; + // prev <= curr : accept + // !(prev <= curr) : reject + // !(!(prev > curr)) : reject + // curr < prev : reject + if((cmp(*curr, prev))) return false; prev = *curr; } return true; @@ -74,15 +87,16 @@ int VarList::cmp(const VarList& vl) const { } VarList VarList::parseVarList(Node n) { - if(Variable::isMember(n)) { - return VarList(Variable(n)); - } else { - Assert(n.getKind() == kind::MULT); - for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) { - Assert(Variable::isMember(*i)); - } - return VarList(n); - } + return VarList(n); + // if(Variable::isMember(n)) { + // return VarList(Variable(n)); + // } else { + // Assert(n.getKind() == kind::MULT); + // for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) { + // Assert(Variable::isMember(*i)); + // } + // return VarList(n); + // } } VarList VarList::operator*(const VarList& other) const { @@ -99,7 +113,9 @@ VarList VarList::operator*(const VarList& other) const { otherBegin = other.internalBegin(), otherEnd = other.internalEnd(); - merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result); + Variable::VariableNodeCmp cmp; + + merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result, cmp); Assert(result.size() >= 2); Node mult = NodeManager::currentNM()->mkNode(kind::MULT, result); diff --git a/src/theory/arith/normal_form.h b/src/theory/arith/normal_form.h index cd5f047b5..050bf8124 100644 --- a/src/theory/arith/normal_form.h +++ b/src/theory/arith/normal_form.h @@ -268,21 +268,43 @@ public: } bool operator<(const Variable& v) const { - bool thisIsVariable = isMetaKindVariable(); - bool vIsVariable = v.isMetaKindVariable(); - - if(thisIsVariable == vIsVariable){ - bool thisIsInteger = isIntegral(); - bool vIsInteger = v.isIntegral(); - if(thisIsInteger == vIsInteger){ - return getNode() < v.getNode(); + VariableNodeCmp cmp; + return cmp(this->getNode(), v.getNode()); + + // bool thisIsVariable = isMetaKindVariable(); + // bool vIsVariable = v.isMetaKindVariable(); + + // if(thisIsVariable == vIsVariable){ + // bool thisIsInteger = isIntegral(); + // bool vIsInteger = v.isIntegral(); + // if(thisIsInteger == vIsInteger){ + // return getNode() < v.getNode(); + // }else{ + // return thisIsInteger && !vIsInteger; + // } + // }else{ + // return thisIsVariable && !vIsVariable; + // } + } + + struct VariableNodeCmp { + bool operator()(Node n, Node m) const { + bool nIsVariable = n.isVar(); + bool mIsVariable = m.isVar(); + + if(nIsVariable == mIsVariable){ + bool nIsInteger = n.getType().isInteger(); + bool mIsInteger = m.getType().isInteger(); + if(nIsInteger == mIsInteger){ + return n < m; + }else{ + return nIsInteger && !mIsInteger; + } }else{ - return thisIsInteger && !vIsInteger; + return nIsVariable && !mIsVariable; } - }else{ - return thisIsVariable && !vIsVariable; } - } + }; bool operator==(const Variable& v) const { return getNode() == v.getNode();} @@ -417,6 +439,27 @@ static void merge_ranges(GetNodeIterator first1, copy_range(first2, last2, result); } +template +static void merge_ranges(GetNodeIterator first1, + GetNodeIterator last1, + GetNodeIterator first2, + GetNodeIterator last2, + std::vector& result, + const Cmp& cmp) { + + while(first1 != last1 && first2 != last2){ + if( cmp(*first1, *first2) ){ + result.push_back(*first1); + ++ first1; + }else{ + result.push_back(*first2); + ++ first2; + } + } + copy_range(first1, last1, result); + copy_range(first2, last2, result); +} + /** * A VarList is a sorted list of variables representing a product. * If the VarList is empty, it represents an empty product or 1. @@ -435,9 +478,7 @@ private: VarList() : NodeWrapper(Node::null()) {} - VarList(Node n) : NodeWrapper(n) { - Assert(isSorted(begin(), end())); - } + VarList(Node n); typedef expr::NodeSelfIterator internal_iterator; -- 2.30.2