Make CEGQI term type to enum (#3256)
authorAndrew Reynolds <andrew.j.reynolds@gmail.com>
Fri, 6 Sep 2019 16:46:21 +0000 (11:46 -0500)
committerGitHub <noreply@github.com>
Fri, 6 Sep 2019 16:46:21 +0000 (11:46 -0500)
src/theory/quantifiers/cegqi/ceg_arith_instantiator.cpp
src/theory/quantifiers/cegqi/ceg_arith_instantiator.h
src/theory/quantifiers/cegqi/ceg_epr_instantiator.cpp
src/theory/quantifiers/cegqi/ceg_instantiator.cpp
src/theory/quantifiers/cegqi/ceg_instantiator.h

index b1a2a2533494de5421b881c022c878b08e8b713b..0c3f1f69bf246fb594a7e9d76795272e2cc44e62 100644 (file)
@@ -102,11 +102,11 @@ bool ArithInstantiator::processEquality(CegInstantiator* ci,
   Node vts_coeff_inf;
   Node vts_coeff_delta;
   // isolate pv in the equality
-  int ires = solve_arith(
+  CegTermType ires = solve_arith(
       ci, pv, eq, pv_prop.d_coeff, val, vts_coeff_inf, vts_coeff_delta);
-  if (ires != 0)
+  if (ires != CEG_TT_INVALID)
   {
-    pv_prop.d_type = 0;
+    pv_prop.d_type = CEG_TT_EQUAL;
     if (ci->constructInstantiationInc(pv, val, pv_prop, sf))
     {
       return true;
@@ -160,9 +160,9 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
   Node val;
   TermProperties pv_prop;
   // isolate pv in the inequality
-  int ires = solve_arith(
+  CegTermType ires = solve_arith(
       ci, pv, atom, pv_prop.d_coeff, val, vts_coeff_inf, vts_coeff_delta);
-  if (ires == 0)
+  if (ires == CEG_TT_INVALID)
   {
     return false;
   }
@@ -174,14 +174,14 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
   }
   for (unsigned r = 0; r < rmax; r++)
   {
-    int uires = ires;
+    CegTermType uires = ires;
     Node uval = val;
     if (atom.getKind() == GEQ)
     {
       // push negation downwards
       if (!pol)
       {
-        uires = -ires;
+        uires = mkNegateCTT(ires);
         if (d_type.isInteger())
         {
           uval = nm->mkNode(PLUS, val, nm->mkConst(Rational(uires)));
@@ -191,14 +191,14 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
         {
           Assert(d_type.isReal());
           // now is strict inequality
-          uires = uires * 2;
+          uires = mkStrictCTT(uires);
         }
       }
     }
     else if (pol)
     {
       // equalities are both non-strict upper and lower bounds
-      uires = r == 0 ? 1 : -1;
+      uires = r == 0 ? CEG_TT_UPPER : CEG_TT_LOWER;
     }
     else
     {
@@ -248,14 +248,14 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
       Assert(atom.getKind() == EQUAL && !pol);
       if (d_type.isInteger())
       {
-        uires = is_upper ? -1 : 1;
+        uires = is_upper ? CEG_TT_LOWER : CEG_TT_UPPER;
         uval = nm->mkNode(PLUS, val, nm->mkConst(Rational(uires)));
         uval = Rewriter::rewrite(uval);
       }
       else
       {
         Assert(d_type.isReal());
-        uires = is_upper ? -2 : 2;
+        uires = is_upper ? CEG_TT_LOWER_STRICT : CEG_TT_UPPER_STRICT;
       }
     }
     if (Trace.isOn("cegqi-arith-bound-inf"))
@@ -266,11 +266,12 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
           << pvmod << " -> " << uval << ", styp = " << uires << std::endl;
     }
     // take into account delta
-    if (uires == 2 || uires == -2)
+    if (uires == CEG_TT_UPPER_STRICT || uires == CEG_TT_LOWER_STRICT)
     {
       if (options::cbqiModel())
       {
-        Node delta_coeff = nm->mkConst(Rational(uires > 0 ? 1 : -1));
+        Node delta_coeff =
+            nm->mkConst(Rational(isUpperBoundCTT(uires) ? 1 : -1));
         if (vts_coeff_delta.isNull())
         {
           vts_coeff_delta = delta_coeff;
@@ -284,14 +285,15 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
       else
       {
         Node delta = ci->getQuantifiersEngine()->getTermUtil()->getVtsDelta();
-        uval = nm->mkNode(uires == 2 ? PLUS : MINUS, uval, delta);
+        uval = nm->mkNode(
+            uires == CEG_TT_UPPER_STRICT ? PLUS : MINUS, uval, delta);
         uval = Rewriter::rewrite(uval);
       }
     }
     if (options::cbqiModel())
     {
       // just store bounds, will choose based on tighest bound
-      unsigned index = uires > 0 ? 0 : 1;
+      unsigned index = isUpperBoundCTT(uires) ? 0 : 1;
       d_mbp_bounds[index].push_back(uval);
       d_mbp_coeff[index].push_back(pv_prop.d_coeff);
       Trace("cegqi-arith-debug")
@@ -308,7 +310,7 @@ bool ArithInstantiator::processAssertion(CegInstantiator* ci,
     else
     {
       // try this bound
-      pv_prop.d_type = uires > 0 ? 1 : -1;
+      pv_prop.d_type = isUpperBoundCTT(uires) ? CEG_TT_UPPER : CEG_TT_LOWER;
       if (ci->constructInstantiationInc(pv, uval, pv_prop, sf))
       {
         return true;
@@ -520,7 +522,7 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci,
           {
             TermProperties pv_prop_bound;
             pv_prop_bound.d_coeff = d_mbp_coeff[rr][best];
-            pv_prop_bound.d_type = rr == 0 ? 1 : -1;
+            pv_prop_bound.d_type = rr == 0 ? CEG_TT_UPPER : CEG_TT_LOWER;
             if (ci->constructInstantiationInc(pv, val, pv_prop_bound, sf))
             {
               return true;
@@ -665,7 +667,7 @@ bool ArithInstantiator::processAssertions(CegInstantiator* ci,
         {
           TermProperties pv_prop_nopt_bound;
           pv_prop_nopt_bound.d_coeff = d_mbp_coeff[rr][j];
-          pv_prop_nopt_bound.d_type = rr == 0 ? 1 : -1;
+          pv_prop_nopt_bound.d_type = rr == 0 ? CEG_TT_UPPER : CEG_TT_LOWER;
           if (ci->constructInstantiationInc(pv, val, pv_prop_nopt_bound, sf))
           {
             return true;
@@ -745,7 +747,8 @@ bool ArithInstantiator::postProcessInstantiationForVariable(
     Trace("cegqi-arith-debug")
         << "...bound type is : " << sf.d_props[index].d_type << std::endl;
     // intger division rounding up if from a lower bound
-    if (sf.d_props[index].d_type == 1 && options::cbqiRoundUpLowerLia())
+    if (sf.d_props[index].d_type == CEG_TT_UPPER
+        && options::cbqiRoundUpLowerLia())
     {
       sf.d_subs[index] = nm->mkNode(
           PLUS,
@@ -763,16 +766,15 @@ bool ArithInstantiator::postProcessInstantiationForVariable(
   return true;
 }
 
-int ArithInstantiator::solve_arith(CegInstantiator* ci,
-                                   Node pv,
-                                   Node atom,
-                                   Node& veq_c,
-                                   Node& val,
-                                   Node& vts_coeff_inf,
-                                   Node& vts_coeff_delta)
+CegTermType ArithInstantiator::solve_arith(CegInstantiator* ci,
+                                           Node pv,
+                                           Node atom,
+                                           Node& veq_c,
+                                           Node& val,
+                                           Node& vts_coeff_inf,
+                                           Node& vts_coeff_delta)
 {
   NodeManager* nm = NodeManager::currentNM();
-  int ires = 0;
   Trace("cegqi-arith-debug")
       << "isolate for " << pv << " in " << atom << std::endl;
   std::map<Node, Node> msum;
@@ -780,7 +782,7 @@ int ArithInstantiator::solve_arith(CegInstantiator* ci,
   {
     Trace("cegqi-arith-debug")
         << "fail : could not get monomial sum" << std::endl;
-    return 0;
+    return CEG_TT_INVALID;
   }
   Trace("cegqi-arith-debug") << "got monomial sum: " << std::endl;
   if (Trace.isOn("cegqi-arith-debug"))
@@ -834,11 +836,11 @@ int ArithInstantiator::solve_arith(CegInstantiator* ci,
     }
   }
 
-  ires = ArithMSum::isolate(pv, msum, veq_c, val, atom.getKind());
+  int ires = ArithMSum::isolate(pv, msum, veq_c, val, atom.getKind());
   if (ires == 0)
   {
     Trace("cegqi-arith-debug") << "fail : isolate" << std::endl;
-    return 0;
+    return CEG_TT_INVALID;
   }
   if (Trace.isOn("cegqi-arith-debug"))
   {
@@ -854,7 +856,7 @@ int ArithInstantiator::solve_arith(CegInstantiator* ci,
   if (expr::hasSubterm(val, pv))
   {
     Trace("cegqi-arith-debug") << "fail : contains bad term" << std::endl;
-    return 0;
+    return CEG_TT_INVALID;
   }
   // if its type is integer but the substitution is not integer
   if (pvtn.isInteger()
@@ -938,6 +940,10 @@ int ArithInstantiator::solve_arith(CegInstantiator* ci,
       Trace("cegqi-arith-debug") << "result : " << val << std::endl;
       Assert(val.getType().isInteger());
     }
+    else
+    {
+      return CEG_TT_INVALID;
+    }
   }
   vts_coeff_inf = vts_coeff[0];
   vts_coeff_delta = vts_coeff[1];
@@ -945,7 +951,12 @@ int ArithInstantiator::solve_arith(CegInstantiator* ci,
       << "Return " << veq_c << " * " << pv << " " << atom.getKind() << " "
       << val << ", vts = (" << vts_coeff_inf << ", " << vts_coeff_delta << ")"
       << std::endl;
-  return ires;
+  Assert(ires != 0);
+  if (atom.getKind() == EQUAL)
+  {
+    return CEG_TT_EQUAL;
+  }
+  return ires == 1 ? CEG_TT_UPPER : CEG_TT_LOWER;
 }
 
 Node ArithInstantiator::getModelBasedProjectionValue(CegInstantiator* ci,
index ee3e3e27d24d241effa11775f48eedd351787de9..8ae5383a52d8d07d5dc232472d24e554106daffd 100644 (file)
@@ -150,14 +150,20 @@ class ArithInstantiator : public Instantiator
    *    veq_c * pv <> val + vts_coeff_delta * delta + vts_coeff_inf * inf
    * where we ensure val has Int type if pv has Int type, and val does not
    * contain vts symbols.
+   *
+   * It returns a CegTermType:
+   *   CEG_TT_INVALID if it was not possible to put atom into a solved form,
+   *   CEG_TT_LOWER if <> in the above equation is >=,
+   *   CEG_TT_UPPER if <> in the above equation is <=, or
+   *   CEG_TT_EQUAL if <> in the above equation is =.
    */
-  int solve_arith(CegInstantiator* ci,
-                  Node v,
-                  Node atom,
-                  Node& veq_c,
-                  Node& val,
-                  Node& vts_coeff_inf,
-                  Node& vts_coeff_delta);
+  CegTermType solve_arith(CegInstantiator* ci,
+                          Node v,
+                          Node atom,
+                          Node& veq_c,
+                          Node& val,
+                          Node& vts_coeff_inf,
+                          Node& vts_coeff_delta);
   /** get model based projection value
    *
    * Given a implied (non-strict) bound:
index 2aa2a927bd7f2d1a140e4f2b98cf1483a91f1408..15d426345a9ff2a502a550cc2a7bef4dd9c2ea03 100644 (file)
@@ -57,7 +57,7 @@ bool EprInstantiator::processEqualTerm(CegInstantiator* ci,
     d_equal_terms.push_back(n);
     return false;
   }
-  pv_prop.d_type = 0;
+  pv_prop.d_type = CEG_TT_EQUAL;
   return ci->constructInstantiationInc(pv, n, pv_prop, sf);
 }
 
@@ -93,7 +93,7 @@ bool EprInstantiator::processEqualTerms(CegInstantiator* ci,
   // sort by match score
   std::sort(d_equal_terms.begin(), d_equal_terms.end(), setm);
   TermProperties pv_prop;
-  pv_prop.d_type = 0;
+  pv_prop.d_type = CEG_TT_EQUAL;
   for (unsigned i = 0, size = d_equal_terms.size(); i < size; i++)
   {
     if (ci->constructInstantiationInc(pv, d_equal_terms[i], pv_prop, sf))
index e2a6432dbd6e1d9d2ad7d33ec33744655d320c12..67985527e82d2bde580a33bd921047baffe19570 100644 (file)
@@ -41,6 +41,53 @@ namespace CVC4 {
 namespace theory {
 namespace quantifiers {
 
+CegTermType mkStrictCTT(CegTermType c)
+{
+  Assert(!isStrictCTT(c));
+  if (c == CEG_TT_LOWER)
+  {
+    return CEG_TT_LOWER_STRICT;
+  }
+  else if (c == CEG_TT_UPPER)
+  {
+    return CEG_TT_UPPER_STRICT;
+  }
+  return c;
+}
+
+CegTermType mkNegateCTT(CegTermType c)
+{
+  if (c == CEG_TT_LOWER)
+  {
+    return CEG_TT_UPPER;
+  }
+  else if (c == CEG_TT_UPPER)
+  {
+    return CEG_TT_LOWER;
+  }
+  else if (c == CEG_TT_LOWER_STRICT)
+  {
+    return CEG_TT_UPPER_STRICT;
+  }
+  else if (c == CEG_TT_UPPER_STRICT)
+  {
+    return CEG_TT_LOWER_STRICT;
+  }
+  return c;
+}
+bool isStrictCTT(CegTermType c)
+{
+  return c == CEG_TT_LOWER_STRICT && c == CEG_TT_UPPER_STRICT;
+}
+bool isLowerBoundCTT(CegTermType c)
+{
+  return c == CEG_TT_LOWER || c == CEG_TT_LOWER_STRICT;
+}
+bool isUpperBoundCTT(CegTermType c)
+{
+  return c == CEG_TT_UPPER || c == CEG_TT_UPPER_STRICT;
+}
+
 std::ostream& operator<<(std::ostream& os, CegInstEffort e)
 {
   switch (e)
@@ -1739,7 +1786,7 @@ bool Instantiator::processEqualTerm(CegInstantiator* ci,
                                     Node n,
                                     CegInstEffort effort)
 {
-  pv_prop.d_type = 0;
+  pv_prop.d_type = CEG_TT_EQUAL;
   return ci->constructInstantiationInc(pv, n, pv_prop, sf);
 }
 
index 8110dcd956a177474fde5ecbe9012e26f5774483..76e0869fa794f63b1d3c2b89ba313b1ff5872323 100644 (file)
@@ -34,6 +34,35 @@ class Instantiator;
 class InstantiatorPreprocess;
 class InstStrategyCegqi;
 
+/**
+ * Descriptions of the types of constraints that a term was solved for in.
+ */
+enum CegTermType
+{
+  // invalid
+  CEG_TT_INVALID,
+  // term was the result of solving an equality
+  CEG_TT_EQUAL,
+  // term was the result of solving a non-strict lower bound x >= t
+  CEG_TT_LOWER,
+  // term was the result of solving a strict lower bound x > t
+  CEG_TT_LOWER_STRICT,
+  // term was the result of solving a non-strict upper bound x <= t
+  CEG_TT_UPPER,
+  // term was the result of solving a strict upper bound x < t
+  CEG_TT_UPPER_STRICT,
+};
+/** make (non-strict term type) c a strict term type */
+CegTermType mkStrictCTT(CegTermType c);
+/** negate c (lower/upper bounds are swapped) */
+CegTermType mkNegateCTT(CegTermType c);
+/** is c a strict term type? */
+bool isStrictCTT(CegTermType c);
+/** is c a lower bound? */
+bool isLowerBoundCTT(CegTermType c);
+/** is c an upper bound? */
+bool isUpperBoundCTT(CegTermType c);
+
 /** Term Properties
  *
  * Stores properties for a variable to solve for in counterexample-guided
@@ -43,13 +72,15 @@ class InstStrategyCegqi;
  * for the variable.
  */
 class TermProperties {
-public:
-  TermProperties() : d_type(0) {}
+ public:
+  TermProperties() : d_type(CEG_TT_EQUAL) {}
   virtual ~TermProperties() {}
 
-  // type of property for a term
-  //  for arithmetic this corresponds to bound type (0:equal, 1:upper bound, -1:lower bound)
-  int d_type;
+  /**
+   * Type for the solution term. For arithmetic this corresponds to bound type
+   * of the constraint that the constraint the term was solved for in.
+   */
+  CegTermType d_type;
   // for arithmetic
   Node d_coeff;
   // get cache node