add redirection of __Part__ to allow overrides for advanced behaviour
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Wed, 29 Sep 2021 15:53:32 +0000 (16:53 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 2 Oct 2021 14:58:16 +0000 (15:58 +0100)
without changing fundamental language characteristics or semantics of nmigen
https://bugs.libre-soc.org/show_bug.cgi?id=458

nmigen/hdl/ast.py
nmigen/hdl/xfrm.py
tests/test_hdl_ast.py

index 9d0cbc2eb1d8ced4ff2b8177b1cef317ca733d62..1df331ede92edb1c4815f62260b7d7371ecfdd23 100644 (file)
@@ -16,6 +16,7 @@ __all__ = [
     "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
     "Array", "ArrayProxy",
     "_InternalSwitch", "_InternalAssign", "_InternalRepl", "_InternalCat",
+    "_InternalPart",
     "Signal", "ClockSignal", "ResetSignal",
     "UserValue", "ValueCastable",
     "Sample", "Past", "Stable", "Rose", "Fell", "Initial",
@@ -156,6 +157,10 @@ class Value(metaclass=ABCMeta):
     # *MUST NOT* use "Type (2) - dsl.Module" nmigen language constructs
     # (m.If, m.Else, m.Switch, m.FSM): it creates complications in dsl.Module.
 
+    def __Part__(self, offset, width, stride, *, src_loc_at=0):
+        return _InternalPart(self, offset, width, stride, src_loc_at=src_loc_at)
+    def __Repl__(self, count, *, src_loc_at=0):
+        return _InternalRepl(self, count, src_loc_at=src_loc_at)
     def __Repl__(self, count, *, src_loc_at=0):
         return _InternalRepl(self, count, src_loc_at=src_loc_at)
     def __Cat__(self, *args, src_loc_at=0):
@@ -803,7 +808,11 @@ class Slice(Value):
 
 
 @final
-class Part(Value):
+def Part(value, offset, width, stride=1, *, src_loc_at=0):
+    return value.__Part__(offset, width, stride, src_loc_at=src_loc_at)
+
+
+class _InternalPart(Value):
     def __init__(self, value, offset, width, stride=1, *, src_loc_at=0):
         if not isinstance(width, int) or width < 0:
             raise TypeError("Part width must be a non-negative integer, not {!r}".format(width))
@@ -1703,7 +1712,7 @@ class ValueKey:
                                tuple(ValueKey(o) for o in self.value.operands)))
         elif isinstance(self.value, Slice):
             self._hash = hash((ValueKey(self.value.value), self.value.start, self.value.stop))
-        elif isinstance(self.value, Part):
+        elif isinstance(self.value, _InternalPart):
             self._hash = hash((ValueKey(self.value.value), ValueKey(self.value.offset),
                               self.value.width, self.value.stride))
         elif isinstance(self.value, _InternalCat):
@@ -1743,7 +1752,7 @@ class ValueKey:
             return (ValueKey(self.value.value) == ValueKey(other.value.value) and
                     self.value.start == other.value.start and
                     self.value.stop == other.value.stop)
-        elif isinstance(self.value, Part):
+        elif isinstance(self.value, _InternalPart):
             return (ValueKey(self.value.value) == ValueKey(other.value.value) and
                     ValueKey(self.value.offset) == ValueKey(other.value.offset) and
                     self.value.width == other.value.width and
index a0736173eb2ace7c80c88d0441c5c3e507f0ed79..cd75531eb4291e21a6146850f9169bb005509a86 100644 (file)
@@ -102,7 +102,7 @@ class ValueVisitor(metaclass=ABCMeta):
             new_value = self.on_Operator(value)
         elif type(value) is Slice:
             new_value = self.on_Slice(value)
-        elif type(value) is Part:
+        elif type(value) is _InternalPart:
             new_value = self.on_Part(value)
         elif type(value) is _InternalCat:
             new_value = self.on_Cat(value)
index 7b1239812ae9f302f17e565f56f539ae890732d5..28a12a5e6e0485bd025a26527cd95317a6515ff3 100644 (file)
@@ -663,16 +663,16 @@ class BitSelectTestCase(FHDLTestCase):
 
     def test_shape(self):
         s1 = self.c.bit_select(self.s, 2)
-        self.assertIsInstance(s1, Part)
+        self.assertIsInstance(s1, _InternalPart)
         self.assertEqual(s1.shape(), unsigned(2))
         self.assertIsInstance(s1.shape(), Shape)
         s2 = self.c.bit_select(self.s, 0)
-        self.assertIsInstance(s2, Part)
+        self.assertIsInstance(s2, _InternalPart)
         self.assertEqual(s2.shape(), unsigned(0))
 
     def test_stride(self):
         s1 = self.c.bit_select(self.s, 2)
-        self.assertIsInstance(s1, Part)
+        self.assertIsInstance(s1, _InternalPart)
         self.assertEqual(s1.stride, 1)
 
     def test_const(self):
@@ -696,13 +696,13 @@ class WordSelectTestCase(FHDLTestCase):
 
     def test_shape(self):
         s1 = self.c.word_select(self.s, 2)
-        self.assertIsInstance(s1, Part)
+        self.assertIsInstance(s1, _InternalPart)
         self.assertEqual(s1.shape(), unsigned(2))
         self.assertIsInstance(s1.shape(), Shape)
 
     def test_stride(self):
         s1 = self.c.word_select(self.s, 2)
-        self.assertIsInstance(s1, Part)
+        self.assertIsInstance(s1, _InternalPart)
         self.assertEqual(s1.stride, 2)
 
     def test_const(self):