Improvements to phase shifting + purification lemmas (#8598)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Thu, 14 Apr 2022 18:00:30 +0000 (13:00 -0500)
committerGitHub <noreply@github.com>
Thu, 14 Apr 2022 18:00:30 +0000 (18:00 +0000)
Makes 3 improvements to phase shifting + purification lemmas:
(1) we use purification skolems when possible, e.g. for exp(t), or sin(c) where -pi <= c <= pi, since it is unnecessary to assert c is in proper phase. This also adds proofs for purification of exp, which is hence trivial.
(2) we make the phase shift variable a real for which is_int holds. This avoids mixed int/real arithmetic, in preparation for eliminating subtyping
(3) the proof checker and sine solver use a common utility for getting the phase shift lemma.

src/theory/arith/nl/transcendental/exponential_solver.cpp
src/theory/arith/nl/transcendental/proof_checker.cpp
src/theory/arith/nl/transcendental/sine_solver.cpp
src/theory/arith/nl/transcendental/sine_solver.h
src/theory/arith/nl/transcendental/transcendental_state.cpp
src/theory/arith/nl/transcendental/transcendental_state.h
src/theory/inference_id.cpp
src/theory/inference_id.h

index 9c2b6b3af160a81d5dd2477f8529a52404cbb7e3..27aaa7730c3546de0d3df5d121eae42db7769576 100644 (file)
@@ -46,13 +46,21 @@ ExponentialSolver::~ExponentialSolver() {}
 
 void ExponentialSolver::doPurification(TNode a, TNode new_a)
 {
+  Assert(TranscendentalState::isSimplePurify(a));
   NodeManager* nm = NodeManager::currentNM();
   // do both equalities to ensure that new_a becomes a preregistered term
   Node lem = nm->mkNode(Kind::AND, a.eqNode(new_a), a[0].eqNode(new_a[0]));
   // note we must do preprocess on this lemma
   Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : purify : " << lem
                         << std::endl;
-  d_data->d_im.addPendingLemma(lem, InferenceId::ARITH_NL_T_PURIFY_ARG);
+  CDProof* proof = nullptr;
+  if (d_data->isProofEnabled())
+  {
+    // simple to justify
+    proof = d_data->getProof();
+    proof->addStep(lem, PfRule::MACRO_SR_PRED_INTRO, {}, {lem});
+  }
+  d_data->d_im.addPendingLemma(lem, InferenceId::ARITH_NL_T_PURIFY_ARG, proof);
 }
 
 void ExponentialSolver::checkInitialRefine()
index 272f4ba0a5bca723283fc6cae7d1fe4951732b20..488e58ce7a30cc777e24fbefd415bc4d0050a61d 100644 (file)
@@ -17,6 +17,7 @@
 
 #include "expr/sequence.h"
 #include "theory/arith/arith_utilities.h"
+#include "theory/arith/nl/transcendental/sine_solver.h"
 #include "theory/arith/nl/transcendental/taylor_generator.h"
 #include "theory/evaluator.h"
 
@@ -226,22 +227,7 @@ Node TranscendentalProofRuleChecker::checkInternal(
     const auto& x = args[0];
     const auto& y = args[1];
     const auto& s = args[2];
-    return nm->mkAnd(std::vector<Node>{
-        nm->mkAnd(std::vector<Node>{
-            nm->mkNode(Kind::GEQ, y, nm->mkNode(Kind::MULT, mone, pi)),
-            nm->mkNode(Kind::LEQ, y, pi)}),
-        nm->mkNode(
-            Kind::ITE,
-            nm->mkAnd(std::vector<Node>{
-                nm->mkNode(Kind::GEQ, x, nm->mkNode(Kind::MULT, mone, pi)),
-                nm->mkNode(Kind::LEQ, x, pi),
-            }),
-            x.eqNode(y),
-            x.eqNode(nm->mkNode(
-                Kind::ADD,
-                y,
-                nm->mkNode(Kind::MULT, nm->mkConstReal(2), s, pi)))),
-        nm->mkNode(Kind::SINE, y).eqNode(nm->mkNode(Kind::SINE, x))});
+    return SineSolver::getPhaseShiftLemma(x, y, s);
   }
   else if (id == PfRule::ARITH_TRANS_SINE_SYMMETRY)
   {
index dfd052018122c99a06ce0416523ee943486e9535..40de74d0a74eb31af5022511ffa6ef57ea583694 100644 (file)
@@ -37,19 +37,6 @@ namespace theory {
 namespace arith {
 namespace nl {
 namespace transcendental {
-namespace {
-
-/**
- * Ensure a is in the main phase:
- *   -pi <= a <= pi
- */
-inline Node mkValidPhase(TNode a, TNode pi)
-{
-  NodeManager* nm = NodeManager::currentNM();
-  return mkBounded(
-      nm->mkNode(Kind::MULT, nm->mkConstReal(Rational(-1)), pi), a, pi);
-}
-}  // namespace
 
 SineSolver::SineSolver(Env& env, TranscendentalState* tstate)
     : EnvObj(env), d_data(tstate)
@@ -159,39 +146,64 @@ void SineSolver::doReductions()
   }
 }
 
+Node SineSolver::getPhaseShiftLemma(const Node& x, const Node& y, const Node& s)
+{
+  NodeManager* nm = NodeManager::currentNM();
+  Node mone = nm->mkConstReal(Rational(-1));
+  Node pi = nm->mkNullaryOperator(nm->realType(), PI);
+  return nm->mkAnd(std::vector<Node>{
+      nm->mkNode(GEQ, y, nm->mkNode(MULT, mone, pi)),
+      nm->mkNode(LEQ, y, pi),
+      nm->mkNode(IS_INTEGER, s),
+      nm->mkNode(ITE,
+                 nm->mkAnd(std::vector<Node>{
+                     nm->mkNode(GEQ, x, nm->mkNode(MULT, mone, pi)),
+                     nm->mkNode(LEQ, x, pi),
+                 }),
+                 x.eqNode(y),
+                 x.eqNode(nm->mkNode(
+                     ADD, y, nm->mkNode(MULT, nm->mkConstReal(2), s, pi)))),
+      nm->mkNode(SINE, y).eqNode(nm->mkNode(SINE, x))});
+}
+
 void SineSolver::doPhaseShift(TNode a, TNode new_a)
 {
-  TNode y = new_a[0];
   NodeManager* nm = NodeManager::currentNM();
   SkolemManager* sm = nm->getSkolemManager();
   Assert(a.getKind() == Kind::SINE);
-  Trace("nl-ext-tf") << "Basis sine : " << new_a << " for " << a << std::endl;
-  Node shift = sm->mkDummySkolem("s", nm->integerType(), "number of shifts");
-  // TODO (cvc4-projects #47) : do not introduce shift here, instead needs model-based
-  // refinement for constant shifts (cvc4-projects #1284)
-  Node lem = nm->mkNode(
-      Kind::AND,
-      mkValidPhase(y, d_pi),
-      nm->mkNode(
-          Kind::ITE,
-          mkValidPhase(a[0], d_pi),
-          a[0].eqNode(y),
-          a[0].eqNode(nm->mkNode(
-              Kind::ADD,
-              y,
-              nm->mkNode(
-                  Kind::MULT, nm->mkConstReal(Rational(2)), shift, d_pi)))),
-      new_a.eqNode(a));
   CDProof* proof = nullptr;
-  if (d_data->isProofEnabled())
+  Node lem;
+  Trace("nl-ext-tf") << "Basis sine : " << new_a << " for " << a << std::endl;
+  InferenceId iid;
+  if (TranscendentalState::isSimplePurify(a))
   {
-    proof = d_data->getProof();
-    proof->addStep(lem, PfRule::ARITH_TRANS_SINE_SHIFT, {}, {a[0], y, shift});
+    lem = nm->mkNode(Kind::AND, a.eqNode(new_a), a[0].eqNode(new_a[0]));
+    if (d_data->isProofEnabled())
+    {
+      // simple to justify
+      proof = d_data->getProof();
+      proof->addStep(lem, PfRule::MACRO_SR_PRED_INTRO, {}, {lem});
+    }
+    iid = InferenceId::ARITH_NL_T_PURIFY_ARG;
+  }
+  else
+  {
+    Node shift = sm->mkDummySkolem("s", nm->realType(), "number of shifts");
+    // TODO (cvc4-projects #47) : do not introduce shift here, instead needs
+    // model-based refinement for constant shifts (cvc4-projects #1284)
+    lem = getPhaseShiftLemma(a[0], new_a[0], shift);
+    if (d_data->isProofEnabled())
+    {
+      proof = d_data->getProof();
+      proof->addStep(
+          lem, PfRule::ARITH_TRANS_SINE_SHIFT, {}, {a[0], new_a[0], shift});
+    }
+    iid = InferenceId::ARITH_NL_T_PURIFY_ARG_PHASE_SHIFT;
   }
   // note we must do preprocess on this lemma
   Trace("nl-ext-lemma") << "NonlinearExtension::Lemma : purify : " << lem
                         << std::endl;
-  d_data->d_im.addPendingLemma(lem, InferenceId::ARITH_NL_T_PURIFY_ARG, proof);
+  d_data->d_im.addPendingLemma(lem, iid, proof);
 }
 
 void SineSolver::checkInitialRefine()
index c2f5b327e8a487ddbd3a0265d53a9f60374b0050..0f6f127db607d1663d08ad6198869d90b420b22f 100644 (file)
@@ -114,6 +114,14 @@ class SineSolver : protected EnvObj
    */
   bool hasExactModelValue(TNode n) const;
 
+  /**
+   * Make the lemma for the phase shift of arguments to SINE x and y, where
+   * s is the (integral) shift. The lemma conceptually says that y is
+   * in the bounds [-pi, pi] and y is offset from x by an integral factor of
+   * 2*pi.
+   */
+  static Node getPhaseShiftLemma(const Node& x, const Node& y, const Node& s);
+
  private:
   std::pair<Node, Node> getSecantBounds(TNode e,
                                         TNode c,
index efb1ab8c342d33b4ea24ad553e9c0ca1052d101b..470862311ea4f7b27eaf136a17dd15127788c6f3 100644 (file)
@@ -222,8 +222,8 @@ void TranscendentalState::mkPi()
   {
     d_pi = nm->mkNullaryOperator(nm->realType(), Kind::PI);
     // initialize bounds
-    d_pi_bound[0] = nm->mkConstReal(Rational(103993) / Rational(33102));
-    d_pi_bound[1] = nm->mkConstReal(Rational(104348) / Rational(33215));
+    d_pi_bound[0] = nm->mkConstReal(getPiInitialLowerBound());
+    d_pi_bound[1] = nm->mkConstReal(getPiInitialUpperBound());
   }
 }
 
@@ -472,8 +472,16 @@ Node TranscendentalState::getPurifiedForm(TNode n)
   }
   Kind k = n.getKind();
   Assert(k == Kind::SINE || k == Kind::EXPONENTIAL);
-  Node y = sm->mkSkolemFunction(
-      SkolemFunId::TRANSCENDENTAL_PURIFY_ARG, nm->realType(), n);
+  Node y;
+  if (isSimplePurify(n))
+  {
+    y = sm->mkPurifySkolem(n[0], "transk");
+  }
+  else
+  {
+    y = sm->mkSkolemFunction(
+        SkolemFunId::TRANSCENDENTAL_PURIFY_ARG, nm->realType(), n);
+  }
   Node new_n = nm->mkNode(k, y);
   d_trPurify[n] = new_n;
   d_trPurify[new_n] = new_n;
@@ -482,6 +490,22 @@ Node TranscendentalState::getPurifiedForm(TNode n)
   return new_n;
 }
 
+bool TranscendentalState::isSimplePurify(TNode n)
+{
+  if (n.getKind() != kind::SINE)
+  {
+    return true;
+  }
+  if (!n[0].isConst())
+  {
+    return false;
+  }
+  Rational r = n[0].getConst<Rational>();
+  // use a fixed value of pi
+  Rational piLower = getPiInitialLowerBound();
+  return -piLower <= r && r <= piLower;
+}
+
 bool TranscendentalState::addModelBoundForPurifyTerm(TNode n, TNode l, TNode u)
 {
   Assert(d_funcCongClass.find(n) != d_funcCongClass.end());
@@ -509,6 +533,16 @@ bool TranscendentalState::addModelBoundForPurifyTerm(TNode n, TNode l, TNode u)
   return true;
 }
 
+Rational TranscendentalState::getPiInitialLowerBound()
+{
+  return Rational(103993) / Rational(33102);
+}
+
+Rational TranscendentalState::getPiInitialUpperBound()
+{
+  return Rational(104348) / Rational(33215);
+}
+
 }  // namespace transcendental
 }  // namespace nl
 }  // namespace arith
