CBQI BV: Refactor solve_bv_constraint. (#1265)
authorAina Niemetz <aina.niemetz@gmail.com>
Thu, 19 Oct 2017 19:31:42 +0000 (12:31 -0700)
committerGitHub <noreply@github.com>
Thu, 19 Oct 2017 19:31:42 +0000 (12:31 -0700)
This refactors function solve_bv_constraint to use a switch-case over kinds rather than an if-else chain.

src/theory/quantifiers/bv_inverter.cpp

index 8a65338a61329cd331c3bc26dd21b2dd875e8f3c..ad1259be025f40332ae94c81b7f2d9dbac81fc9c 100644 (file)
@@ -261,19 +261,25 @@ Node BvInverter::getPathToPv(Node lit, Node pv, Node sv, Node pvs,
   return slit;
 }
 
-Node BvInverter::solve_bv_constraint(Node sv, Node sv_t, Node t, Kind rk,
-                                     bool pol, std::vector<unsigned>& path,
+Node BvInverter::solve_bv_constraint(Node sv,
+                                     Node sv_t,
+                                     Node t,
+                                     Kind rk,
+                                     bool pol,
+                                     std::vector<unsigned>& path,
                                      BvInverterModelQuery* m,
                                      BvInverterStatus& status) {
+  unsigned index;
+  unsigned nchildren;
   NodeManager* nm = NodeManager::currentNM();
+
   while (!path.empty()) {
-    unsigned index = path.back();
+    index = path.back();
     Assert(index < sv_t.getNumChildren());
     path.pop_back();
     Kind k = sv_t.getKind();
-    unsigned nchildren = sv_t.getNumChildren();
+    nchildren = sv_t.getNumChildren();
 
-    /* inversions  */
     if (k == BITVECTOR_CONCAT) {
       /* x = t[upper:lower]
        * where
@@ -302,219 +308,239 @@ Node BvInverter::solve_bv_constraint(Node sv, Node sv_t, Node t, Kind rk,
       Node s = nchildren == 2 ? sv_t[1 - index] : dropChild(sv_t, index);
       /* Note: All n-ary kinds except for CONCAT (i.e., AND, OR, MULT, PLUS)
        *       are commutative (no case split based on index). */
-      if (k == BITVECTOR_PLUS) {
-        t = nm->mkNode(BITVECTOR_SUB, t, s);
-      } else if (k == BITVECTOR_SUB) {
-        t = nm->mkNode(BITVECTOR_PLUS, t, s);
-      } else if (k == BITVECTOR_MULT) {
-        /* t = skv (fresh skolem constant)
-         * with side condition:
-         * ctz(t) >= ctz(s) <-> x * s = t
-         * where
-         * ctz(t) >= ctz(s) -> (t & -t) >= (s & -s)  */
-        TypeNode solve_tn = sv_t[index].getType();
-        Node x = getSolveVariable(solve_tn);
-        /* left hand side of side condition  */
-        Node scl = nm->mkNode(
-            BITVECTOR_UGE,
-            nm->mkNode(BITVECTOR_AND, t, nm->mkNode(BITVECTOR_NEG, t)),
-            nm->mkNode(BITVECTOR_AND, s, nm->mkNode(BITVECTOR_NEG, s)));
-        /* right hand side of side condition  */
-        Node scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_MULT, x, s), t);
-        /* overall side condition  */
-        Node sc = nm->mkNode(IMPLIES, scl, scr);
-        /* add side condition  */
-        status.d_conds.push_back(sc);
-
-        /* get the skolem node for this side condition  */
-        Node skv = getInversionNode(sc, solve_tn);
-        /* now solving with the skolem node as the RHS  */
-        t = skv;
-      } else if (k == BITVECTOR_UREM_TOTAL) {
-        /* t = skv (fresh skolem constant)  */
-        TypeNode solve_tn = sv_t[index].getType();
-        Node x = getSolveVariable(solve_tn);
-        Node scl, scr;
-        if (index == 0) {
-          /* x % s = t is rewritten to x - x / y * y */
-          Trace("bv-invert") << "bv-invert : Unsupported for index " << index
-                             << ", from " << sv_t << std::endl;
-          return Node::null();
-        } else {
-          /* s % x = t
-           * with side conditions:
-           * s > t
-           * && s-t > t
-           * && (t = 0 || t != s-1)  */
-          Node s_gt_t = nm->mkNode(BITVECTOR_UGT, s, t);
-          Node s_m_t = nm->mkNode(BITVECTOR_SUB, s, t);
-          Node smt_gt_t = nm->mkNode(BITVECTOR_UGT, s_m_t, t);
-          Node t_eq_z = nm->mkNode(EQUAL,
-              t, bv::utils::mkZero(bv::utils::getSize(t)));
-          Node s_m_o = nm->mkNode(BITVECTOR_SUB,
-              s, bv::utils::mkOne(bv::utils::getSize(s)));
-          Node t_d_smo = nm->mkNode(DISTINCT, t, s_m_o);
-
-          scl = nm->mkNode(AND,
-              nm->mkNode(AND, s_gt_t, smt_gt_t),
-              nm->mkNode(OR, t_eq_z, t_d_smo));
-          scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UREM_TOTAL, s, x), t);
+      switch(k) {
+        case BITVECTOR_PLUS:
+          t = nm->mkNode(BITVECTOR_SUB, t, s);
+          break;
+        case BITVECTOR_SUB:
+          t = nm->mkNode(BITVECTOR_PLUS, t, s);
+          break;
+
+        case BITVECTOR_MULT: {
+          /* t = skv (fresh skolem constant)
+           * with side condition:
+           * ctz(t) >= ctz(s) <-> x * s = t
+           * where
+           * ctz(t) >= ctz(s) -> (t & -t) >= (s & -s)  */
+          TypeNode solve_tn = sv_t[index].getType();
+          Node x = getSolveVariable(solve_tn);
+          /* left hand side of side condition  */
+          Node scl = nm->mkNode(
+              BITVECTOR_UGE,
+              nm->mkNode(BITVECTOR_AND, t, nm->mkNode(BITVECTOR_NEG, t)),
+              nm->mkNode(BITVECTOR_AND, s, nm->mkNode(BITVECTOR_NEG, s)));
+          /* right hand side of side condition  */
+          Node scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_MULT, x, s), t);
+          /* overall side condition  */
+          Node sc = nm->mkNode(IMPLIES, scl, scr);
+          /* add side condition  */
+          status.d_conds.push_back(sc);
+
+          /* get the skolem node for this side condition  */
+          Node skv = getInversionNode(sc, solve_tn);
+          /* now solving with the skolem node as the RHS  */
+          t = skv;
+          break;
         }
-        Node sc = nm->mkNode(IMPLIES, scl, scr);
-        status.d_conds.push_back(sc);
-        Node skv = getInversionNode(sc, solve_tn);
-        t = skv;
-      } else if (k == BITVECTOR_AND || k == BITVECTOR_OR) {
-        /* t = skv (fresh skolem constant)
-         * with side condition:
-         * t & s = t
-         * t | s = t */
-        TypeNode solve_tn = sv_t[index].getType();
-        Node x = getSolveVariable(solve_tn);
-        Node scl = nm->mkNode(EQUAL, t, nm->mkNode(k, t, s));
-        Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
-        Node sc = nm->mkNode(IMPLIES, scl, scr);
-        status.d_conds.push_back(sc);
-        Node skv = getInversionNode(sc, solve_tn);
-        t = skv;
-      } else if (k == BITVECTOR_LSHR) {
-        /* t = skv (fresh skolem constant)  */
-        TypeNode solve_tn = sv_t[index].getType();
-        Node x = getSolveVariable(solve_tn);
-        Node scl, scr;
-        if (index == 0) {
-          /* x >> s = t
+
+        case BITVECTOR_UREM_TOTAL: {
+          /* t = skv (fresh skolem constant)  */
+          TypeNode solve_tn = sv_t[index].getType();
+          Node x = getSolveVariable(solve_tn);
+          Node scl, scr;
+          if (index == 0) {
+            /* x % s = t is rewritten to x - x / y * y */
+            Trace("bv-invert") << "bv-invert : Unsupported for index " << index
+                               << ", from " << sv_t << std::endl;
+            return Node::null();
+          } else {
+            /* s % x = t
+             * with side conditions:
+             * s > t
+             * && s-t > t
+             * && (t = 0 || t != s-1)  */
+            Node s_gt_t = nm->mkNode(BITVECTOR_UGT, s, t);
+            Node s_m_t = nm->mkNode(BITVECTOR_SUB, s, t);
+            Node smt_gt_t = nm->mkNode(BITVECTOR_UGT, s_m_t, t);
+            Node t_eq_z = nm->mkNode(EQUAL,
+                t, bv::utils::mkZero(bv::utils::getSize(t)));
+            Node s_m_o = nm->mkNode(BITVECTOR_SUB,
+                s, bv::utils::mkOne(bv::utils::getSize(s)));
+            Node t_d_smo = nm->mkNode(DISTINCT, t, s_m_o);
+
+            scl = nm->mkNode(AND,
+                nm->mkNode(AND, s_gt_t, smt_gt_t),
+                nm->mkNode(OR, t_eq_z, t_d_smo));
+            scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UREM_TOTAL, s, x), t);
+          }
+          Node sc = nm->mkNode(IMPLIES, scl, scr);
+          status.d_conds.push_back(sc);
+          Node skv = getInversionNode(sc, solve_tn);
+          t = skv;
+          break;
+        }
+
+        case BITVECTOR_AND:
+        case BITVECTOR_OR: {
+          /* t = skv (fresh skolem constant)
            * with side condition:
-           * s = 0 || clz(t) >= s
-           * ->
-           * s = 0 || ((z o t) << s)[2w-1 : w] = z
-           * with w = getSize(t) = getSize(s) and z = 0 with getSize(z) = w  */
-          unsigned w = bv::utils::getSize(s);
-          Node z = bv::utils::mkZero(w);
-          Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t);
-          Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s);
-          Node zot_shl_zos = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
-          Node ext = bv::utils::mkExtract(zot_shl_zos, 2*w-1, w);
-          scl = nm->mkNode(OR,
-              nm->mkNode(EQUAL, s, z),
-              nm->mkNode(EQUAL, ext, z));
-          scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_LSHR, x, s), t);
+           * t & s = t
+           * t | s = t */
+          TypeNode solve_tn = sv_t[index].getType();
+          Node x = getSolveVariable(solve_tn);
+          Node scl = nm->mkNode(EQUAL, t, nm->mkNode(k, t, s));
+          Node scr = nm->mkNode(EQUAL, nm->mkNode(k, x, s), t);
           Node sc = nm->mkNode(IMPLIES, scl, scr);
           status.d_conds.push_back(sc);
           Node skv = getInversionNode(sc, solve_tn);
           t = skv;
