Fixes an idempotency issue for non-linear multiplication with integer and real variables.
authorTim King <taking@cs.nyu.edu>
Wed, 26 Mar 2014 20:56:13 +0000 (16:56 -0400)
committerTim King <taking@cs.nyu.edu>
Wed, 26 Mar 2014 20:56:13 +0000 (16:56 -0400)
src/theory/arith/normal_form.cpp
src/theory/arith/normal_form.h

index 4edc55cca9013bc69d6a25708418e6f79aa6a7b2..f38bfa2f4d69c8ad49bfcc69763c3cb4abd1e1bc 100644 (file)
@@ -25,6 +25,13 @@ namespace CVC4 {
 namespace theory {
 namespace arith {
 
+
+VarList::VarList(Node n)
+  : NodeWrapper(n)
+{
+  Assert(isSorted(begin(), end()));
+}
+
 bool Variable::isDivMember(Node n){
   switch(n.getKind()){
   case kind::DIVISION:
@@ -52,9 +59,15 @@ bool VarList::isMember(Node n) {
     Node prev = *curr;
     if(!Variable::isMember(prev)) return false;
 
+    Variable::VariableNodeCmp cmp;
+
     while( (++curr) != end) {
       if(!Variable::isMember(*curr)) return false;
-      if(!(prev <= *curr)) return false;
+      // prev <= curr : accept
+      // !(prev <= curr) : reject
+      // !(!(prev > curr)) : reject
+      // curr < prev : reject
+      if((cmp(*curr, prev))) return false;
       prev = *curr;
     }
     return true;
@@ -74,15 +87,16 @@ int VarList::cmp(const VarList& vl) const {
 }
 
 VarList VarList::parseVarList(Node n) {
-  if(Variable::isMember(n)) {
-    return VarList(Variable(n));
-  } else {
-    Assert(n.getKind() == kind::MULT);
-    for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) {
-      Assert(Variable::isMember(*i));
-    }
-    return VarList(n);
-  }
+  return VarList(n);
+  // if(Variable::isMember(n)) {
+  //   return VarList(Variable(n));
+  // } else {
+  //   Assert(n.getKind() == kind::MULT);
+  //   for(Node::iterator i=n.begin(), end = n.end(); i!=end; ++i) {
+  //     Assert(Variable::isMember(*i));
+  //   }
+  //   return VarList(n);
+  // }
 }
 
 VarList VarList::operator*(const VarList& other) const {
@@ -99,7 +113,9 @@ VarList VarList::operator*(const VarList& other) const {
       otherBegin = other.internalBegin(),
       otherEnd = other.internalEnd();
 
-    merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result);
+    Variable::VariableNodeCmp cmp;
+
+    merge_ranges(thisBegin, thisEnd, otherBegin, otherEnd, result, cmp);
 
     Assert(result.size() >= 2);
     Node mult = NodeManager::currentNM()->mkNode(kind::MULT, result);
index cd5f047b5965ee523a8cf34a952e499ebbf7b5c8..050bf8124736858fc7704802d994c1a524769546 100644 (file)
@@ -268,21 +268,43 @@ public:
   }
 
   bool operator<(const Variable& v) const {
-    bool thisIsVariable = isMetaKindVariable();
-    bool vIsVariable = v.isMetaKindVariable();
-
-    if(thisIsVariable == vIsVariable){
-      bool thisIsInteger = isIntegral();
-      bool vIsInteger = v.isIntegral();
-      if(thisIsInteger == vIsInteger){
-        return getNode() < v.getNode();
+    VariableNodeCmp cmp;
+    return cmp(this->getNode(), v.getNode());
+
+    // bool thisIsVariable = isMetaKindVariable();
+    // bool vIsVariable = v.isMetaKindVariable();
+
+    // if(thisIsVariable == vIsVariable){
+    //   bool thisIsInteger = isIntegral();
+    //   bool vIsInteger = v.isIntegral();
+    //   if(thisIsInteger == vIsInteger){
+    //     return getNode() < v.getNode();
+    //   }else{
+    //     return thisIsInteger && !vIsInteger;
+    //   }
+    // }else{
+    //   return thisIsVariable && !vIsVariable;
+    // }
+  }
+
+  struct VariableNodeCmp {
+    bool operator()(Node n, Node m) const {
+      bool nIsVariable = n.isVar();
+      bool mIsVariable = m.isVar();
+
+      if(nIsVariable == mIsVariable){
+        bool nIsInteger = n.getType().isInteger();
+        bool mIsInteger = m.getType().isInteger();
+        if(nIsInteger == mIsInteger){
+          return n < m;
+        }else{
+          return nIsInteger && !mIsInteger;
+        }
       }else{
-        return thisIsInteger && !vIsInteger;
+        return nIsVariable && !mIsVariable;
       }
-    }else{
-      return thisIsVariable && !vIsVariable;
     }
-  }
+  };
 
   bool operator==(const Variable& v) const { return getNode() == v.getNode();}
 
@@ -417,6 +439,27 @@ static void merge_ranges(GetNodeIterator first1,
   copy_range(first2, last2, result);
 }
 
+template <class GetNodeIterator, class T, class Cmp>
+static void merge_ranges(GetNodeIterator first1,
+                         GetNodeIterator last1,
+                         GetNodeIterator first2,
+                         GetNodeIterator last2,
+                         std::vector<T>& result,
+                         const Cmp& cmp) {
+
+  while(first1 != last1 && first2 != last2){
+    if( cmp(*first1, *first2) ){
+      result.push_back(*first1);
+      ++ first1;
+    }else{
+      result.push_back(*first2);
+      ++ first2;
+    }
+  }
+  copy_range(first1, last1, result);
+  copy_range(first2, last2, result);
+}
+
 /**
  * A VarList is a sorted list of variables representing a product.
  * If the VarList is empty, it represents an empty product or 1.
@@ -435,9 +478,7 @@ private:
 
   VarList() : NodeWrapper(Node::null()) {}
 
-  VarList(Node n) : NodeWrapper(n) {
-    Assert(isSorted(begin(), end()));
-  }
+  VarList(Node n);
 
   typedef expr::NodeSelfIterator internal_iterator;