convert PartitionedRepl over to new "PartType" format
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 9 Oct 2021 16:22:36 +0000 (17:22 +0100)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sat, 9 Oct 2021 16:22:36 +0000 (17:22 +0100)
src/ieee754/part/partsig.py
src/ieee754/part/test/test_partsig.py
src/ieee754/part_repl/prepl.py
src/ieee754/part_repl/repl.py

index 880ea18b310ef227c63198118fafdb712e03cc88..d02dac4a1691b8320eb826ec7539d4648edd0902 100644 (file)
@@ -30,7 +30,7 @@ from ieee754.part_cat.pcat import PCat
 from ieee754.part_repl.prepl import PRepl
 from operator import or_, xor, and_, not_
 
-from nmigen import (Signal, Const)
+from nmigen import (Signal, Const, Cat)
 from nmigen.hdl.ast import UserValue, Shape
 
 
@@ -56,6 +56,21 @@ for name in ['add', 'eq', 'gt', 'ge', 'ls', 'xor', 'bool', 'all']:
     modnames[name] = 0
 
 
+
+class PartType: # TODO decide name
+    def __init__(self, psig):
+        self.psig = psig
+    def get_mask(self):
+        return list(self.psig.partpoints.values())
+    def get_switch(self):
+        return Cat(self.get_mask())
+    def get_cases(self):
+        return range(1<<len(self.get_mask()))
+    @property
+    def blanklanes(self):
+        return 0
+
+
 class PartitionedSignal(UserValue):
     # XXX ################################################### XXX
     # XXX Keep these functions in the same order as ast.Value XXX
@@ -69,7 +84,7 @@ class PartitionedSignal(UserValue):
             self.partpoints = mask
         else:
             self.partpoints = make_partition2(mask, width)
-
+        self.ptype = PartType(self)
 
     def set_module(self, m):
         self.m = m
@@ -96,7 +111,7 @@ class PartitionedSignal(UserValue):
     #def __Part__(self, offset, width, stride=1, *, src_loc_at=0):
 
     def __Repl__(self, count, *, src_loc_at=0):
-        return PRepl(self.m, self, count, self.partpoints)
+        return PRepl(self.m, self, count, self.ptype)
 
     def __Cat__(self, *args, src_loc_at=0):
         args = [self] + list(args)
index a1f44fc6361eb8e53b2bb872724f1ea9bfcbb69c..efcc5e8c541082f032692bfc0152724d6c9c8ec5 100644 (file)
@@ -484,7 +484,6 @@ class TestRepl(unittest.TestCase):
                             list(map(hex, apart)), list(map(hex, bpart)))
 
                     yield module.a.lower().eq(a)
-                    yield module.b.lower().eq(b)
                     yield Delay(0.1e-6)
 
                     y = 0
index b2b649543783263569b5c495e725b74b32ee5e14..589954d29413d8b3e2b102a0ee13b7d5a3952229 100644 (file)
@@ -17,11 +17,11 @@ See:
 
 
 modcount = 0 # global for now
-def PRepl(m, repl, qty, mask):
+def PRepl(m, repl, qty, ctx):
     from ieee754.part_repl.repl import PartitionedRepl # recursion issue
     global modcount
     modcount += 1
-    pc = PartitionedRepl(repl, qty, mask)
+    pc = PartitionedRepl(repl, qty, ctx)
     setattr(m.submodules, "repl%d" % modcount, pc)
     return pc.output
 
index 364b7721c9ec896e9e5d0cbf133404f3631ad059..3f372e6ecdb34edcf8a1f69a957cc88ab466e311 100644 (file)
@@ -43,17 +43,16 @@ def get_runlengths(pbit, size):
 
 
 class PartitionedRepl(Elaboratable):
-    def __init__(self, repl, qty, mask):
+    def __init__(self, repl, qty, ctx):
         """Create a ``PartitionedRepl`` operator
         """
         # work out the length (total of all PartitionedSignals)
         self.repl = repl
         self.qty = qty
         width, signed = repl.shape()
-        if isinstance(mask, dict):
-            mask = list(mask.values())
-        self.mask = mask
+        self.ptype = ctx
         self.shape = (width * qty), signed
+        mask = ctx.get_mask()
         self.output = PartitionedSignal(mask, self.shape, reset_less=True)
         self.partition_points = self.output.partpoints
         self.mwidth = len(self.partition_points)+1
@@ -82,14 +81,14 @@ class PartitionedRepl(Elaboratable):
 
         keys = list(self.partition_points.keys())
         print ("keys", keys, "values", self.partition_points.values())
-        print ("mask", self.mask)
+        print ("ptype", self.ptype)
         outpartsize = len(self.output) // self.mwidth
         width, signed = self.output.shape()
         print ("width, signed", width, signed)
 
-        with m.Switch(Cat(self.mask)):
+        with m.Switch(self.ptype.get_switch()):
             # for each partition possibility, create a Repl sequence
-            for pbit in range(1<<len(keys)):
+            for pbit in self.ptype.get_cases():
                 # set up some indices pointing to where things have got
                 # then when called below in the inner nested loop they give
                 # the relevant sequential chunk
@@ -118,7 +117,8 @@ if __name__ == "__main__":
     m = Module()
     mask = Signal(3)
     a = PartitionedSignal(mask, 32)
-    m.submodules.repl = repl = PartitionedRepl(a, 2, mask)
+    print ("a.ptype", a.ptype)
+    m.submodules.repl = repl = PartitionedRepl(a, 2, a.ptype)
     omask = (1<<len(repl.output))-1
 
     traces = repl.ports()
@@ -154,9 +154,23 @@ if __name__ == "__main__":
 
     # Scalar
     m = Module()
+    class PartType:
+        def __init__(self, mask):
+            self.mask = mask
+        def get_mask(self):
+            return mask
+        def get_switch(self):
+            return Cat(self.get_mask())
+        def get_cases(self):
+            return range(1<<len(self.get_mask()))
+        @property
+        def blanklanes(self):
+            return 0
+
     mask = Signal(3)
+    ptype = PartType(mask)
     a = Signal(32)
-    m.submodules.ass = ass = PartitionedRepl(a, 2, mask)
+    m.submodules.ass = ass = PartitionedRepl(a, 2, ptype)
     omask = (1<<len(ass.output))-1
 
     traces = ass.ports()