Refactor arithmetic pre-rewriter for multiplication (#7930)
[cvc5.git] / src / theory / arith / normal_form.h
index b054f9804b9ab5bb567d83697f85dce69e98930d..577bd052d0372d93f9e6eb75fac0a98a685f261c 100644 (file)
@@ -1,38 +1,35 @@
-/*********************                                                        */
-/*! \file normal_form.h
- ** \verbatim
- ** Original author: taking
- ** Major contributors: mdeters
- ** Minor contributors (to current version): dejan
- ** This file is part of the CVC4 prototype.
- ** Copyright (c) 2009, 2010, 2011  The Analysis of Computer Systems Group (ACSys)
- ** Courant Institute of Mathematical Sciences
- ** New York University
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-#include "cvc4_private.h"
-
-#ifndef __CVC4__THEORY__ARITH__NORMAL_FORM_H
-#define __CVC4__THEORY__ARITH__NORMAL_FORM_H
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Tim King, Morgan Deters, Gereon Kremer
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * [[ Add one-line brief description here ]]
+ *
+ * [[ Add lengthier description here ]]
+ * \todo document this file
+ */
+
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__ARITH__NORMAL_FORM_H
+#define CVC5__THEORY__ARITH__NORMAL_FORM_H
+
+#include <algorithm>
 
+#include "base/output.h"
 #include "expr/node.h"
 #include "expr/node_self_iterator.h"
+#include "theory/arith/delta_rational.h"
 #include "util/rational.h"
-#include "theory/theory.h"
-#include "theory/arith/arith_utilities.h"
-
-#include <list>
-#include <algorithm>
-#include <ext/algorithm>
 
-namespace CVC4 {
+namespace cvc5 {
 namespace theory {
 namespace arith {
 
@@ -49,7 +46,7 @@ namespace arith {
  *
  * variable := n
  *   where
- *     n.getMetaKind() == metakind::VARIABLE or is foreign
+ *     n.isVar() or is foreign
  *     n.getType() \in {Integer, Real}
  *
  * constant := n
@@ -78,12 +75,13 @@ namespace arith {
  *     (exists realMonomial (monomialList qpolynomial))
  *     abs(monomialCoefficient (head (monomialList qpolynomial))) == 1
  *
- * integer_cmp := (<= zpolynomial constant)
+ * integer_cmp := (>= zpolynomial constant)
  *   where
  *     not (exists constantMonomial (monomialList zpolynomial))
  *     (forall integerMonomial (monomialList zpolynomial))
  *     the gcd of all numerators of coefficients is 1
  *     the denominator of all coefficients and the constant is 1
+ *     the leading coefficient is positive
  *
  * rational_eq := (= qvarlist qpolynomial)
  *   where
@@ -211,7 +209,6 @@ namespace arith {
  *      | (+ [monomial]) -> [monomial]
  */
 
-
 /**
  * A NodeWrapper is a class that is a thinly veiled container of a Node object.
  */
@@ -226,16 +223,55 @@ public:
 
 class Variable : public NodeWrapper {
 public:
-  Variable(Node n) : NodeWrapper(n) {
-    Assert(isMember(getNode()));
-  }
-
-  // TODO: check if it's a theory leaf also
-  static bool isMember(Node n) {
-    if (n.getKind() == kind::CONST_RATIONAL) return false;
-    if (isRelationOperator(n.getKind())) return false;
-    return Theory::isLeafOf(n, theory::THEORY_ARITH);
-  }
+ Variable(Node n) : NodeWrapper(n) { Assert(isMember(getNode())); }
+
+ // TODO: check if it's a theory leaf also
+ static bool isMember(Node n)
+ {
+   Kind k = n.getKind();
+   switch (k)
+   {
+     case kind::CONST_RATIONAL: return false;
+     case kind::INTS_DIVISION:
+     case kind::INTS_MODULUS:
+     case kind::DIVISION:
+     case kind::INTS_DIVISION_TOTAL:
+     case kind::INTS_MODULUS_TOTAL:
+     case kind::DIVISION_TOTAL: return isDivMember(n);
+     case kind::IAND: return isIAndMember(n);
+     case kind::POW2: return isPow2Member(n);
+     case kind::EXPONENTIAL:
+     case kind::SINE:
+     case kind::COSINE:
+     case kind::TANGENT:
+     case kind::COSECANT:
+     case kind::SECANT:
+     case kind::COTANGENT:
+     case kind::ARCSINE:
+     case kind::ARCCOSINE:
+     case kind::ARCTANGENT:
+     case kind::ARCCOSECANT:
+     case kind::ARCSECANT:
+     case kind::ARCCOTANGENT:
+     case kind::SQRT:
+     case kind::PI: return isTranscendentalMember(n);
+     case kind::ABS:
+     case kind::TO_INTEGER:
+       // Treat to_int as a variable; it is replaced in early preprocessing
+       // by a variable.
+       return true;
+     default: return isLeafMember(n);
+   }
+ }
+
+  static bool isLeafMember(Node n);
+  static bool isIAndMember(Node n);
+  static bool isPow2Member(Node n);
+  static bool isDivMember(Node n);
+  bool isDivLike() const{
+    return isDivMember(getNode());
+  }
+  static bool isTranscendentalMember(Node n);
 
   bool isNormalForm() { return isMember(getNode()); }
 
@@ -244,51 +280,76 @@ public:
   }
 
   bool isMetaKindVariable() const {
-    return getNode().getMetaKind() == kind::metakind::VARIABLE;
+    return getNode().isVar();
   }
 
   bool operator<(const Variable& v) const {
-    bool thisIsVariable = isMetaKindVariable();
-    bool vIsVariable = v.isMetaKindVariable();
-
-    if(thisIsVariable == vIsVariable){
-      bool thisIsInteger = isIntegral();
-      bool vIsInteger = v.isIntegral();
-      if(thisIsInteger == vIsInteger){
-        return getNode() < v.getNode();
+    VariableNodeCmp cmp;
+    return cmp(this->getNode(), v.getNode());
+  }
+
+  struct VariableNodeCmp {
+    static inline int cmp(const Node& n, const Node& m) {
+      if ( n == m ) { return 0; }
+
+      // this is now slightly off of the old variable order.
+
+      bool nIsInteger = n.getType().isInteger();
+      bool mIsInteger = m.getType().isInteger();
+
+      if(nIsInteger == mIsInteger){
+        bool nIsVariable = n.isVar();
+        bool mIsVariable = m.isVar();
+
+        if(nIsVariable == mIsVariable){
+          if(n < m){
+            return -1;
+          }else{
+            Assert(n != m);
+            return 1;
+          }
+        }else{
+          if(nIsVariable){
+            return -1; // nIsVariable => !mIsVariable
+          }else{
+            return 1; // !nIsVariable => mIsVariable
+          }
+        }
       }else{
-        return thisIsInteger && !vIsInteger;
+        Assert(nIsInteger != mIsInteger);
+        if(nIsInteger){
+          return 1; // nIsInteger => !mIsInteger
+        }else{
+          return -1; // !nIsInteger => mIsInteger
+        }
       }
-    }else{
-      return thisIsVariable && !vIsVariable;
     }
-  }
+
+    bool operator()(const Node& n, const Node& m) const {
+      return VariableNodeCmp::cmp(n,m) < 0;
+    }
+  };
 
   bool operator==(const Variable& v) const { return getNode() == v.getNode();}
 
+  size_t getComplexity() const;
 };/* class Variable */
 
-
 class Constant : public NodeWrapper {
 public:
-  Constant(Node n) : NodeWrapper(n) {
-    Assert(isMember(getNode()));
-  }
+ Constant(Node n) : NodeWrapper(n) { Assert(isMember(getNode())); }
 
-  static bool isMember(Node n) {
-    return n.getKind() == kind::CONST_RATIONAL;
-  }
+ static bool isMember(Node n) { return n.getKind() == kind::CONST_RATIONAL; }
 
 bool isNormalForm() { return isMember(getNode()); }
+ bool isNormalForm() { return isMember(getNode()); }
 
-  static Constant mkConstant(Node n) {
-    Assert(n.getKind() == kind::CONST_RATIONAL);
-    return Constant(n);
-  }
+ static Constant mkConstant(Node n)
+ {
+   Assert(n.getKind() == kind::CONST_RATIONAL);
+   return Constant(n);
+ }
 
-  static Constant mkConstant(const Rational& rat) {
-    return Constant(mkRationalNode(rat));
-  }
+  static Constant mkConstant(const Rational& rat);
 
   static Constant mkZero() {
     return mkConstant(Rational(0));
@@ -302,6 +363,7 @@ public:
     return getNode().getConst<Rational>();
   }
 
+  static int absCmp(const Constant& a, const Constant& b);
   bool isIntegral() const { return getValue().isIntegral(); }
 
   int sgn() const { return getValue().sgn(); }
@@ -353,12 +415,14 @@ public:
     return getValue().getNumerator().length();
   }
 
+  size_t getComplexity() const;
+
 };/* class Constant */
 
 
 template <class GetNodeIterator>
 inline Node makeNode(Kind k, GetNodeIterator start, GetNodeIterator end) {
-  NodeBuilder<> nb(k);
+  NodeBuilder nb(k);
 
   while(start != end) {
     nb << (*start).getNode();
@@ -368,35 +432,6 @@ inline Node makeNode(Kind k, GetNodeIterator start, GetNodeIterator end) {
   return Node(nb);
 }/* makeNode<GetNodeIterator>(Kind, iterator, iterator) */
 
-
-template <class GetNodeIterator, class T>
-static void copy_range(GetNodeIterator begin, GetNodeIterator end, std::vector<T>& result){
-  while(begin != end){
-    result.push_back(*begin);
-    ++begin;
-  }
-}
-
-template <class GetNodeIterator, class T>
-static void merge_ranges(GetNodeIterator first1,
-                  GetNodeIterator last1,
-                  GetNodeIterator first2,
-                  GetNodeIterator last2,
-                  std::vector<T>& result) {
-
-  while(first1 != last1 && first2 != last2){
-    if( (*first1) < (*first2) ){
-      result.push_back(*first1);
-      ++ first1;
-    }else{
-      result.push_back(*first2);
-      ++ first2;
-    }
-  }
-  copy_range(first1, last1, result);
-  copy_range(first2, last2, result);
-}
-
 /**
  * A VarList is a sorted list of variables representing a product.
  * If the VarList is empty, it represents an empty product or 1.
@@ -410,14 +445,12 @@ private:
   static Node multList(const std::vector<Variable>& list) {
     Assert(list.size() >= 2);
 
-    return makeNode(kind::MULT, list.begin(), list.end());
+    return makeNode(kind::NONLINEAR_MULT, list.begin(), list.end());
   }
 
   VarList() : NodeWrapper(Node::null()) {}
 
-  VarList(Node n) : NodeWrapper(n) {
-    Assert(isSorted(begin(), end()));
-  }
+  VarList(Node n);
 
   typedef expr::NodeSelfIterator internal_iterator;
 
@@ -444,6 +477,25 @@ public:
     internal_iterator d_iter;
 
   public:
+    /* The following types are required by trait std::iterator_traits */
+
+    /** Iterator tag */
+    using iterator_category = std::forward_iterator_tag;
+
+    /** The type of the item */
+    using value_type = Variable;
+
+    /** The pointer type of the item */
+    using pointer = Variable*;
+
+    /** The reference type of the item */
+    using reference = Variable&;
+
+    /** The type returned when two iterators are subtracted */
+    using difference_type = std::ptrdiff_t;
+
+    /* End of std::iterator_traits required types */
+
     explicit iterator(internal_iterator i) : d_iter(i) {}
 
     inline Variable operator*() {
@@ -514,7 +566,7 @@ public:
 
   bool empty() const { return getNode().isNull(); }
   bool singleton() const {
-    return !empty() && getNode().getKind() != kind::MULT;
+    return !empty() && getNode().getKind() != kind::NONLINEAR_MULT;
   }
 
   int size() const {
@@ -543,6 +595,7 @@ public:
     }
     return true;
   }
+  size_t getComplexity() const;
 
 private:
   bool isSorted(iterator start, iterator end);
@@ -550,6 +603,7 @@ private:
 };/* class VarList */
 
 
+/** Constructors have side conditions. Use the static mkMonomial functions instead. */ 
 class Monomial : public NodeWrapper {
 private:
   Constant constant;
@@ -557,8 +611,8 @@ private:
   Monomial(Node n, const Constant& c, const VarList& vl):
     NodeWrapper(n), constant(c), varList(vl)
   {
-    Assert(!c.isZero() ||  vl.empty() );
-    Assert( c.isZero() || !vl.empty() );
+    Assert(!c.isZero() || vl.empty());
+    Assert(c.isZero() || !vl.empty());
 
     Assert(!c.isOne() || !multStructured(n));
   }
@@ -576,33 +630,38 @@ private:
       n.getNumChildren() == 2;
   }
 
-public:
-
   Monomial(const Constant& c):
     NodeWrapper(c.getNode()), constant(c), varList(VarList::mkEmptyVarList())
   { }
-
+  
   Monomial(const VarList& vl):
     NodeWrapper(vl.getNode()), constant(Constant::mkConstant(1)), varList(vl)
   {
-    Assert( !varList.empty() );
+    Assert(!varList.empty());
   }
 
   Monomial(const Constant& c, const VarList& vl):
     NodeWrapper(makeMultNode(c,vl)), constant(c), varList(vl)
   {
-    Assert( !c.isZero() );
-    Assert( !c.isOne() );
-    Assert( !varList.empty() );
+    Assert(!c.isZero());
+    Assert(!c.isOne());
+    Assert(!varList.empty());
 
     Assert(multStructured(getNode()));
   }
-
+public:
   static bool isMember(TNode n);
 
   /** Makes a monomial with no restrictions on c and vl. */
   static Monomial mkMonomial(const Constant& c, const VarList& vl);
 
+  /** If vl is empty, this make one. */
+  static Monomial mkMonomial(const VarList& vl);
+
+  static Monomial mkMonomial(const Constant& c){
+    return Monomial(c);
+  }
+  
   static Monomial mkMonomial(const Variable& v){
     return Monomial(VarList(v));
   }
@@ -617,7 +676,7 @@ public:
   }
   const Constant& getConstant() const { return constant; }
   const VarList& getVarList() const { return varList; }
-
+  
   bool isConstant() const {
     return varList.empty();
   }
@@ -660,13 +719,16 @@ public:
   }
 
   static bool isSorted(const std::vector<Monomial>& m) {
-    return __gnu_cxx::is_sorted(m.begin(), m.end());
+    return std::is_sorted(m.begin(), m.end());
   }
 
   static bool isStrictlySorted(const std::vector<Monomial>& m) {
     return isSorted(m) && std::adjacent_find(m.begin(),m.end()) == m.end();
   }
 
+  static void sort(std::vector<Monomial>& m);
+  static void combineAdjacentMonomials(std::vector<Monomial>& m);
+
   /**
    * The variable product
    */
@@ -688,15 +750,23 @@ public:
     return integralCoefficient() && integralVariables();
   }
 
+  /** Returns true if the VarList is a product of at least 2 Variables.*/
+  bool isNonlinear() const {
+    return getVarList().size() >= 2;
+  }
+
   /**
    * Given a sorted list of monomials, this function transforms this
    * into a strictly sorted list of monomials that does not contain zero.
    */
-  static std::vector<Monomial> sumLikeTerms(const std::vector<Monomial>& monos);
+  //static std::vector<Monomial> sumLikeTerms(const std::vector<Monomial>& monos);
 
-  bool absLessThan(const Monomial& other) const{
-    return getConstant().abs() < other.getConstant().abs();
+  int absCmp(const Monomial& other) const{
+    return getConstant().getValue().absCmp(other.getConstant().getValue());
   }
+  // bool absLessThan(const Monomial& other) const{
+  //   return getConstant().abs() < other.getConstant().abs();
+  // }
 
   uint32_t coefficientLength() const{
     return getConstant().length();
@@ -705,6 +775,7 @@ public:
   void print() const;
   static void printList(const std::vector<Monomial>& list);
 
+  size_t getComplexity() const;
 };/* class Monomial */
 
 class SumPair;
@@ -745,41 +816,32 @@ private:
   bool singleton() const { return d_singleton; }
 
 public:
-  static bool isMember(TNode n) {
-    if(Monomial::isMember(n)){
-      return true;
-    }else if(n.getKind() == kind::PLUS){
-      Assert(n.getNumChildren() >= 2);
-      Node::iterator currIter = n.begin(), end = n.end();
-      Node prev = *currIter;
-      if(!Monomial::isMember(prev)){
-        return false;
-      }
-
-      Monomial mprev = Monomial::parseMonomial(prev);
-      ++currIter;
-      for(; currIter != end; ++currIter){
-        Node curr = *currIter;
-        if(!Monomial::isMember(curr)){
-          return false;
-        }
-        Monomial mcurr = Monomial::parseMonomial(curr);
-        if(!(mprev < mcurr)){
-          return false;
-        }
-        mprev = mcurr;
-      }
-      return true;
-    } else {
-      return false;
-    }
-  }
+  static bool isMember(TNode n);
 
   class iterator {
   private:
     internal_iterator d_iter;
 
   public:
+    /* The following types are required by trait std::iterator_traits */
+
+    /** Iterator tag */
+    using iterator_category = std::forward_iterator_tag;
+
+    /** The type of the item */
+    using value_type = Monomial;
+
+    /** The pointer type of the item */
+    using pointer = Monomial*;
+
+    /** The reference type of the item */
+    using reference = Monomial&;
+
+    /** The type returned when two iterators are subtracted */
+    using difference_type = std::ptrdiff_t;
+
+    /* End of std::iterator_traits required types */
+
     explicit iterator(internal_iterator i) : d_iter(i) {}
 
     inline Monomial operator*() {
@@ -814,12 +876,16 @@ public:
   Polynomial(const std::vector<Monomial>& m):
     NodeWrapper(makePlusNode(m)), d_singleton(false)
   {
-    Assert( m.size() >= 2);
-    Assert( Monomial::isStrictlySorted(m) );
+    Assert(m.size() >= 2);
+    Assert(Monomial::isStrictlySorted(m));
+  }
+
+  static Polynomial mkPolynomial(const Constant& c){
+    return Polynomial(Monomial::mkMonomial(c));
   }
 
   static Polynomial mkPolynomial(const Variable& v){
-    return Monomial::mkMonomial(v);
+    return Polynomial(Monomial::mkMonomial(v));
   }
 
   static Polynomial mkPolynomial(const std::vector<Monomial>& m) {
@@ -868,12 +934,12 @@ public:
   }
 
   Polynomial getTail() const {
-    Assert(! singleton());
+    Assert(!singleton());
 
     iterator tailStart = begin();
     ++tailStart;
     std::vector<Monomial> subrange;
-    copy_range(tailStart, end(), subrange);
+    std::copy(tailStart, end(), std::back_inserter(subrange));
     return mkPolynomial(subrange);
   }
 
@@ -913,6 +979,24 @@ public:
     return true;
   }
 
+  static Polynomial sumPolynomials(const std::vector<Polynomial>& polynomials);
+
+  /** Returns true if the polynomial contains a non-linear monomial.*/
+  bool isNonlinear() const;
+
+  /** Check whether this polynomial is only a single variable. */
+  bool isVariable() const
+  {
+    return singleton() && getHead().getVarList().singleton()
+           && getHead().coefficientIsOne();
+  }
+  /** Return the variable, given that isVariable() holds. */
+  Variable getVariable() const
+  {
+    Assert(isVariable());
+    return getHead().getVarList().getHead();
+  }
+
   /**
    * Selects a minimal monomial in the polynomial by the absolute value of
    * the coefficient.
@@ -925,6 +1009,10 @@ public:
   bool denominatorLCMIsOne() const;
   bool numeratorGCDIsOne() const;
 
+  bool signNormalizedReducedSum() const {
+    return leadingCoefficientIsPositive() && denominatorLCMIsOne() && numeratorGCDIsOne();
+  }
+
   /**
    * Returns the Least Common Multiple of the denominators of the coefficients
    * of the monomials.
@@ -943,13 +1031,8 @@ public:
    */
   Integer gcd() const;
 
-  Polynomial exactDivide(const Integer& z) const {
-    Assert(isIntegral());
-    Constant invz = Constant::mkConstant(Rational(1,z));
-    Polynomial prod = (*this) * Monomial(invz);
-    Assert(prod.isIntegral());
-    return prod;
-  }
+  /** z must divide all of the coefficients of the polynomial. */
+  Polynomial exactDivide(const Integer& z) const;
 
   Polynomial operator+(const Polynomial& vl) const;
   Polynomial operator-(const Polynomial& vl) const;
@@ -987,7 +1070,7 @@ public:
     }else{
       uint32_t max = (*i).coefficientLength();
       ++i;
-      for(; i!=e; ++i){      
+      for(; i!=e; ++i){
         uint32_t curr = (*i).coefficientLength();
         if(curr > max){
           max = curr;
@@ -1026,8 +1109,18 @@ public:
     return getHead().getVarList();
   }
 
+  size_t getComplexity() const;
+
   friend class SumPair;
-  friend class Comparison;;
+  friend class Comparison;
+
+  /** Returns a node that if asserted ensures v is the abs of this polynomial.*/
+  Node makeAbsCondition(Variable v){
+    return makeAbsCondition(v, *this);
+  }
+
+  /** Returns a node that if asserted ensures v is the abs of p.*/
+  static Node makeAbsCondition(Variable v, Polynomial p);
 
 };/* class Polynomial */
 
@@ -1049,14 +1142,9 @@ private:
     return NodeManager::currentNM()->mkNode(kind::PLUS, p.getNode(), c.getNode());
   }
 
-  SumPair(TNode n) :
-    NodeWrapper(n)
-  {
-    Assert(isNormalForm());
-  }
-
-public:
+  SumPair(TNode n) : NodeWrapper(n) { Assert(isNormalForm()); }
 
+ public:
   SumPair(const Polynomial& p):
     NodeWrapper(toNode(p, Constant::mkConstant(0)))
   {
@@ -1133,6 +1221,14 @@ public:
     return getConstant().isZero() && isConstant();
   }
 
+  uint32_t size() const{
+    return getPolynomial().size();
+  }
+
+  bool isNonlinear() const{
+    return getPolynomial().isNonlinear();
+  }
+
   /**
    * Returns the greatest common divisor of gcd(getPolynomial()) and getConstant().
    * The SumPair must be integral.
@@ -1243,7 +1339,7 @@ private:
    * Creates a comparison equivalent to (k l 0).
    * k is either GT or GEQ.
    * It is not the case that all variables in l are integral.
-   */  
+   */
   static Node mkRatInequality(Kind k, const Polynomial& l);
 
 public:
@@ -1324,15 +1420,30 @@ public:
     return parse.isNormalForm();
   }
 
+  size_t getComplexity() const;
+
   SumPair toSumPair() const;
 
   Polynomial normalizedVariablePart() const;
   DeltaRational normalizedDeltaRational() const;
 
+  /**
+   * Transforms a Comparison object into a stronger normal form:
+   *    Polynomial ~Kind~ Constant
+   * 
+   * From the comparison, this method resolved a negation (if present) and
+   * moves everything to the left side.
+   * If split_constant is false, the constant is always zero.
+   * If split_constant is true, the polynomial has no constant term and is
+   * normalized to have leading coefficient one.
+   */
+  std::tuple<Polynomial, Kind, Constant> decompose(
+      bool split_constant = false) const;
+
 };/* class Comparison */
 
-}/* CVC4::theory::arith namespace */
-}/* CVC4::theory namespace */
-}/* CVC4 namespace */
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5
 
-#endif /* __CVC4__THEORY__ARITH__NORMAL_FORM_H */
+#endif /* CVC5__THEORY__ARITH__NORMAL_FORM_H */