From ccfd17f885a34062bbe2bd155a4a71690f5827c6 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Sun, 10 Oct 2021 12:33:49 +0100 Subject: [PATCH] add redirection of __Slice__ to allow overrides for advanced behaviour without changing fundamental language characteristics or semantics of nmigen https://bugs.libre-soc.org/show_bug.cgi?id=458 --- nmigen/back/rtlil.py | 4 ++-- nmigen/hdl/ast.py | 19 ++++++++++++++----- nmigen/hdl/xfrm.py | 2 +- tests/test_hdl_ast.py | 18 +++++++++--------- 4 files changed, 26 insertions(+), 17 deletions(-) diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index ad58ad3..7d44672 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -593,7 +593,7 @@ class _RHSValueCompiler(_ValueCompiler): raise TypeError # :nocov: def _prepare_value_for_Slice(self, value): - if isinstance(value, (ast.Signal, ast.Slice, ast._InternalCat)): + if isinstance(value, (ast.Signal, ast._InternalSlice, ast._InternalCat)): sigspec = self(value) else: sigspec = self.s.rtlil.wire(len(value), src=_src(value.src_loc)) @@ -659,7 +659,7 @@ class _LHSValueCompiler(_ValueCompiler): return wire_next or wire_curr def _prepare_value_for_Slice(self, value): - assert isinstance(value, (ast.Signal, ast.Slice, ast._InternalCat)) + assert isinstance(value, (ast.Signal, ast._InternalSlice, ast._InternalCat)) return self(value) def on_Part(self, value): diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 3721482..03d2be6 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -16,7 +16,7 @@ __all__ = [ "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl", "Array", "ArrayProxy", "_InternalSwitch", "_InternalAssign", "_InternalRepl", "_InternalCat", - "_InternalPart", + "_InternalPart", "_InternalSlice", "Signal", "ClockSignal", "ResetSignal", "UserValue", "ValueCastable", "Sample", "Past", "Stable", "Rose", "Fell", "Initial", @@ -158,6 +158,8 @@ class Value(metaclass=ABCMeta): # (m.If, m.Else, m.Switch, m.FSM): it creates complications # (recursive dependencies) in dsl.Module. + def __Slice__(self, start, stop, *, src_loc_at=0): + return _InternalSlice(self, start, stop, src_loc_at=src_loc_at) def __Part__(self, offset, width, stride=1, *, src_loc_at=0): return _InternalPart(self, offset, width, stride, src_loc_at=src_loc_at) @@ -773,7 +775,14 @@ def _InternalMux(sel, val1, val0): @final -class Slice(Value): +def Slice(value, start, stop, *, src_loc_at=0): + # this relies on Value.cast returning the original object unmodified if + # it is already derived from Value (UserValue) + value = Value.cast(value) + return value.__Slice__(start, stop, src_loc_at=src_loc_at) + + +class _InternalSlice(Value): def __init__(self, value, start, stop, *, src_loc_at=0): if not isinstance(start, int): raise TypeError("Slice start must be an integer, not {!r}".format(start)) @@ -1715,7 +1724,7 @@ class ValueKey: elif isinstance(self.value, Operator): self._hash = hash((self.value.operator, tuple(ValueKey(o) for o in self.value.operands))) - elif isinstance(self.value, Slice): + elif isinstance(self.value, _InternalSlice): self._hash = hash((ValueKey(self.value.value), self.value.start, self.value.stop)) elif isinstance(self.value, _InternalPart): self._hash = hash((ValueKey(self.value.value), ValueKey(self.value.offset), @@ -1753,7 +1762,7 @@ class ValueKey: len(self.value.operands) == len(other.value.operands) and all(ValueKey(a) == ValueKey(b) for a, b in zip(self.value.operands, other.value.operands))) - elif isinstance(self.value, Slice): + elif isinstance(self.value, _InternalSlice): return (ValueKey(self.value.value) == ValueKey(other.value.value) and self.value.start == other.value.start and self.value.stop == other.value.stop) @@ -1791,7 +1800,7 @@ class ValueKey: return self.value < other.value elif isinstance(self.value, (Signal, AnyValue)): return self.value.duid < other.value.duid - elif isinstance(self.value, Slice): + elif isinstance(self.value, _InternalSlice): return (ValueKey(self.value.value) < ValueKey(other.value.value) and self.value.start < other.value.start and self.value.end < other.value.end) diff --git a/nmigen/hdl/xfrm.py b/nmigen/hdl/xfrm.py index cd75531..c39e523 100644 --- a/nmigen/hdl/xfrm.py +++ b/nmigen/hdl/xfrm.py @@ -100,7 +100,7 @@ class ValueVisitor(metaclass=ABCMeta): new_value = self.on_ResetSignal(value) elif type(value) is Operator: new_value = self.on_Operator(value) - elif type(value) is Slice: + elif type(value) is _InternalSlice: new_value = self.on_Slice(value) elif type(value) is _InternalPart: new_value = self.on_Part(value) diff --git a/tests/test_hdl_ast.py b/tests/test_hdl_ast.py index 53985c8..588a1d0 100644 --- a/tests/test_hdl_ast.py +++ b/tests/test_hdl_ast.py @@ -179,11 +179,11 @@ class ValueTestCase(FHDLTestCase): def test_getitem_int(self): s1 = Const(10)[0] - self.assertIsInstance(s1, Slice) + self.assertIsInstance(s1, _InternalSlice) self.assertEqual(s1.start, 0) self.assertEqual(s1.stop, 1) s2 = Const(10)[-1] - self.assertIsInstance(s2, Slice) + self.assertIsInstance(s2, _InternalSlice) self.assertEqual(s2.start, 3) self.assertEqual(s2.stop, 4) with self.assertRaisesRegex(IndexError, @@ -192,22 +192,22 @@ class ValueTestCase(FHDLTestCase): def test_getitem_slice(self): s1 = Const(10)[1:3] - self.assertIsInstance(s1, Slice) + self.assertIsInstance(s1, _InternalSlice) self.assertEqual(s1.start, 1) self.assertEqual(s1.stop, 3) s2 = Const(10)[1:-2] - self.assertIsInstance(s2, Slice) + self.assertIsInstance(s2, _InternalSlice) self.assertEqual(s2.start, 1) self.assertEqual(s2.stop, 2) s3 = Const(31)[::2] self.assertIsInstance(s3, _InternalCat) - self.assertIsInstance(s3.parts[0], Slice) + self.assertIsInstance(s3.parts[0], _InternalSlice) self.assertEqual(s3.parts[0].start, 0) self.assertEqual(s3.parts[0].stop, 1) - self.assertIsInstance(s3.parts[1], Slice) + self.assertIsInstance(s3.parts[1], _InternalSlice) self.assertEqual(s3.parts[1].start, 2) self.assertEqual(s3.parts[1].stop, 3) - self.assertIsInstance(s3.parts[2], Slice) + self.assertIsInstance(s3.parts[2], _InternalSlice) self.assertEqual(s3.parts[2].start, 4) self.assertEqual(s3.parts[2].stop, 5) @@ -677,7 +677,7 @@ class BitSelectTestCase(FHDLTestCase): def test_const(self): s1 = self.c.bit_select(1, 2) - self.assertIsInstance(s1, Slice) + self.assertIsInstance(s1, _InternalSlice) self.assertRepr(s1, """(slice (const 8'd0) 1:3)""") def test_width_wrong(self): @@ -707,7 +707,7 @@ class WordSelectTestCase(FHDLTestCase): def test_const(self): s1 = self.c.word_select(1, 2) - self.assertIsInstance(s1, Slice) + self.assertIsInstance(s1, _InternalSlice) self.assertRepr(s1, """(slice (const 8'd0) 2:4)""") def test_width_wrong(self): -- 2.30.2