add redirection of __Cat__ to allow overrides for more advanced behaviour
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 28 Sep 2021 17:02:53 +0000 (18:02 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Tue, 28 Sep 2021 17:02:53 +0000 (18:02 +0100)
without changing fundamental language characteristics or semantics in nmigen

https://bugs.libre-soc.org/show_bug.cgi?id=458

this one is slightly more involved than __Repl__, __Mux__, __Switch__ etc.
because Cat() can receive arbitrary objects including generators.
exactly as is done in what is now _InternalCat, the arguments need to
be flattened and individually Value.cast()ed, however *before* doing
so, the first argument needs to be inspected and treated separately,
because the first argument is the one on which __Cat__ shall be called.

    args[0].__Cat__(*args[1:])

therefore it must not be lowered because its type (a derivative of
UserValue) would be entirely lost through the Value.cast().

nmigen/back/rtlil.py
nmigen/hdl/ast.py
nmigen/hdl/xfrm.py
tests/test_hdl_ast.py

index 21986392366e547d62d99b426ae7e7b8d360715c..e0835fb429d646a96b526f2b835eb4de9a24765b 100644 (file)
@@ -591,7 +591,7 @@ class _RHSValueCompiler(_ValueCompiler):
             raise TypeError # :nocov:
 
     def _prepare_value_for_Slice(self, value):
-        if isinstance(value, (ast.Signal, ast.Slice, ast.Cat)):
+        if isinstance(value, (ast.Signal, ast.Slice, ast._InternalCat)):
             sigspec = self(value)
         else:
             sigspec = self.s.rtlil.wire(len(value), src=_src(value.src_loc))
@@ -657,7 +657,7 @@ class _LHSValueCompiler(_ValueCompiler):
         return wire_next or wire_curr
 
     def _prepare_value_for_Slice(self, value):
-        assert isinstance(value, (ast.Signal, ast.Slice, ast.Cat))
+        assert isinstance(value, (ast.Signal, ast.Slice, ast._InternalCat))
         return self(value)
 
     def on_Part(self, value):
index e645dc6d4ec246eef09f52a8a7d8de6129abfb68..2f0625fa62d7d21937db46ff6d7cf4fc9febd9e0 100644 (file)
@@ -15,7 +15,7 @@ __all__ = [
     "Shape", "signed", "unsigned",
     "Value", "Const", "C", "AnyConst", "AnySeq", "Operator", "Mux", "Part", "Slice", "Cat", "Repl",
     "Array", "ArrayProxy",
-    "_InternalSwitch", "_InternalAssign", "_InternalRepl",
+    "_InternalSwitch", "_InternalAssign", "_InternalRepl", "_InternalCat",
     "Signal", "ClockSignal", "ResetSignal",
     "UserValue", "ValueCastable",
     "Sample", "Past", "Stable", "Rose", "Fell", "Initial",
@@ -154,6 +154,10 @@ class Value(metaclass=ABCMeta):
     def __Repl__(self, count, *, src_loc_at=0):
         return _InternalRepl(self, count, src_loc_at=src_loc_at)
 
+    def __Cat__(self, *args, src_loc_at=0):
+        args = [self] + list(args)
+        return _InternalCat(*args, src_loc_at=src_loc_at)
+
     def __Mux__(self, val1, val0):
         return _InternalMux(self, val1, val0)
 
@@ -829,7 +833,33 @@ class Part(Value):
 
 
 @final
-class Cat(Value):
+def Cat(*args, src_loc_at=0):
+    """Concatenate values.
+
+    Behaviour is required to be identical to _InternalCat.
+    The first argument "defines" the way that all others are
+    handled.  If the first argument is derived from UserValue,
+    it is not downcast to a type Value because doing so would
+    lose the opportunity for "redirection" (Value.__Cat__ would
+    always be called).
+    """
+    # flatten the args and convert to tuple (not a generator)
+    args = tuple(flatten(args))
+    # check if there are no arguments (zero-length Signal).
+    if len(args) == 0:
+        return _InternalCat(*args, src_loc_at=src_loc_at)
+    # determine if the first is a UserValue.
+    if isinstance(args[0], UserValue):
+        first = args[0] # take UserValue directly, do not downcast
+    else:
+        first = Value.cast(args[0]) # ok to downcast to Value
+    # all other arguments are safe to downcast to Value
+    rest = [Value.cast(v) for v in flatten(args[1:])]
+    # assume first item defines the "handling" for all others
+    return first.__Cat__(*rest, src_loc_at=src_loc_at)
+
+
+class _InternalCat(Value):
     """Concatenate values.
 
     Form a compound ``Value`` from several smaller ones by concatenation.
@@ -1679,7 +1709,7 @@ class ValueKey:
         elif isinstance(self.value, Part):
             self._hash = hash((ValueKey(self.value.value), ValueKey(self.value.offset),
                               self.value.width, self.value.stride))
-        elif isinstance(self.value, Cat):
+        elif isinstance(self.value, _InternalCat):
             self._hash = hash(tuple(ValueKey(o) for o in self.value.parts))
         elif isinstance(self.value, ArrayProxy):
             self._hash = hash((ValueKey(self.value.index),
@@ -1721,7 +1751,7 @@ class ValueKey:
                     ValueKey(self.value.offset) == ValueKey(other.value.offset) and
                     self.value.width == other.value.width and
                     self.value.stride == other.value.stride)
-        elif isinstance(self.value, Cat):
+        elif isinstance(self.value, _InternalCat):
             return all(ValueKey(a) == ValueKey(b)
                         for a, b in zip(self.value.parts, other.value.parts))
         elif isinstance(self.value, ArrayProxy):
index d55afb8fcb6ab25134a5af2c24b6a94039e87e2b..a0736173eb2ace7c80c88d0441c5c3e507f0ed79 100644 (file)
@@ -104,7 +104,7 @@ class ValueVisitor(metaclass=ABCMeta):
             new_value = self.on_Slice(value)
         elif type(value) is Part:
             new_value = self.on_Part(value)
-        elif type(value) is Cat:
+        elif type(value) is _InternalCat:
             new_value = self.on_Cat(value)
         elif type(value) is _InternalRepl:
             new_value = self.on_Repl(value)
index b3492313ce29883664b72b35a755d40ec305bce5..126eac4fc9e655a7750f689e6a80904dbd8ecba7 100644 (file)
@@ -200,7 +200,7 @@ class ValueTestCase(FHDLTestCase):
         self.assertEqual(s2.start, 1)
         self.assertEqual(s2.stop, 2)
         s3 = Const(31)[::2]
-        self.assertIsInstance(s3, Cat)
+        self.assertIsInstance(s3, _InternalCat)
         self.assertIsInstance(s3.parts[0], Slice)
         self.assertEqual(s3.parts[0].start, 0)
         self.assertEqual(s3.parts[0].stop, 1)