switch to exact version of cython
[ieee754fpu.git] / src / ieee754 / part_repl / repl.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "repl" class, directly equivalent
8 to nmigen Repl
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/repl
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=709
14
15 """
16
17 from nmigen import Signal, Module, Elaboratable, Cat, Repl
18 from nmigen.back.pysim import Simulator, Settle
19 from nmigen.cli import rtlil
20
21 from ieee754.part_mul_add.partpoints import PartitionPoints
22 from ieee754.part.partsig import SimdSignal
23
24
25 def get_runlengths(pbit, size):
26 res = []
27 count = 1
28 # identify where the 1s are, which indicates "start of a new partition"
29 # we want a list of the lengths of all partitions
30 for i in range(size):
31 if pbit & (1<<i): # it's a 1: ends old partition, starts new
32 res.append(count) # add partition
33 count = 1 # start again
34 else:
35 count += 1
36 # end reached, add whatever is left. could have done this by creating
37 # "fake" extra bit on the partitions, but hey
38 res.append(count)
39
40 print ("get_runlengths", bin(pbit), size, res)
41
42 return res
43
44
45 class PartitionedRepl(Elaboratable):
46 def __init__(self, repl, qty, ctx):
47 """Create a ``PartitionedRepl`` operator
48 """
49 # work out the length (total of all SimdSignals)
50 self.repl = repl
51 self.qty = qty
52 width, signed = repl.shape()
53 self.ptype = ctx
54 self.shape = (width * qty), signed
55 mask = ctx.get_mask()
56 self.output = SimdSignal(mask, self.shape, reset_less=True)
57 self.partition_points = self.output.partpoints
58 self.mwidth = len(self.partition_points)+1
59
60 def get_chunk(self, y, numparts):
61 x = self.repl
62 if not isinstance(x, SimdSignal):
63 # assume Scalar. totally different rules
64 end = numparts * (len(x) // self.mwidth)
65 return x[:end]
66 # SimdSignal: start at partition point
67 keys = [0] + list(x.partpoints.keys()) + [len(x)]
68 # get current index and increment it (for next Repl chunk)
69 upto = y[0]
70 y[0] += numparts
71 print ("getting", upto, numparts, keys, len(x))
72 # get the partition point as far as we are up to
73 start = keys[upto]
74 end = keys[upto+numparts]
75 print ("start end", start, end, len(x))
76 # access the underlying Signal of SimdSignal directly
77 return x.sig[start:end]
78
79 def elaborate(self, platform):
80 m = Module()
81 comb = m.d.comb
82
83 keys = list(self.partition_points.keys())
84 print ("keys", keys, "values", self.partition_points.values())
85 print ("ptype", self.ptype)
86 outpartsize = len(self.output) // self.mwidth
87 width, signed = self.output.shape()
88 print ("width, signed", width, signed)
89
90 with m.Switch(self.ptype.get_switch()):
91 # for each partition possibility, create a Repl sequence
92 for pbit in self.ptype.get_cases():
93 # set up some indices pointing to where things have got
94 # then when called below in the inner nested loop they give
95 # the relevant sequential chunk
96 output = []
97 y = [0]
98 # get a list of the length of each partition run
99 runlengths = get_runlengths(pbit, len(keys))
100 print ("pbit", bin(pbit), "runs", runlengths)
101 for i in runlengths: # for each partition
102 thing = self.get_chunk(y, i) # get sequential chunk
103 output.append(Repl(thing, self.qty)) # and replicate it
104 with m.Case(pbit):
105 # direct access to the underlying Signal
106 comb += self.output.sig.eq(Cat(*output)) # cat all chunks
107
108 return m
109
110 def ports(self):
111 if isinstance(self.repl, SimdSignal):
112 return [self.repl.lower(), self.output.lower()]
113 return [self.repl, self.output.lower()]
114
115
116 if __name__ == "__main__":
117 from ieee754.part.test.test_partsig import create_simulator
118 m = Module()
119 mask = Signal(3)
120 a = SimdSignal(mask, 32)
121 print ("a.ptype", a.ptype)
122 m.submodules.repl = repl = PartitionedRepl(a, 2, a.ptype)
123 omask = (1<<len(repl.output))-1
124
125 traces = repl.ports()
126 vl = rtlil.convert(repl, ports=traces)
127 with open("part_repl.il", "w") as f:
128 f.write(vl)
129
130 sim = create_simulator(m, traces, "partrepl")
131
132 def process():
133 yield mask.eq(0b000)
134 yield a.sig.eq(0xa12345c7)
135 yield Settle()
136 out = yield repl.output.sig
137 print("out 000", bin(out&omask), hex(out&omask))
138 yield mask.eq(0b010)
139 yield Settle()
140 out = yield repl.output.sig
141 print("out 010", bin(out&omask), hex(out&omask))
142 yield mask.eq(0b110)
143 yield Settle()
144 out = yield repl.output.sig
145 print("out 110", bin(out&omask), hex(out&omask))
146 yield mask.eq(0b111)
147 yield Settle()
148 out = yield repl.output.sig
149 print("out 111", bin(out&omask), hex(out&omask))
150
151 sim.add_process(process)
152 with sim.write_vcd("partition_repl.vcd", "partition_repl.gtkw",
153 traces=traces):
154 sim.run()
155
156 # Scalar
157 m = Module()
158 class PartType:
159 def __init__(self, mask):
160 self.mask = mask
161 def get_mask(self):
162 return mask
163 def get_switch(self):
164 return Cat(self.get_mask())
165 def get_cases(self):
166 return range(1<<len(self.get_mask()))
167 @property
168 def blanklanes(self):
169 return 0
170
171 mask = Signal(3)
172 ptype = PartType(mask)
173 a = Signal(32)
174 m.submodules.ass = ass = PartitionedRepl(a, 2, ptype)
175 omask = (1<<len(ass.output))-1
176
177 traces = ass.ports()
178 sim = create_simulator(m, traces, "partass")
179
180 def process():
181 yield mask.eq(0b000)
182 yield a.eq(0xa12345c7)
183 yield Settle()
184 out = yield ass.output.sig
185 print("out 000", bin(out&omask), hex(out&omask))
186 yield mask.eq(0b010)
187 yield Settle()
188 out = yield ass.output.sig
189 print("out 010", bin(out&omask), hex(out&omask))
190 yield mask.eq(0b110)
191 yield Settle()
192 out = yield ass.output.sig
193 print("out 110", bin(out&omask), hex(out&omask))
194 yield mask.eq(0b111)
195 yield Settle()
196 out = yield ass.output.sig
197 print("out 111", bin(out&omask), hex(out&omask))
198
199 sim.add_process(process)
200 with sim.write_vcd("partition_repl_scalar.vcd",
201 "partition_repl_scalar.gtkw",
202 traces=traces):
203 sim.run()