7347648370dc346a1e80ad642be82bdb3b19a401
[ieee754fpu.git] / src / ieee754 / part_repl / repl.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2021 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6
7 dynamically-partitionable "repl" class, directly equivalent
8 to nmigen Repl
9
10 See:
11
12 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/repl
13 * http://bugs.libre-riscv.org/show_bug.cgi?id=709
14
15 """
16
17 from nmigen import Signal, Module, Elaboratable, Cat, Repl
18 from nmigen.back.pysim import Simulator, Settle
19 from nmigen.cli import rtlil
20
21 from ieee754.part_mul_add.partpoints import PartitionPoints
22 from ieee754.part.partsig import SimdSignal
23
24
25 def get_runlengths(pbit, size):
26 res = []
27 count = 1
28 # identify where the 1s are, which indicates "start of a new partition"
29 # we want a list of the lengths of all partitions
30 for i in range(size):
31 if pbit & (1<<i): # it's a 1: ends old partition, starts new
32 res.append(count) # add partition
33 count = 1 # start again
34 else:
35 count += 1
36 # end reached, add whatever is left. could have done this by creating
37 # "fake" extra bit on the partitions, but hey
38 res.append(count)
39
40 print ("get_runlengths", bin(pbit), size, res)
41
42 return res
43
44
45 class PartitionedRepl(Elaboratable):
46 def __init__(self, repl, qty, ctx):
47 """Create a ``PartitionedRepl`` operator
48 """
49 # work out the length (total of all SimdSignals)
50 self.repl = repl
51 self.qty = qty
52 width, signed = repl.shape()
53 self.ptype = ctx
54 self.shape = (width * qty), signed
55 mask = ctx.get_mask()
56 self.output = SimdSignal(mask, self.shape, reset_less=True)
57 self.partition_points = self.output.partpoints
58 self.mwidth = len(self.partition_points)+1
59
60 def get_chunk(self, y, numparts):
61 x = self.repl
62 if not isinstance(x, SimdSignal):
63 # assume Scalar. totally different rules
64 end = numparts * (len(x) // self.mwidth)
65 return x[:end]
66 # SimdSignal: start at partition point
67 keys = [0] + list(x.partpoints.keys()) + [len(x)]
68 # get current index and increment it (for next Repl chunk)
69 upto = y[0]
70 y[0] += numparts
71 print ("getting", upto, numparts, keys, len(x))
72 # get the partition point as far as we are up to
73 start = keys[upto]
74 end = keys[upto+numparts]
75 print ("start end", start, end, len(x))
76 return x[start:end]
77
78 def elaborate(self, platform):
79 m = Module()
80 comb = m.d.comb
81
82 keys = list(self.partition_points.keys())
83 print ("keys", keys, "values", self.partition_points.values())
84 print ("ptype", self.ptype)
85 outpartsize = len(self.output) // self.mwidth
86 width, signed = self.output.shape()
87 print ("width, signed", width, signed)
88
89 with m.Switch(self.ptype.get_switch()):
90 # for each partition possibility, create a Repl sequence
91 for pbit in self.ptype.get_cases():
92 # set up some indices pointing to where things have got
93 # then when called below in the inner nested loop they give
94 # the relevant sequential chunk
95 output = []
96 y = [0]
97 # get a list of the length of each partition run
98 runlengths = get_runlengths(pbit, len(keys))
99 print ("pbit", bin(pbit), "runs", runlengths)
100 for i in runlengths: # for each partition
101 thing = self.get_chunk(y, i) # get sequential chunk
102 output.append(Repl(thing, self.qty)) # and replicate it
103 with m.Case(pbit):
104 # direct access to the underlying Signal
105 comb += self.output.sig.eq(Cat(*output)) # cat all chunks
106
107 return m
108
109 def ports(self):
110 if isinstance(self.repl, SimdSignal):
111 return [self.repl.lower(), self.output.lower()]
112 return [self.repl, self.output.lower()]
113
114
115 if __name__ == "__main__":
116 from ieee754.part.test.test_partsig import create_simulator
117 m = Module()
118 mask = Signal(3)
119 a = SimdSignal(mask, 32)
120 print ("a.ptype", a.ptype)
121 m.submodules.repl = repl = PartitionedRepl(a, 2, a.ptype)
122 omask = (1<<len(repl.output))-1
123
124 traces = repl.ports()
125 vl = rtlil.convert(repl, ports=traces)
126 with open("part_repl.il", "w") as f:
127 f.write(vl)
128
129 sim = create_simulator(m, traces, "partrepl")
130
131 def process():
132 yield mask.eq(0b000)
133 yield a.sig.eq(0xa12345c7)
134 yield Settle()
135 out = yield repl.output.sig
136 print("out 000", bin(out&omask), hex(out&omask))
137 yield mask.eq(0b010)
138 yield Settle()
139 out = yield repl.output.sig
140 print("out 010", bin(out&omask), hex(out&omask))
141 yield mask.eq(0b110)
142 yield Settle()
143 out = yield repl.output.sig
144 print("out 110", bin(out&omask), hex(out&omask))
145 yield mask.eq(0b111)
146 yield Settle()
147 out = yield repl.output.sig
148 print("out 111", bin(out&omask), hex(out&omask))
149
150 sim.add_process(process)
151 with sim.write_vcd("partition_repl.vcd", "partition_repl.gtkw",
152 traces=traces):
153 sim.run()
154
155 # Scalar
156 m = Module()
157 class PartType:
158 def __init__(self, mask):
159 self.mask = mask
160 def get_mask(self):
161 return mask
162 def get_switch(self):
163 return Cat(self.get_mask())
164 def get_cases(self):
165 return range(1<<len(self.get_mask()))
166 @property
167 def blanklanes(self):
168 return 0
169
170 mask = Signal(3)
171 ptype = PartType(mask)
172 a = Signal(32)
173 m.submodules.ass = ass = PartitionedRepl(a, 2, ptype)
174 omask = (1<<len(ass.output))-1
175
176 traces = ass.ports()
177 sim = create_simulator(m, traces, "partass")
178
179 def process():
180 yield mask.eq(0b000)
181 yield a.eq(0xa12345c7)
182 yield Settle()
183 out = yield ass.output.sig
184 print("out 000", bin(out&omask), hex(out&omask))
185 yield mask.eq(0b010)
186 yield Settle()
187 out = yield ass.output.sig
188 print("out 010", bin(out&omask), hex(out&omask))
189 yield mask.eq(0b110)
190 yield Settle()
191 out = yield ass.output.sig
192 print("out 110", bin(out&omask), hex(out&omask))
193 yield mask.eq(0b111)
194 yield Settle()
195 out = yield ass.output.sig
196 print("out 111", bin(out&omask), hex(out&omask))
197
198 sim.add_process(process)
199 with sim.write_vcd("partition_repl_scalar.vcd",
200 "partition_repl_scalar.gtkw",
201 traces=traces):
202 sim.run()