Refactor arithmetic pre-rewriter for multiplication (#7930)
[cvc5.git] / src / theory / arith / normal_form.h
index 1a3327e8ff1d73e537faa5d7b82f0fed7dc6d363..577bd052d0372d93f9e6eb75fac0a98a685f261c 100644 (file)
@@ -1,39 +1,35 @@
-/*********************                                                        */
-/*! \file normal_form.h
- ** \verbatim
- ** Original author: Tim King
- ** Major contributors: none
- ** Minor contributors (to current version): Dejan Jovanovic, Morgan Deters
- ** This file is part of the CVC4 project.
- ** Copyright (c) 2009-2013  New York University and The University of Iowa
- ** See the file COPYING in the top-level source directory for licensing
- ** information.\endverbatim
- **
- ** \brief [[ Add one-line brief description here ]]
- **
- ** [[ Add lengthier description here ]]
- ** \todo document this file
- **/
-
-#include "cvc4_private.h"
-
-#ifndef __CVC4__THEORY__ARITH__NORMAL_FORM_H
-#define __CVC4__THEORY__ARITH__NORMAL_FORM_H
+/******************************************************************************
+ * Top contributors (to current version):
+ *   Tim King, Morgan Deters, Gereon Kremer
+ *
+ * This file is part of the cvc5 project.
+ *
+ * Copyright (c) 2009-2021 by the authors listed in the file AUTHORS
+ * in the top-level source directory and their institutional affiliations.
+ * All rights reserved.  See the file COPYING in the top-level source
+ * directory for licensing information.
+ * ****************************************************************************
+ *
+ * [[ Add one-line brief description here ]]
+ *
+ * [[ Add lengthier description here ]]
+ * \todo document this file
+ */
 
-#include "expr/node.h"
-#include "expr/node_self_iterator.h"
-#include "util/rational.h"
-#include "theory/arith/delta_rational.h"
-//#include "theory/arith/arith_utilities.h"
+#include "cvc5_private.h"
+
+#ifndef CVC5__THEORY__ARITH__NORMAL_FORM_H
+#define CVC5__THEORY__ARITH__NORMAL_FORM_H
 
-#include <list>
 #include <algorithm>
 
-#if IS_SORTED_IN_GNUCXX_NAMESPACE
-#  include <ext/algorithm>
-#endif /* IS_SORTED_IN_GNUCXX_NAMESPACE */
+#include "base/output.h"
+#include "expr/node.h"
+#include "expr/node_self_iterator.h"
+#include "theory/arith/delta_rational.h"
+#include "util/rational.h"
 
-namespace CVC4 {
+namespace cvc5 {
 namespace theory {
 namespace arith {
 
@@ -227,38 +223,55 @@ public:
 
 class Variable : public NodeWrapper {
 public:
-  Variable(Node n) : NodeWrapper(n) {
-    Assert(isMember(getNode()));
-  }
-
-  // TODO: check if it's a theory leaf also
-  static bool isMember(Node n) {
-    Kind k = n.getKind();
-    switch(k){
-    case kind::CONST_RATIONAL:
-      return false;
-    case kind::INTS_DIVISION:
-    case kind::INTS_MODULUS:
-    case kind::DIVISION:
-    case kind::INTS_DIVISION_TOTAL:
-    case kind::INTS_MODULUS_TOTAL:
-    case kind::DIVISION_TOTAL:
-      return isDivMember(n);
-    case kind::ABS:
-    case kind::TO_INTEGER:
-      // Treat to_int as a variable; it is replaced in early preprocessing
-      // by a variable.
-      return true;
-    default:
-      return isLeafMember(n);
-    }
-  }
+ Variable(Node n) : NodeWrapper(n) { Assert(isMember(getNode())); }
+
+ // TODO: check if it's a theory leaf also
+ static bool isMember(Node n)
+ {
+   Kind k = n.getKind();
+   switch (k)
+   {
+     case kind::CONST_RATIONAL: return false;
+     case kind::INTS_DIVISION:
+     case kind::INTS_MODULUS:
+     case kind::DIVISION:
+     case kind::INTS_DIVISION_TOTAL:
+     case kind::INTS_MODULUS_TOTAL:
+     case kind::DIVISION_TOTAL: return isDivMember(n);
+     case kind::IAND: return isIAndMember(n);
+     case kind::POW2: return isPow2Member(n);
+     case kind::EXPONENTIAL:
+     case kind::SINE:
+     case kind::COSINE:
+     case kind::TANGENT:
+     case kind::COSECANT:
+     case kind::SECANT:
+     case kind::COTANGENT:
+     case kind::ARCSINE:
+     case kind::ARCCOSINE:
+     case kind::ARCTANGENT:
+     case kind::ARCCOSECANT:
+     case kind::ARCSECANT:
+     case kind::ARCCOTANGENT:
+     case kind::SQRT:
+     case kind::PI: return isTranscendentalMember(n);
+     case kind::ABS:
+     case kind::TO_INTEGER:
+       // Treat to_int as a variable; it is replaced in early preprocessing
+       // by a variable.
+       return true;
+     default: return isLeafMember(n);
+   }
+ }
 
   static bool isLeafMember(Node n);
+  static bool isIAndMember(Node n);
+  static bool isPow2Member(Node n);
   static bool isDivMember(Node n);
   bool isDivLike() const{
     return isDivMember(getNode());
   }
+  static bool isTranscendentalMember(Node n);
 
   bool isNormalForm() { return isMember(getNode()); }
 
@@ -276,7 +289,7 @@ public:
   }
 
   struct VariableNodeCmp {
-    static inline int cmp(Node n, Node m) {
+    static inline int cmp(const Node& n, const Node& m) {
       if ( n == m ) { return 0; }
 
       // this is now slightly off of the old variable order.
@@ -292,7 +305,7 @@ public:
           if(n < m){
             return -1;
           }else{
-            Assert( n != m );
+            Assert(n != m);
             return 1;
           }
         }else{
@@ -312,7 +325,7 @@ public:
       }
     }
 
-    bool operator()(Node n, Node m) const {
+    bool operator()(const Node& n, const Node& m) const {
       return VariableNodeCmp::cmp(n,m) < 0;
     }
   };
@@ -322,23 +335,19 @@ public:
   size_t getComplexity() const;
 };/* class Variable */
 
-
 class Constant : public NodeWrapper {
 public:
-  Constant(Node n) : NodeWrapper(n) {
-    Assert(isMember(getNode()));
-  }
+ Constant(Node n) : NodeWrapper(n) { Assert(isMember(getNode())); }
 
-  static bool isMember(Node n) {
-    return n.getKind() == kind::CONST_RATIONAL;
-  }
+ static bool isMember(Node n) { return n.getKind() == kind::CONST_RATIONAL; }
 
 bool isNormalForm() { return isMember(getNode()); }
+ bool isNormalForm() { return isMember(getNode()); }
 
-  static Constant mkConstant(Node n) {
-    Assert(n.getKind() == kind::CONST_RATIONAL);
-    return Constant(n);
-  }
+ static Constant mkConstant(Node n)
+ {
+   Assert(n.getKind() == kind::CONST_RATIONAL);
+   return Constant(n);
+ }
 
   static Constant mkConstant(const Rational& rat);
 
@@ -413,7 +422,7 @@ public:
 
 template <class GetNodeIterator>
 inline Node makeNode(Kind k, GetNodeIterator start, GetNodeIterator end) {
-  NodeBuilder<> nb(k);
+  NodeBuilder nb(k);
 
   while(start != end) {
     nb << (*start).getNode();
@@ -423,56 +432,6 @@ inline Node makeNode(Kind k, GetNodeIterator start, GetNodeIterator end) {
   return Node(nb);
 }/* makeNode<GetNodeIterator>(Kind, iterator, iterator) */
 
-
-template <class GetNodeIterator, class T>
-static void copy_range(GetNodeIterator begin, GetNodeIterator end, std::vector<T>& result){
-  while(begin != end){
-    result.push_back(*begin);
-    ++begin;
-  }
-}
-
-template <class GetNodeIterator, class T>
-static void merge_ranges(GetNodeIterator first1,
-                  GetNodeIterator last1,
-                  GetNodeIterator first2,
-                  GetNodeIterator last2,
-                  std::vector<T>& result) {
-
-  while(first1 != last1 && first2 != last2){
-    if( (*first1) < (*first2) ){
-      result.push_back(*first1);
-      ++ first1;
-    }else{
-      result.push_back(*first2);
-      ++ first2;
-    }
-  }
-  copy_range(first1, last1, result);
-  copy_range(first2, last2, result);
-}
-
-template <class GetNodeIterator, class T, class Cmp>
-static void merge_ranges(GetNodeIterator first1,
-                         GetNodeIterator last1,
-                         GetNodeIterator first2,
-                         GetNodeIterator last2,
-                         std::vector<T>& result,
-                         const Cmp& cmp) {
-
-  while(first1 != last1 && first2 != last2){
-    if( cmp(*first1, *first2) ){
-      result.push_back(*first1);
-      ++ first1;
-    }else{
-      result.push_back(*first2);
-      ++ first2;
-    }
-  }
-  copy_range(first1, last1, result);
-  copy_range(first2, last2, result);
-}
-
 /**
  * A VarList is a sorted list of variables representing a product.
  * If the VarList is empty, it represents an empty product or 1.
@@ -486,7 +445,7 @@ private:
   static Node multList(const std::vector<Variable>& list) {
     Assert(list.size() >= 2);
 
-    return makeNode(kind::MULT, list.begin(), list.end());
+    return makeNode(kind::NONLINEAR_MULT, list.begin(), list.end());
   }
 
   VarList() : NodeWrapper(Node::null()) {}
@@ -513,11 +472,30 @@ private:
 
 public:
 
-  class iterator : public std::iterator<std::input_iterator_tag, Variable> {
+  class iterator {
   private:
     internal_iterator d_iter;
 
   public:
+    /* The following types are required by trait std::iterator_traits */
+
+    /** Iterator tag */
+    using iterator_category = std::forward_iterator_tag;
+
+    /** The type of the item */
+    using value_type = Variable;
+
+    /** The pointer type of the item */
+    using pointer = Variable*;
+
+    /** The reference type of the item */
+    using reference = Variable&;
+
+    /** The type returned when two iterators are subtracted */
+    using difference_type = std::ptrdiff_t;
+
+    /* End of std::iterator_traits required types */
+
     explicit iterator(internal_iterator i) : d_iter(i) {}
 
     inline Variable operator*() {
@@ -588,7 +566,7 @@ public:
 
   bool empty() const { return getNode().isNull(); }
   bool singleton() const {
-    return !empty() && getNode().getKind() != kind::MULT;
+    return !empty() && getNode().getKind() != kind::NONLINEAR_MULT;
   }
 
   int size() const {
@@ -625,6 +603,7 @@ private:
 };/* class VarList */
 
 
+/** Constructors have side conditions. Use the static mkMonomial functions instead. */ 
 class Monomial : public NodeWrapper {
 private:
   Constant constant;
@@ -632,8 +611,8 @@ private:
   Monomial(Node n, const Constant& c, const VarList& vl):
     NodeWrapper(n), constant(c), varList(vl)
   {
-    Assert(!c.isZero() ||  vl.empty() );
-    Assert( c.isZero() || !vl.empty() );
+    Assert(!c.isZero() || vl.empty());
+    Assert(c.isZero() || !vl.empty());
 
     Assert(!c.isOne() || !multStructured(n));
   }
@@ -651,33 +630,38 @@ private:
       n.getNumChildren() == 2;
   }
 
-public:
-
   Monomial(const Constant& c):
     NodeWrapper(c.getNode()), constant(c), varList(VarList::mkEmptyVarList())
   { }
-
+  
   Monomial(const VarList& vl):
     NodeWrapper(vl.getNode()), constant(Constant::mkConstant(1)), varList(vl)
   {
-    Assert( !varList.empty() );
+    Assert(!varList.empty());
   }
 
   Monomial(const Constant& c, const VarList& vl):
     NodeWrapper(makeMultNode(c,vl)), constant(c), varList(vl)
   {
-    Assert( !c.isZero() );
-    Assert( !c.isOne() );
-    Assert( !varList.empty() );
+    Assert(!c.isZero());
+    Assert(!c.isOne());
+    Assert(!varList.empty());
 
     Assert(multStructured(getNode()));
   }
-
+public:
   static bool isMember(TNode n);
 
   /** Makes a monomial with no restrictions on c and vl. */
   static Monomial mkMonomial(const Constant& c, const VarList& vl);
 
+  /** If vl is empty, this make one. */
+  static Monomial mkMonomial(const VarList& vl);
+
+  static Monomial mkMonomial(const Constant& c){
+    return Monomial(c);
+  }
+  
   static Monomial mkMonomial(const Variable& v){
     return Monomial(VarList(v));
   }
@@ -692,7 +676,7 @@ public:
   }
   const Constant& getConstant() const { return constant; }
   const VarList& getVarList() const { return varList; }
-
+  
   bool isConstant() const {
     return varList.empty();
   }
@@ -735,11 +719,7 @@ public:
   }
 
   static bool isSorted(const std::vector<Monomial>& m) {
-#if IS_SORTED_IN_GNUCXX_NAMESPACE
-    return __gnu_cxx::is_sorted(m.begin(), m.end());
-#else /* IS_SORTED_IN_GNUCXX_NAMESPACE */
     return std::is_sorted(m.begin(), m.end());
-#endif /* IS_SORTED_IN_GNUCXX_NAMESPACE */
   }
 
   static bool isStrictlySorted(const std::vector<Monomial>& m) {
@@ -843,6 +823,25 @@ public:
     internal_iterator d_iter;
 
   public:
+    /* The following types are required by trait std::iterator_traits */
+
+    /** Iterator tag */
+    using iterator_category = std::forward_iterator_tag;
+
+    /** The type of the item */
+    using value_type = Monomial;
+
+    /** The pointer type of the item */
+    using pointer = Monomial*;
+
+    /** The reference type of the item */
+    using reference = Monomial&;
+
+    /** The type returned when two iterators are subtracted */
+    using difference_type = std::ptrdiff_t;
+
+    /* End of std::iterator_traits required types */
+
     explicit iterator(internal_iterator i) : d_iter(i) {}
 
     inline Monomial operator*() {
@@ -877,12 +876,16 @@ public:
   Polynomial(const std::vector<Monomial>& m):
     NodeWrapper(makePlusNode(m)), d_singleton(false)
   {
-    Assert( m.size() >= 2);
-    Assert( Monomial::isStrictlySorted(m) );
+    Assert(m.size() >= 2);
+    Assert(Monomial::isStrictlySorted(m));
+  }
+
+  static Polynomial mkPolynomial(const Constant& c){
+    return Polynomial(Monomial::mkMonomial(c));
   }
 
   static Polynomial mkPolynomial(const Variable& v){
-    return Monomial::mkMonomial(v);
+    return Polynomial(Monomial::mkMonomial(v));
   }
 
   static Polynomial mkPolynomial(const std::vector<Monomial>& m) {
@@ -931,12 +934,12 @@ public:
   }
 
   Polynomial getTail() const {
-    Assert(! singleton());
+    Assert(!singleton());
 
     iterator tailStart = begin();
     ++tailStart;
     std::vector<Monomial> subrange;
-    copy_range(tailStart, end(), subrange);
+    std::copy(tailStart, end(), std::back_inserter(subrange));
     return mkPolynomial(subrange);
   }
 
@@ -981,6 +984,18 @@ public:
   /** Returns true if the polynomial contains a non-linear monomial.*/
   bool isNonlinear() const;
 
+  /** Check whether this polynomial is only a single variable. */
+  bool isVariable() const
+  {
+    return singleton() && getHead().getVarList().singleton()
+           && getHead().coefficientIsOne();
+  }
+  /** Return the variable, given that isVariable() holds. */
+  Variable getVariable() const
+  {
+    Assert(isVariable());
+    return getHead().getVarList().getHead();
+  }
 
   /**
    * Selects a minimal monomial in the polynomial by the absolute value of
@@ -1016,13 +1031,8 @@ public:
    */
   Integer gcd() const;
 
-  Polynomial exactDivide(const Integer& z) const {
-    Assert(isIntegral());
-    Constant invz = Constant::mkConstant(Rational(1,z));
-    Polynomial prod = (*this) * Monomial(invz);
-    Assert(prod.isIntegral());
-    return prod;
-  }
+  /** z must divide all of the coefficients of the polynomial. */
+  Polynomial exactDivide(const Integer& z) const;
 
   Polynomial operator+(const Polynomial& vl) const;
   Polynomial operator-(const Polynomial& vl) const;
@@ -1132,14 +1142,9 @@ private:
     return NodeManager::currentNM()->mkNode(kind::PLUS, p.getNode(), c.getNode());
   }
 
-  SumPair(TNode n) :
-    NodeWrapper(n)
-  {
-    Assert(isNormalForm());
-  }
-
-public:
+  SumPair(TNode n) : NodeWrapper(n) { Assert(isNormalForm()); }
 
+ public:
   SumPair(const Polynomial& p):
     NodeWrapper(toNode(p, Constant::mkConstant(0)))
   {
@@ -1422,10 +1427,23 @@ public:
   Polynomial normalizedVariablePart() const;
   DeltaRational normalizedDeltaRational() const;
 
+  /**
+   * Transforms a Comparison object into a stronger normal form:
+   *    Polynomial ~Kind~ Constant
+   * 
+   * From the comparison, this method resolved a negation (if present) and
+   * moves everything to the left side.
+   * If split_constant is false, the constant is always zero.
+   * If split_constant is true, the polynomial has no constant term and is
+   * normalized to have leading coefficient one.
+   */
+  std::tuple<Polynomial, Kind, Constant> decompose(
+      bool split_constant = false) const;
+
 };/* class Comparison */
 
-}/* CVC4::theory::arith namespace */
-}/* CVC4::theory namespace */
-}/* CVC4 namespace */
+}  // namespace arith
+}  // namespace theory
+}  // namespace cvc5
 
-#endif /* __CVC4__THEORY__ARITH__NORMAL_FORM_H */
+#endif /* CVC5__THEORY__ARITH__NORMAL_FORM_H */