convert PartitionedAssign and PAssign over to PartType
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 9 Oct 2021 16:30:13 +0000 (17:30 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 9 Oct 2021 16:30:13 +0000 (17:30 +0100)
https://bugs.libre-soc.org/show_bug.cgi?id=713#c56

src/ieee754/part/partsig.py
src/ieee754/part_ass/assign.py
src/ieee754/part_ass/passign.py

index 78658826514dbeae344c95824879d23d529b8d15..2701fedda9cac2395a044c8fdbe8fb549f36540c 100644 (file)
@@ -130,7 +130,7 @@ class PartitionedSignal(UserValue):
 
     def __Assign__(self, val, *, src_loc_at=0):
         # print ("partsig ass", self, val)
-        return PAssign(self.m, self, val, self.partpoints)
+        return PAssign(self.m, self, val, self.ptype)
 
     # TODO, http://bugs.libre-riscv.org/show_bug.cgi?id=458
     #def __Switch__(self, cases, *, src_loc=None, src_loc_at=0,
index 00cebc6dab33a25f93d4e790dd7883fc859bae9e..abb599627ae2c6247b527cae7fa8e108235f27ee 100644 (file)
@@ -43,15 +43,14 @@ def get_runlengths(pbit, size):
 
 
 class PartitionedAssign(Elaboratable):
-    def __init__(self, shape, assign, mask):
+    def __init__(self, shape, assign, ctx):
         """Create a ``PartitionedAssign`` operator
         """
         # work out the length (total of all PartitionedSignals)
         self.assign = assign
-        if isinstance(mask, dict):
-            mask = list(mask.values())
-        self.mask = mask
+        self.ptype = ctx
         self.shape = shape
+        mask = ctx.get_mask()
         self.output = PartitionedSignal(mask, self.shape, reset_less=True)
         self.partition_points = self.output.partpoints
         self.mwidth = len(self.partition_points)+1
@@ -80,14 +79,14 @@ class PartitionedAssign(Elaboratable):
 
         keys = list(self.partition_points.keys())
         print ("keys", keys, "values", self.partition_points.values())
-        print ("mask", self.mask)
+        print ("ptype", self.ptype)
         outpartsize = len(self.output) // self.mwidth
         width, signed = self.output.shape()
         print ("width, signed", width, signed)
 
-        with m.Switch(Cat(self.mask)):
+        with m.Switch(self.ptype.get_switch()):
             # for each partition possibility, create a Assign sequence
-            for pbit in range(1<<len(keys)):
+            for pbit in self.ptype.get_cases():
                 # set up some indices pointing to where things have got
                 # then when called below in the inner nested loop they give
                 # the relevant sequential chunk
@@ -120,7 +119,7 @@ if __name__ == "__main__":
     m = Module()
     mask = Signal(3)
     a = PartitionedSignal(mask, 32)
-    m.submodules.ass = ass = PartitionedAssign(signed(48), a, mask)
+    m.submodules.ass = ass = PartitionedAssign(signed(48), a, a.ptype)
     omask = (1<<len(ass.output))-1
 
     traces = ass.ports()
@@ -154,7 +153,20 @@ if __name__ == "__main__":
     m = Module()
     mask = Signal(3)
     a = Signal(32)
-    m.submodules.ass = ass = PartitionedAssign(signed(48), a, mask)
+    class PartType:
+        def __init__(self, mask):
+            self.mask = mask
+        def get_mask(self):
+            return mask
+        def get_switch(self):
+            return Cat(self.get_mask())
+        def get_cases(self):
+            return range(1<<len(self.get_mask()))
+        @property
+        def blanklanes(self):
+            return 0
+    ptype = PartType(mask)
+    m.submodules.ass = ass = PartitionedAssign(signed(48), a, ptype)
     omask = (1<<len(ass.output))-1
 
     traces = ass.ports()
index 8e497643f473af5765f4f89fd13fed05f3093dde..56f22d399641689a1c8101c9ec278624ea346978 100644 (file)
@@ -17,11 +17,11 @@ See:
 
 
 modcount = 0 # global for now
-def PAssign(m, val, assign, mask):
+def PAssign(m, val, assign, ctx):
     from ieee754.part_ass.assign import PartitionedAssign # recursion issue
     global modcount
     modcount += 1
-    pc = PartitionedAssign(val.shape(), assign, mask)
+    pc = PartitionedAssign(val.shape(), assign, ctx)
     setattr(m.submodules, "pass%d" % modcount, pc)
     return val.lower().eq(pc.output.lower())