From d1ffbe06f4b3cd0351d182bbbd986fdb29d31e9c Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 2 Oct 2021 14:18:02 +0000 Subject: [PATCH] hdl.ast: simplify Mux implementation. --- nmigen/back/rtlil.py | 2 ++ nmigen/hdl/ast.py | 3 --- tests/test_hdl_ast.py | 2 +- tests/test_sim.py | 6 ++++++ 4 files changed, 9 insertions(+), 4 deletions(-) diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index 2198639..6da4d9c 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -562,6 +562,8 @@ class _RHSValueCompiler(_ValueCompiler): def on_Operator_mux(self, value): sel, val1, val0 = value.operands + if len(sel) != 1: + sel = sel.bool() val1_bits, val1_sign = val1.shape() val0_bits, val0_sign = val0.shape() res_bits, res_sign = value.shape() diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 685924a..5ed3a77 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -735,9 +735,6 @@ def Mux(sel, val1, val0): Value, out Output ``Value``. If ``sel`` is asserted, the Mux returns ``val1``, else ``val0``. """ - sel = Value.cast(sel) - if len(sel) != 1: - sel = sel.bool() return Operator("m", [sel, val1, val0]) diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 9f0fec6..3604433 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -542,7 +542,7 @@ class OperatorTestCase(FHDLTestCase): def test_mux_wide(self): s = Const(0b100) v = Mux(s, Const(0, unsigned(4)), Const(0, unsigned(6))) - self.assertEqual(repr(v), "(m (b (const 3'd4)) (const 4'd0) (const 6'd0))") + self.assertEqual(repr(v), "(m (const 3'd4) (const 4'd0) (const 6'd0))") def test_mux_bool(self): v = Mux(True, Const(0), Const(0)) diff --git a/tests/test_sim.py b/tests/test_sim.py index ab31bd6..e4bd5c8 100644 --- a/tests/test_sim.py +++ b/tests/test_sim.py @@ -191,6 +191,12 @@ class SimulatorUnitTestCase(FHDLTestCase): self.assertStatement(stmt, [C(2, 4), C(3, 4), C(0)], C(2, 4)) self.assertStatement(stmt, [C(2, 4), C(3, 4), C(1)], C(3, 4)) + def test_mux_wide(self): + stmt = lambda y, a, b, c: y.eq(Mux(c, a, b)) + self.assertStatement(stmt, [C(2, 4), C(3, 4), C(0, 2)], C(3, 4)) + self.assertStatement(stmt, [C(2, 4), C(3, 4), C(1, 2)], C(2, 4)) + self.assertStatement(stmt, [C(2, 4), C(3, 4), C(2, 2)], C(2, 4)) + def test_abs(self): stmt = lambda y, a: y.eq(abs(a)) self.assertStatement(stmt, [C(3, unsigned(8))], C(3, unsigned(8))) -- 2.30.2