hdl.rec: make Record inherit from UserValue. working_23jun2020
authoranuejn <jarohabiger@googlemail.com>
Thu, 16 Apr 2020 16:46:55 +0000 (18:46 +0200)
committerGitHub <noreply@github.com>
Thu, 16 Apr 2020 16:46:55 +0000 (16:46 +0000)
Closes #354.

nmigen/back/pysim.py
nmigen/back/rtlil.py
nmigen/hdl/ast.py
nmigen/hdl/rec.py
nmigen/hdl/xfrm.py
nmigen/test/test_hdl_ast.py
nmigen/test/test_hdl_xfrm.py

index 1ea631c56b880c24dc33aebf4e3839a12e0698cc..3077991e24cb40b078ea15d3bc87e5604c8a26a4 100644 (file)
@@ -374,9 +374,6 @@ class _ValueCompiler(ValueVisitor, _Compiler):
     def on_ResetSignal(self, value):
         raise NotImplementedError # :nocov:
 
-    def on_Record(self, value):
-        return self(Cat(value.fields.values()))
-
     def on_AnyConst(self, value):
         raise NotImplementedError # :nocov:
 
index 120a354ee21332f30a706ebf979136d06eddfce4..3f9936e7882400d37c3bb8ecff43e4d384d12f17 100644 (file)
@@ -365,9 +365,6 @@ class _ValueCompiler(xfrm.ValueVisitor):
     def on_Initial(self, value):
         raise NotImplementedError # :nocov:
 
-    def on_Record(self, value):
-        return self(ast.Cat(value.fields.values()))
-
     def on_Cat(self, value):
         return "{{ {} }}".format(" ".join(reversed([self(o) for o in value.parts])))
 
@@ -378,7 +375,11 @@ class _ValueCompiler(xfrm.ValueVisitor):
         if value.start == 0 and value.stop == len(value.value):
             return self(value.value)
 
-        sigspec = self._prepare_value_for_Slice(value.value)
+        if isinstance(value.value, ast.UserValue):
+            sigspec = self._prepare_value_for_Slice(value.value._lazy_lower())
+        else:
+            sigspec = self._prepare_value_for_Slice(value.value)
+
         if value.start == value.stop:
             return "{}"
         elif value.start + 1 == value.stop:
@@ -644,7 +645,7 @@ class _LHSValueCompiler(_ValueCompiler):
         return wire_next or wire_curr
 
     def _prepare_value_for_Slice(self, value):
-        assert isinstance(value, (ast.Signal, ast.Slice, ast.Cat, rec.Record))
+        assert isinstance(value, (ast.Signal, ast.Slice, ast.Cat))
         return self(value)
 
     def on_Part(self, value):
index 45dab3d06dd8337bff8cf2c1dbef0e0659d97153..d8d9da04120d2de502aab97372401d437e63e499 100644 (file)
@@ -1188,7 +1188,10 @@ class UserValue(Value):
 
     def _lazy_lower(self):
         if self.__lowered is None:
-            self.__lowered = Value.cast(self.lower())
+            lowered = self.lower()
+            if isinstance(lowered, UserValue):
+                lowered = lowered._lazy_lower()
+            self.__lowered = Value.cast(lowered)
         return self.__lowered
 
     def shape(self):
index b60fb00fd74b6633912d91afec9ba3528f45be42..be8934226a1027acbaea076b6ee1c91b5cc583a4 100644 (file)
@@ -85,7 +85,7 @@ class Layout:
 
 
 # Unlike most Values, Record *can* be subclassed.
-class Record(Value):
+class Record(UserValue):
     @staticmethod
     def like(other, *, name=None, name_suffix=None, src_loc_at=0):
         if name is not None:
@@ -113,6 +113,8 @@ class Record(Value):
         return Record(other.layout, name=new_name, fields=fields, src_loc_at=1)
 
     def __init__(self, layout, *, name=None, fields=None, src_loc_at=0):
+        super().__init__(src_loc_at=src_loc_at)
+
         if name is None:
             name = tracer.get_var_name(depth=2 + src_loc_at, default=None)
 
@@ -165,8 +167,8 @@ class Record(Value):
         else:
             return super().__getitem__(item)
 
-    def shape(self):
-        return Shape(sum(len(f) for f in self.fields.values()))
+    def lower(self):
+        return Cat(self.fields.values())
 
     def _lhs_signals(self):
         return union((f._lhs_signals() for f in self.fields.values()), start=SignalSet())
index fb4dad954b704313e56ec73c5eb3da8340cdf4f4..4d83eb72663fdb6639db0627e2d2937b086f2d5b 100644 (file)
@@ -38,10 +38,6 @@ class ValueVisitor(metaclass=ABCMeta):
     def on_Signal(self, value):
         pass # :nocov:
 
-    @abstractmethod
-    def on_Record(self, value):
-        pass # :nocov:
-
     @abstractmethod
     def on_ClockSignal(self, value):
         pass # :nocov:
@@ -98,9 +94,6 @@ class ValueVisitor(metaclass=ABCMeta):
         elif isinstance(value, Signal):
             # Uses `isinstance()` and not `type() is` because nmigen.compat requires it.
             new_value = self.on_Signal(value)
-        elif isinstance(value, Record):
-            # Uses `isinstance()` and not `type() is` to allow inheriting from Record.
-            new_value = self.on_Record(value)
         elif type(value) is ClockSignal:
             new_value = self.on_ClockSignal(value)
         elif type(value) is ResetSignal:
@@ -147,9 +140,6 @@ class ValueTransformer(ValueVisitor):
     def on_Signal(self, value):
         return value
 
-    def on_Record(self, value):
-        return value
-
     def on_ClockSignal(self, value):
         return value
 
@@ -372,8 +362,6 @@ class DomainCollector(ValueVisitor, StatementVisitor):
     def on_ResetSignal(self, value):
         self._add_used_domain(value.domain)
 
-    on_Record = on_ignore
-
     def on_Operator(self, value):
         for o in value.operands:
             self.on_value(o)
index f642f20f6da5cc903f2283c6f88670339ffb4f29..1b7dd58e295db0029d26772ed0aaa11f06ca654b 100644 (file)
@@ -916,6 +916,14 @@ class UserValueTestCase(FHDLTestCase):
         self.assertEqual(uv.shape(), unsigned(1))
         self.assertEqual(uv.lower_count, 1)
 
+    def test_lower_to_user_value(self):
+        uv = MockUserValue(MockUserValue(1))
+        self.assertEqual(uv.shape(), unsigned(1))
+        self.assertIsInstance(uv.shape(), Shape)
+        uv.lowered = MockUserValue(2)
+        self.assertEqual(uv.shape(), unsigned(1))
+        self.assertEqual(uv.lower_count, 1)
+
 
 class SampleTestCase(FHDLTestCase):
     def test_const(self):
index e5f1745f1cd0a0e0117bedc1bb07f5eee716be53..121c87f10cb4b174d6d97094fe06f5ce2632e2c6 100644 (file)
@@ -620,3 +620,12 @@ class UserValueTestCase(FHDLTestCase):
             )
         )
         """)
+
+
+class UserValueRecursiveTestCase(UserValueTestCase):
+    def setUp(self):
+        self.s = Signal()
+        self.c = Signal()
+        self.uv = MockUserValue(MockUserValue(self.s))
+
+    # inherit the test_lower method from UserValueTestCase because the checks are the same