-        } else {
-          // TODO: index == 1
-          /* s >> x = t
-           * with side conditions:
-           * (s = 0 && t = 0)
-           * || (clz(t) >= clz(s)
-           *     && (t = 0
-           *         || "remaining shifted bits in t "
-           *            "match corresponding bits in s"))  */
-          Trace("bv-invert") << "bv-invert : Unsupported for index " << index
-                             << ", from " << sv_t << std::endl;
-          return Node::null();
+          break;
         }
-      } else if (k == BITVECTOR_UDIV_TOTAL) {
-        TypeNode solve_tn = sv_t[index].getType();
-        Node x = getSolveVariable(solve_tn);
-        Node s = sv_t[1 - index];
-        unsigned w = bv::utils::getSize(s);
-        Node scl, scr;
-        Node zero = bv::utils::mkConst(w, 0u);
-
-        /* x udiv s = t */
-        if (index == 0) {
-          /* with side conditions:
-           * !umulo(s * t)
-           */
-          scl = nm->mkNode(NOT, bv::utils::mkUmulo(s, t));
-          scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, x, s), t);
-        /* s udiv x = t */
-        } else {
-          /* with side conditions:
-           * (t = 0 && (s = 0 || s != 2^w-1))
-           * || s >= t
-           * || t = 2^w-1
-           */
-          Node ones = bv::utils::mkOnes(w);
-          Node t_eq_zero = nm->mkNode(EQUAL, t, zero);
-          Node s_eq_zero = nm->mkNode(EQUAL, s, zero);
-          Node s_ne_ones = nm->mkNode(DISTINCT, s, ones);
-          Node s_ge_t = nm->mkNode(BITVECTOR_UGE, s, t);
-          Node t_eq_ones = nm->mkNode(EQUAL, t, ones);
-          scl = nm->mkNode(
-              OR,
-              nm->mkNode(AND, t_eq_zero, nm->mkNode(OR, s_eq_zero, s_ne_ones)),
-              s_ge_t, t_eq_ones);
-          scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, s, x), t);
+
+        case BITVECTOR_LSHR: {
+          /* t = skv (fresh skolem constant)  */
+          TypeNode solve_tn = sv_t[index].getType();
+          Node x = getSolveVariable(solve_tn);
+          Node scl, scr;
+          if (index == 0) {
+            /* x >> s = t
+             * with side condition:
+             * s = 0 || clz(t) >= s
+             * ->
+             * s = 0 || ((z o t) << s)[2w-1 : w] = z
+             * with w = getSize(t) = getSize(s)
+             * and z = 0 with getSize(z) = w  */
+            unsigned w = bv::utils::getSize(s);
+            Node z = bv::utils::mkZero(w);
+            Node z_o_t = nm->mkNode(BITVECTOR_CONCAT, z, t);
+            Node z_o_s = nm->mkNode(BITVECTOR_CONCAT, z, s);
+            Node zot_shl_zos = nm->mkNode(BITVECTOR_SHL, z_o_t, z_o_s);
+            Node ext = bv::utils::mkExtract(zot_shl_zos, 2*w-1, w);
+            scl = nm->mkNode(OR,
+                nm->mkNode(EQUAL, s, z),
+                nm->mkNode(EQUAL, ext, z));
+            scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_LSHR, x, s), t);
+            Node sc = nm->mkNode(IMPLIES, scl, scr);
+            status.d_conds.push_back(sc);
+            Node skv = getInversionNode(sc, solve_tn);
+            t = skv;
+          } else {
+            /* s >> x = t
+             * with side conditions:
+             * (s = 0 && t = 0)
+             * || (clz(t) >= clz(s)
+             *     && (t = 0
+             *         || "remaining shifted bits in t "
+             *            "match corresponding bits in s"))  */
+            Trace("bv-invert") << "bv-invert : Unsupported for index " << index
+                               << ", from " << sv_t << std::endl;
+            return Node::null();
+          }
+          break;
         }
 
