From 6a3072148db6b25b04b4d3a4c05dd2ab8b5db00b Mon Sep 17 00:00:00 2001 From: whitequark Date: Fri, 13 Sep 2019 13:14:52 +0000 Subject: [PATCH] hdl.ast: add Value.{any,all}, mapping to $reduce_{or,and}. Refs #147. --- nmigen/back/pysim.py | 6 ++++++ nmigen/back/rtlil.py | 3 +++ nmigen/hdl/ast.py | 26 +++++++++++++++++++++++--- nmigen/test/test_hdl_ast.py | 8 ++++++++ nmigen/test/test_sim.py | 12 ++++++++++++ 5 files changed, 52 insertions(+), 3 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 82df605..cf4b998 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -129,6 +129,12 @@ class _RHSValueCompiler(_ValueCompiler): return lambda state: normalize(-arg(state), shape) if value.op == "b": return lambda state: normalize(bool(arg(state)), shape) + if value.op == "r|": + return lambda state: normalize(arg(state) != 0, shape) + if value.op == "r&": + val, = value.operands + mask = (1 << len(val)) - 1 + return lambda state: normalize(arg(state) == mask, shape) elif len(value.operands) == 2: lhs, rhs = map(self, value.operands) if value.op == "+": diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index 470bd12..460d1d4 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -375,6 +375,9 @@ class _RHSValueCompiler(_ValueCompiler): (1, "~"): "$not", (1, "-"): "$neg", (1, "b"): "$reduce_bool", + (1, "r|"): "$reduce_or", + (1, "r&"): "$reduce_and", + (1, "r^"): "$reduce_xor", (2, "+"): "$add", (2, "-"): "$sub", (2, "*"): "$mul", diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 56671ed..b91fb50 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -133,10 +133,30 @@ class Value(metaclass=ABCMeta): Returns ------- Value, out - Output ``Value``. If any bits are set, returns ``1``, else ``0``. + ``1`` if any bits are set, ``0`` otherwise. """ return Operator("b", [self]) + def any(self): + """Check if any bits are ``1``. + + Returns + ------- + Value, out + ``1`` if any bits are set, ``0`` otherwise. + """ + return Operator("r|", [self]) + + def all(self): + """Check if all bits are ``1``. + + Returns + ------- + Value, out + ``1`` if all bits are set, ``0`` otherwise. + """ + return Operator("r&", [self]) + def implies(premise, conclusion): """Implication. @@ -361,7 +381,7 @@ class Operator(Value): return a_bits + 1, True else: return a_bits, a_sign - if self.op == "b": + if self.op in ("b", "r|", "r&", "r^"): return 1, False elif len(op_shapes) == 2: (a_bits, a_sign), (b_bits, b_sign) = op_shapes @@ -372,7 +392,7 @@ class Operator(Value): return a_bits + b_bits, a_sign or b_sign if self.op == "%": return a_bits, a_sign - if self.op in ("<", "<=", "==", "!=", ">", ">=", "b"): + if self.op in ("<", "<=", "==", "!=", ">", ">="): return 1, False if self.op in ("&", "^", "|"): return self._bitwise_binary_shape(*op_shapes) diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index 9ad4d32..e0c43ca 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -251,6 +251,14 @@ class OperatorTestCase(FHDLTestCase): self.assertEqual(repr(v), "(b (const 1'd0))") self.assertEqual(v.shape(), (1, False)) + def test_any(self): + v = Const(0b101).any() + self.assertEqual(repr(v), "(r| (const 3'd5))") + + def test_all(self): + v = Const(0b101).all() + self.assertEqual(repr(v), "(r& (const 3'd5))") + def test_hash(self): with self.assertRaises(TypeError): hash(Const(0) + Const(0)) diff --git a/nmigen/test/test_sim.py b/nmigen/test/test_sim.py index 585bbeb..eb283a2 100644 --- a/nmigen/test/test_sim.py +++ b/nmigen/test/test_sim.py @@ -57,6 +57,18 @@ class SimulatorUnitTestCase(FHDLTestCase): self.assertStatement(stmt, [C(1, 4)], C(1)) self.assertStatement(stmt, [C(2, 4)], C(1)) + def test_any(self): + stmt = lambda y, a: y.eq(a.any()) + self.assertStatement(stmt, [C(0b00, 2)], C(0)) + self.assertStatement(stmt, [C(0b01, 2)], C(1)) + self.assertStatement(stmt, [C(0b11, 2)], C(1)) + + def test_all(self): + stmt = lambda y, a: y.eq(a.all()) + self.assertStatement(stmt, [C(0b00, 2)], C(0)) + self.assertStatement(stmt, [C(0b01, 2)], C(0)) + self.assertStatement(stmt, [C(0b11, 2)], C(1)) + def test_add(self): stmt = lambda y, a, b: y.eq(a + b) self.assertStatement(stmt, [C(0, 4), C(1, 4)], C(1, 4)) -- 2.30.2