fhdl.ast: refactor Operator.shape(). NFC.
authorwhitequark <whitequark@whitequark.org>
Sat, 15 Dec 2018 09:46:20 +0000 (09:46 +0000)
committerwhitequark <whitequark@whitequark.org>
Sat, 15 Dec 2018 09:46:20 +0000 (09:46 +0000)
nmigen/fhdl/ast.py
nmigen/test/test_fhdl_ast.py

index 01a43badadf1d3d88ace3b9bf2a46da512aa0892..d8789687761e29f7c01f6286cff6e090e2581849 100644 (file)
@@ -259,61 +259,71 @@ class Operator(Value):
         self.operands = [Value.wrap(o) for o in operands]
 
     @staticmethod
-    def _bitwise_binary_shape(a, b):
-        if not a[1] and not b[1]:
+    def _bitwise_binary_shape(a_shape, b_shape):
+        a_bits, a_sign = a_shape
+        b_bits, b_sign = b_shape
+        if not a_sign and not b_sign:
             # both operands unsigned
-            return max(a[0], b[0]), False
-        elif a[1] and b[1]:
+            return max(a_bits, b_bits), False
+        elif a_sign and b_sign:
             # both operands signed
-            return max(a[0], b[0]), True
-        elif not a[1] and b[1]:
+            return max(a_bits, b_bits), True
+        elif not a_sign and b_sign:
             # first operand unsigned (add sign bit), second operand signed
-            return max(a[0] + 1, b[0]), True
+            return max(a_bits + 1, b_bits), True
         else:
             # first signed, second operand unsigned (add sign bit)
-            return max(a[0], b[0] + 1), True
+            return max(a_bits, b_bits + 1), True
 
     def shape(self):
-        obs = list(map(lambda x: x.shape(), self.operands))
-        if self.op == "+" or self.op == "-":
-            if len(obs) == 1:
-                if self.op == "-" and not obs[0][1]:
-                    return obs[0][0] + 1, True
+        op_shapes = list(map(lambda x: x.shape(), self.operands))
+        if len(op_shapes) == 1:
+            (a_bits, a_sign), = op_shapes
+            if self.op in ("+", "~"):
+                return a_bits, a_sign
+            if self.op == "-":
+                if not a_sign:
+                    return a_bits + 1, True
                 else:
-                    return obs[0]
-            n, s = self._bitwise_binary_shape(*obs)
-            return n + 1, s
-        elif self.op == "*":
-            if not obs[0][1] and not obs[1][1]:
-                # both operands unsigned
-                return obs[0][0] + obs[1][0], False
-            elif obs[0][1] and obs[1][1]:
-                # both operands signed
-                return obs[0][0] + obs[1][0] - 1, True
-            else:
+                    return a_bits, a_sign
+            if self.op == "b":
+                return 1, False
+        elif len(op_shapes) == 2:
+            (a_bits, a_sign), (b_bits, b_sign) = op_shapes
+            if self.op == "+" or self.op == "-":
+                bits, sign = self._bitwise_binary_shape(*op_shapes)
+                return bits + 1, sign
+            if self.op == "*":
+                if not a_sign and not b_sign:
+                    # both operands unsigned
+                    return a_bits + b_bits, False
+                if a_sign and b_sign:
+                    # both operands signed
+                    return a_bits + b_bits - 1, True
                 # one operand signed, the other unsigned (add sign bit)
-                return obs[0][0] + obs[1][0] + 1 - 1, True
-        elif self.op == "<<<":
-            if obs[1][1]:
-                extra = 2**(obs[1][0] - 1) - 1
-            else:
-                extra = 2**obs[1][0] - 1
-            return obs[0][0] + extra, obs[0][1]
-        elif self.op == ">>>":
-            if obs[1][1]:
-                extra = 2**(obs[1][0] - 1)
-            else:
-                extra = 0
-            return obs[0][0] + extra, obs[0][1]
-        elif self.op in ("&", "^", "|"):
-            return self._bitwise_binary_shape(*obs)
-        elif self.op in ("<", "<=", "==", "!=", ">", ">=", "b"):
-            return 1, False
-        elif self.op == "~":
-            return obs[0]
-        elif self.op == "m":
-            return self._bitwise_binary_shape(obs[1], obs[2])
-        raise NotImplementedError("Operator '{}' not implemented".format(self.op)) # :nocov:
+                return a_bits + b_bits + 1 - 1, True
+            if self.op in ("<", "<=", "==", "!=", ">", ">=", "b"):
+                return 1, False
+            if self.op in ("&", "^", "|"):
+                return self._bitwise_binary_shape(*op_shapes)
+            if self.op == "<<<":
+                if b_sign:
+                    extra = 2**(b_bits - 1) - 1
+                else:
+                    extra = 2**b_bits - 1
+                return a_bits + extra, a_sign
+            if self.op == ">>>":
+                if b_sign:
+                    extra = 2**(b_bits - 1)
+                else:
+                    extra = 0
+                return a_bits + extra, a_sign
+        elif len(op_shapes) == 3:
+            if self.op == "m":
+                s_shape, a_shape, b_shape = op_shapes
+                return self._bitwise_binary_shape(a_shape, b_shape)
+        raise NotImplementedError("Operator {}/{} not implemented"
+                                  .format(self.op, len(op_shapes))) # :nocov:
 
     def _rhs_signals(self):
         return union(op._rhs_signals() for op in self.operands)
index 66ddec734754fdf0b4068f12a619de1e3e9191a7..433b609181a14f404160be30ffdddd6535c6dded 100644 (file)
@@ -88,6 +88,11 @@ class ConstTestCase(FHDLTestCase):
 
 
 class OperatorTestCase(FHDLTestCase):
+    def test_bool(self):
+        v = Const(0, 4).bool()
+        self.assertEqual(repr(v), "(b (const 4'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
     def test_invert(self):
         v = ~Const(0, 4)
         self.assertEqual(repr(v), "(~ (const 4'd0))")