index cbb24a59bc00c3a24b454d19f8df7220d63ff216..b424b58af92e3acd9f77e9f67704b2862a119d25 100644 (file)
@@ -169,10 +169,22 @@ class TranscendentalState : protected EnvObj
   bool isPurified(TNode n) const;
   /** get the purified form of node n */
   Node getPurifiedForm(TNode n);
+  /**
+   * Can we do "simple" purification for n? If this is the case, then
+   * f(x) is purified by f(k) where k is the purification variable for x.
+   *
+   * This is true for sin(x) where x is guaranteed to be a constant in the
+   * bound [-pi, pi] (note that there may be some x in [-pi, pi] for which
+   * this function returns false, because the check is not precise).
+   */
+  static bool isSimplePurify(TNode n);
   /**
    * Add bound for n, and for what (if anything) it purifies
    */
   bool addModelBoundForPurifyTerm(TNode n, TNode l, TNode u);
+  /** initial lower and upper bounds for PI */
+  static Rational getPiInitialLowerBound();
+  static Rational getPiInitialUpperBound();
 
   Node d_true;
   Node d_false;
index f11c6b87fb358e934e2cc5d3e95fd424aa81ab4f..22e693f9ff2f1608c7c608c2348c7d89102971ed 100644 (file)
@@ -79,6 +79,8 @@ const char* toString(InferenceId i)
     case InferenceId::ARITH_NL_T_SINE_BOUNDARY_REDUCE:
       return "ARITH_NL_T_SINE_BOUNDARY_REDUCE";
     case InferenceId::ARITH_NL_T_PURIFY_ARG: return "ARITH_NL_T_PURIFY_ARG";
+    case InferenceId::ARITH_NL_T_PURIFY_ARG_PHASE_SHIFT:
+      return "ARITH_NL_T_PURIFY_ARG_PHASE_SHIFT";
     case InferenceId::ARITH_NL_T_INIT_REFINE: return "ARITH_NL_T_INIT_REFINE";
     case InferenceId::ARITH_NL_T_PI_BOUND: return "ARITH_NL_T_PI_BOUND";
     case InferenceId::ARITH_NL_T_MONOTONICITY: return "ARITH_NL_T_MONOTONICITY";
index a687fac3cc85a97ed70ee10a2b6fbf0168fece5b..9f0692a6fb39ee6f59dafd45c1f836f94f2e7145 100644 (file)
@@ -130,6 +130,8 @@ enum class InferenceId
   ARITH_NL_T_SINE_BOUNDARY_REDUCE,
   // purification of arguments to transcendental functions
   ARITH_NL_T_PURIFY_ARG,
+  // purification of arguments to transcendental functions with phase shifting
+  ARITH_NL_T_PURIFY_ARG_PHASE_SHIFT,
   // initial refinement (TranscendentalSolver::checkTranscendentalInitialRefine)
   ARITH_NL_T_INIT_REFINE,
   // pi bounds