switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_cat / cat.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "cat" class, directly equivalent
8 to nmigen Cat
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/cat
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=707
14
15 m.Switch()
16 for pbits cases: 0b000 to 0b111
17 output = []
18 # set up some yielders which will retain where they each got to
19 # then when called below in the inner nested loop they give
20 # the relevant sequential chunk
21 yielders = [Yielder(a), Yielder(b), ....]
22 runlist = split pbits into runs of zeros
23 for y in yielders: # for each signal a b c d ...
24 for i in runlist: # for each partition
25 for _ in range(i)+1: # for the length of each partition
26 thing = yield from y # grab sequential chunks
27 output.append(thing)
28 with m.Case(pbits):
29 comb += out.eq(Cat(*output)
30
31 """
32
33 from nmigen import Signal, Module, Elaboratable, Cat, C
34 from nmigen.back.pysim import Simulator, Settle
35
36 from ieee754.part_mul_add.partpoints import PartitionPoints
37 from ieee754.part.partsig import SimdSignal
38 from ieee754.part.test.test_partsig import create_simulator
39
40
41 def get_runlengths(pbit, size):
42 res = []
43 count = 1
44 # identify where the 1s are, which indicates "start of a new partition"
45 # we want a list of the lengths of all partitions
46 for i in range(size):
47 if pbit & (1<<i): # it's a 1: ends old partition, starts new
48 res.append(count) # add partition
49 count = 1 # start again
50 else:
51 count += 1
52 # end reached, add whatever is left. could have done this by creating
53 # "fake" extra bit on the partitions, but hey
54 res.append(count)
55
56 print ("get_runlengths", bin(pbit), size, res)
57
58 return res
59
60
61 class PartitionedCat(Elaboratable):
62 def __init__(self, catlist, ctx):
63 """Create a ``PartitionedCat`` operator
64 """
65 # work out the length (total of all SimdSignals)
66 self.catlist = catlist
67 self.ptype = ctx
68 width = 0
69 for p in catlist:
70 width += len(p.sig)
71 self.width = width
72 mask = ctx.get_mask()
73 self.output = SimdSignal(mask, self.width, reset_less=True)
74 # XXX errr... this is a bit of a hack, but should work
75 # obtain the module for the output Signal
76 self.output.set_module(ctx.psig.m)
77 self.partition_points = self.output.partpoints
78 self.mwidth = len(self.partition_points)+1
79
80 def set_lhs_mode(self, is_lhs):
81 """set an indication that this is a LHS mode
82 deliberately do not set self.is_lhs in the constructor
83 to a default value in order to detect when it is missing
84 """
85 self.is_lhs = is_lhs
86
87 def get_chunk(self, y, idx, numparts):
88 x = self.catlist[idx]
89 keys = [0] + list(x.partpoints.keys()) + [len(x.sig)]
90 # get current index and increment it (for next Cat chunk)
91 upto = y[idx]
92 y[idx] += numparts
93 print ("getting", idx, upto, numparts, keys, len(x.sig))
94 # get the partition point as far as we are up to
95 start = keys[upto]
96 end = keys[upto+numparts]
97 print ("start end", start, end, len(x.sig))
98 return x.sig[start:end]
99
100 def elaborate(self, platform):
101 print ("PartitionedCat start", self.is_lhs)
102 m = Module()
103 comb = m.d.comb
104
105 keys = list(self.partition_points.keys())
106 print ("keys", keys, "values", self.partition_points.values())
107 print ("ptype", self.ptype)
108 with m.Switch(self.ptype.get_switch()):
109 # for each partition possibility, create a Cat sequence
110 for pbit in self.ptype.get_cases():
111 # set up some indices pointing to where things have got
112 # then when called below in the inner nested loop they give
113 # the relevant sequential chunk
114 output = []
115 nelts = self.ptype.get_num_elements(pbit)
116 # get a list of the length of each partition run
117 #runlengths = get_runlengths(pbit, len(keys))
118 print ("pbit", bin(pbit), "nelts", nelts)
119 for i in range(nelts): # for each element
120 for x in self.catlist:
121 trange = x.ptype.get_el_range(pbit, i)
122 thing = x.sig[trange.start:trange.stop:trange.step]
123 output.append(thing)
124 with m.Case(pbit):
125 # direct access to the underlying Signal
126 if self.is_lhs:
127 comb += Cat(*output).eq(self.output.sig) # LHS mode
128 else:
129 comb += self.output.sig.eq(Cat(*output)) # RHS mode
130
131 print ("PartitionedCat end")
132 return m
133
134 def ports(self):
135 res = []
136 for p in self.catlist + [self.output]:
137 res.append(p.sig)
138 return res
139
140
141 if __name__ == "__main__":
142 m = Module()
143 mask = Signal(3)
144 a = SimdSignal(mask, 32)
145 b = SimdSignal(mask, 16)
146 a.set_module(m)
147 b.set_module(m)
148 catlist = [a, b]
149 m.submodules.cat = cat = PartitionedCat(catlist, a.ptype)
150 cat.set_lhs_mode(False)
151
152 traces = cat.ports()
153 sim = create_simulator(m, traces, "partcat")
154
155 def process():
156 yield mask.eq(0b000)
157 yield a.sig.eq(0x01234567)
158 yield b.sig.eq(0xfdbc)
159 yield Settle()
160 out = yield cat.output.sig
161 print("out 000", bin(out), hex(out))
162 yield mask.eq(0b010)
163 yield a.sig.eq(0x01234567)
164 yield b.sig.eq(0xfdbc)
165 yield Settle()
166 out = yield cat.output.sig
167 print("out 010", bin(out), hex(out))
168 yield mask.eq(0b110)
169 yield a.sig.eq(0x01234567)
170 yield b.sig.eq(0xfdbc)
171 yield Settle()
172 out = yield cat.output.sig
173 print("out 110", bin(out), hex(out))
174 yield mask.eq(0b111)
175 yield a.sig.eq(0x01234567)
176 yield b.sig.eq(0xfdbc)
177 yield Settle()
178 out = yield cat.output.sig
179 print("out 111", bin(out), hex(out))
180
181 sim.add_process(process)
182 with sim.write_vcd("partition_cat.vcd", "partition_cat.gtkw",
183 traces=traces):
184 sim.run()