covert PartitionedCat (and PCat) over to PartType format
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 9 Oct 2021 16:26:32 +0000 (17:26 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 9 Oct 2021 16:26:32 +0000 (17:26 +0100)
https://bugs.libre-soc.org/show_bug.cgi?id=713#c56

src/ieee754/part/partsig.py
src/ieee754/part_cat/cat.py
src/ieee754/part_cat/pcat.py

index d02dac4a1691b8320eb826ec7539d4648edd0902..78658826514dbeae344c95824879d23d529b8d15 100644 (file)
@@ -119,7 +119,7 @@ class PartitionedSignal(UserValue):
             assert isinstance(sig, PartitionedSignal), \
                 "All PartitionedSignal.__Cat__ arguments must be " \
                 "a PartitionedSignal. %s is not." % repr(sig)
-        return PCat(self.m, args, self.partpoints)
+        return PCat(self.m, args, self.ptype)
 
     def __Mux__(self, val1, val2):
         # print ("partsig mux", self, val1, val2)
index 27abb19abc644b467bb060e5234adb43b68da18d..8fcf8bcf842dadd3db4f18db05a67d9122076c39 100644 (file)
@@ -59,18 +59,17 @@ def get_runlengths(pbit, size):
 
 
 class PartitionedCat(Elaboratable):
-    def __init__(self, catlist, mask):
+    def __init__(self, catlist, ctx):
         """Create a ``PartitionedCat`` operator
         """
         # work out the length (total of all PartitionedSignals)
         self.catlist = catlist
-        if isinstance(mask, dict):
-            mask = list(mask.values())
-        self.mask = mask
+        self.ptype = ctx
         width = 0
         for p in catlist:
             width += len(p.sig)
         self.width = width
+        mask = ctx.get_mask()
         self.output = PartitionedSignal(mask, self.width, reset_less=True)
         self.partition_points = self.output.partpoints
         self.mwidth = len(self.partition_points)+1
@@ -94,10 +93,10 @@ class PartitionedCat(Elaboratable):
 
         keys = list(self.partition_points.keys())
         print ("keys", keys, "values", self.partition_points.values())
-        print ("mask", self.mask)
-        with m.Switch(Cat(self.mask)):
+        print ("ptype", self.ptype)
+        with m.Switch(self.ptype.get_switch()):
             # for each partition possibility, create a Cat sequence
-            for pbit in range(1<<len(keys)):
+            for pbit in self.ptype.get_cases():
                 # set up some indices pointing to where things have got
                 # then when called below in the inner nested loop they give
                 # the relevant sequential chunk
@@ -129,7 +128,7 @@ if __name__ == "__main__":
     a = PartitionedSignal(mask, 32)
     b = PartitionedSignal(mask, 16)
     catlist = [a, b]
-    m.submodules.cat = cat = PartitionedCat(catlist, mask)
+    m.submodules.cat = cat = PartitionedCat(catlist, a.ptype)
 
     traces = cat.ports()
     sim = create_simulator(m, traces, "partcat")
index f15bc869dabd464c5688ca93eb971eaea32082c7..6b4630decd061bb9ffd6acf046ac613e9238dac3 100644 (file)
@@ -4,10 +4,10 @@
 
 
 modcount = 0 # global for now
-def PCat(m, arglist, mask):
+def PCat(m, arglist, ctx):
     from ieee754.part_cat.cat import PartitionedCat # avoid recursive import
     global modcount
     modcount += 1
-    pc = PartitionedCat(arglist, mask)
+    pc = PartitionedCat(arglist, ctx)
     setattr(m.submodules, "pcat%d" % modcount, pc)
     return pc.output