-        /* overall side condition */
-        Node sc = nm->mkNode(IMPLIES, scl, scr);
-        /* add side condition */
-        status.d_conds.push_back(sc);
-
-        /* get the skolem node for this side condition*/
-        Node skv = getInversionNode(sc, solve_tn);
-        /* now solving with the skolem node as the RHS */
-        t = skv;
-      } else if (k == BITVECTOR_SHL) {
-        TypeNode solve_tn = sv_t[index].getType();
-        Node x = getSolveVariable(solve_tn);
-        Node s = sv_t[1 - index];
-        unsigned w = bv::utils::getSize(s);
-        Node scl, scr;
-
-        /* x << s = t */
-        if (index == 0) {
-          /* with side conditions:
-           * (s = 0 || ctz(t) >= s)
-           * <->
-           * (s = 0 || ((t o z) >> (z o s))[w-1:0] = z)
-           *
-           * where
-           * w = getSize(s) = getSize(t) = getSize (z) && z = 0
-           */
+        case BITVECTOR_UDIV_TOTAL: {
+          TypeNode solve_tn = sv_t[index].getType();
+          Node x = getSolveVariable(solve_tn);
+          Node s = sv_t[1 - index];
+          unsigned w = bv::utils::getSize(s);
+          Node scl, scr;
           Node zero = bv::utils::mkConst(w, 0u);
-          Node s_eq_zero = nm->mkNode(EQUAL, s, zero);
-          Node t_conc_zero = nm->mkNode(BITVECTOR_CONCAT, t, zero);
-          Node zero_conc_s = nm->mkNode(BITVECTOR_CONCAT, zero, s);
-          Node shr_s = nm->mkNode(BITVECTOR_LSHR, t_conc_zero, zero_conc_s);
-          Node extr_shr_s = bv::utils::mkExtract(shr_s, w - 1, 0);
-          Node ctz_t_ge_s = nm->mkNode(EQUAL, extr_shr_s, zero);
-          scl = nm->mkNode(OR, s_eq_zero, ctz_t_ge_s);
-          scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_SHL, x, s), t);
-          /* s << x = t */
-        } else {
-          /* with side conditions:
-           * (s = 0 && t = 0)
-           * || (ctz(t) >= ctz(s)
-           *     && (t = 0 ||
-           *         "remaining shifted bits in t match corresponding bits in s"))
-           */
-          Trace("bv-invert") << "bv-invert : Unsupported for index " << index
-                             << ", from " << sv_t << std::endl;
-          return Node::null();
+
+          if (index == 0) {
+            /* x udiv s = t
+             * with side conditions:
+             * !umulo(s * t)
+             */
+            scl = nm->mkNode(NOT, bv::utils::mkUmulo(s, t));
+            scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, x, s), t);
+          } else {
+            /* s udiv x = t
+             * with side conditions:
+             * (t = 0 && (s = 0 || s != 2^w-1))
+             * || s >= t
+             * || t = 2^w-1
+             */
+            Node ones = bv::utils::mkOnes(w);
+            Node t_eq_zero = nm->mkNode(EQUAL, t, zero);
+            Node s_eq_zero = nm->mkNode(EQUAL, s, zero);
+            Node s_ne_ones = nm->mkNode(DISTINCT, s, ones);
+            Node s_ge_t = nm->mkNode(BITVECTOR_UGE, s, t);
+            Node t_eq_ones = nm->mkNode(EQUAL, t, ones);
+            scl = nm->mkNode(OR,
+                             nm->mkNode(AND, t_eq_zero,
+                                        nm->mkNode(OR, s_eq_zero, s_ne_ones)),
+                             s_ge_t, t_eq_ones);
+            scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_UDIV_TOTAL, s, x), t);
+          }
+
+          /* overall side condition */
+          Node sc = nm->mkNode(IMPLIES, scl, scr);
+          /* add side condition */
+          status.d_conds.push_back(sc);
+
+          /* get the skolem node for this side condition*/
+          Node skv = getInversionNode(sc, solve_tn);
+          /* now solving with the skolem node as the RHS */
+          t = skv;
+          break;
         }
 
