From a931e51840a5f4a2cfd6f8ebdf944b13a2bdc8e6 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 15 Dec 2018 09:46:20 +0000 Subject: [PATCH] fhdl.ast: refactor Operator.shape(). NFC. --- nmigen/fhdl/ast.py | 102 +++++++++++++++++++---------------- nmigen/test/test_fhdl_ast.py | 5 ++ 2 files changed, 61 insertions(+), 46 deletions(-) diff --git a/nmigen/fhdl/ast.py b/nmigen/fhdl/ast.py index 01a43ba..d878968 100644 --- a/nmigen/fhdl/ast.py +++ b/nmigen/fhdl/ast.py @@ -259,61 +259,71 @@ class Operator(Value): self.operands = [Value.wrap(o) for o in operands] @staticmethod - def _bitwise_binary_shape(a, b): - if not a[1] and not b[1]: + def _bitwise_binary_shape(a_shape, b_shape): + a_bits, a_sign = a_shape + b_bits, b_sign = b_shape + if not a_sign and not b_sign: # both operands unsigned - return max(a[0], b[0]), False - elif a[1] and b[1]: + return max(a_bits, b_bits), False + elif a_sign and b_sign: # both operands signed - return max(a[0], b[0]), True - elif not a[1] and b[1]: + return max(a_bits, b_bits), True + elif not a_sign and b_sign: # first operand unsigned (add sign bit), second operand signed - return max(a[0] + 1, b[0]), True + return max(a_bits + 1, b_bits), True else: # first signed, second operand unsigned (add sign bit) - return max(a[0], b[0] + 1), True + return max(a_bits, b_bits + 1), True def shape(self): - obs = list(map(lambda x: x.shape(), self.operands)) - if self.op == "+" or self.op == "-": - if len(obs) == 1: - if self.op == "-" and not obs[0][1]: - return obs[0][0] + 1, True + op_shapes = list(map(lambda x: x.shape(), self.operands)) + if len(op_shapes) == 1: + (a_bits, a_sign), = op_shapes + if self.op in ("+", "~"): + return a_bits, a_sign + if self.op == "-": + if not a_sign: + return a_bits + 1, True else: - return obs[0] - n, s = self._bitwise_binary_shape(*obs) - return n + 1, s - elif self.op == "*": - if not obs[0][1] and not obs[1][1]: - # both operands unsigned - return obs[0][0] + obs[1][0], False - elif obs[0][1] and obs[1][1]: - # both operands signed - return obs[0][0] + obs[1][0] - 1, True - else: + return a_bits, a_sign + if self.op == "b": + return 1, False + elif len(op_shapes) == 2: + (a_bits, a_sign), (b_bits, b_sign) = op_shapes + if self.op == "+" or self.op == "-": + bits, sign = self._bitwise_binary_shape(*op_shapes) + return bits + 1, sign + if self.op == "*": + if not a_sign and not b_sign: + # both operands unsigned + return a_bits + b_bits, False + if a_sign and b_sign: + # both operands signed + return a_bits + b_bits - 1, True # one operand signed, the other unsigned (add sign bit) - return obs[0][0] + obs[1][0] + 1 - 1, True - elif self.op == "<<<": - if obs[1][1]: - extra = 2**(obs[1][0] - 1) - 1 - else: - extra = 2**obs[1][0] - 1 - return obs[0][0] + extra, obs[0][1] - elif self.op == ">>>": - if obs[1][1]: - extra = 2**(obs[1][0] - 1) - else: - extra = 0 - return obs[0][0] + extra, obs[0][1] - elif self.op in ("&", "^", "|"): - return self._bitwise_binary_shape(*obs) - elif self.op in ("<", "<=", "==", "!=", ">", ">=", "b"): - return 1, False - elif self.op == "~": - return obs[0] - elif self.op == "m": - return self._bitwise_binary_shape(obs[1], obs[2]) - raise NotImplementedError("Operator '{}' not implemented".format(self.op)) # :nocov: + return a_bits + b_bits + 1 - 1, True + if self.op in ("<", "<=", "==", "!=", ">", ">=", "b"): + return 1, False + if self.op in ("&", "^", "|"): + return self._bitwise_binary_shape(*op_shapes) + if self.op == "<<<": + if b_sign: + extra = 2**(b_bits - 1) - 1 + else: + extra = 2**b_bits - 1 + return a_bits + extra, a_sign + if self.op == ">>>": + if b_sign: + extra = 2**(b_bits - 1) + else: + extra = 0 + return a_bits + extra, a_sign + elif len(op_shapes) == 3: + if self.op == "m": + s_shape, a_shape, b_shape = op_shapes + return self._bitwise_binary_shape(a_shape, b_shape) + raise NotImplementedError("Operator {}/{} not implemented" + .format(self.op, len(op_shapes))) # :nocov: def _rhs_signals(self): return union(op._rhs_signals() for op in self.operands) diff --git a/nmigen/test/test_fhdl_ast.py b/nmigen/test/test_fhdl_ast.py index 66ddec7..433b609 100644 --- a/nmigen/test/test_fhdl_ast.py +++ b/nmigen/test/test_fhdl_ast.py @@ -88,6 +88,11 @@ class ConstTestCase(FHDLTestCase): class OperatorTestCase(FHDLTestCase): + def test_bool(self): + v = Const(0, 4).bool() + self.assertEqual(repr(v), "(b (const 4'd0))") + self.assertEqual(v.shape(), (1, False)) + def test_invert(self): v = ~Const(0, 4) self.assertEqual(repr(v), "(~ (const 4'd0))") -- 2.30.2