-        /* overall side condition */
-        Node sc = nm->mkNode(IMPLIES, scl, scr);
-        /* add side condition */
-        status.d_conds.push_back(sc);
-
-        /* get the skolem node for this side condition*/
-        Node skv = getInversionNode(sc, solve_tn);
-        /* now solving with the skolem node as the RHS */
-        t = skv;
-      //}else if( k==BITVECTOR_ASHR ){
-      // TODO
-      } else {
-        Trace("bv-invert") << "bv-invert : Unknown kind for bit-vector term "
-                           << k
-                           << ", from " << sv_t << std::endl;
-        return Node::null();
+        case BITVECTOR_SHL: {
+          TypeNode solve_tn = sv_t[index].getType();
+          Node x = getSolveVariable(solve_tn);
+          Node s = sv_t[1 - index];
+          unsigned w = bv::utils::getSize(s);
+          Node scl, scr;
+
+          if (index == 0) {
+            /* x << s = t
+             * with side conditions:
+             * (s = 0 || ctz(t) >= s)
+             * <->
+             * (s = 0 || ((t o z) >> (z o s))[w-1:0] = z)
+             *
+             * where
+             * w = getSize(s) = getSize(t) = getSize (z) && z = 0
+             */
+            Node zero = bv::utils::mkConst(w, 0u);
+            Node s_eq_zero = nm->mkNode(EQUAL, s, zero);
+            Node t_conc_zero = nm->mkNode(BITVECTOR_CONCAT, t, zero);
+            Node zero_conc_s = nm->mkNode(BITVECTOR_CONCAT, zero, s);
+            Node shr_s = nm->mkNode(BITVECTOR_LSHR, t_conc_zero, zero_conc_s);
+            Node extr_shr_s = bv::utils::mkExtract(shr_s, w - 1, 0);
+            Node ctz_t_ge_s = nm->mkNode(EQUAL, extr_shr_s, zero);
+            scl = nm->mkNode(OR, s_eq_zero, ctz_t_ge_s);
+            scr = nm->mkNode(EQUAL, nm->mkNode(BITVECTOR_SHL, x, s), t);
+          } else {
+            /* s << x = t
+             * with side conditions:
+             * (s = 0 && t = 0)
+             * || (ctz(t) >= ctz(s)
+             *     && (t = 0 ||
+             *         "remaining shifted bits in t"
+             *         "match corresponding bits in s"))
+             */
+            Trace("bv-invert") << "bv-invert : Unsupported for index " << index
+                               << "for bit-vector term " << sv_t << std::endl;
+            return Node::null();
+          }
+
+          /* overall side condition */
+          Node sc = nm->mkNode(IMPLIES, scl, scr);
+          /* add side condition */
+          status.d_conds.push_back(sc);
+
+          /* get the skolem node for this side condition*/
+          Node skv = getInversionNode(sc, solve_tn);
+          /* now solving with the skolem node as the RHS */
+          t = skv;
+          break;
+        }
+        default:
+          Trace("bv-invert") << "bv-invert : Unknown kind " << k
+                             << " for bit-vector term " << sv_t << std::endl;
+          return Node::null();
       }
     }
     sv_t = sv_t[index];