more fun comments
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
7
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
10
11 See:
12
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
15 """
16 from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
17 from ieee754.part_mul_add.partpoints import PartitionPoints
18 from ieee754.part_shift.bitrev import GatedBitReverse
19 import math
20
21 class ShifterMask(Elaboratable):
22
23 def __init__(self, pwid, bwid, max_bits, min_bits):
24 self.max_bits = max_bits
25 self.min_bits = min_bits
26 self.pwid = pwid
27 self.mask = Signal(bwid, reset_less=True)
28 if pwid != 0:
29 self.gates = Signal(pwid, reset_less=True)
30
31 def elaborate(self, platform):
32 m = Module()
33 comb = m.d.comb
34
35 minm = (1<<self.min_bits)-1
36 maxm = (1<<self.max_bits)-1
37
38 # zero-width mustn't try to do anything
39 if self.pwid == 0:
40 comb += self.mask.eq(minm)
41 return m
42
43 # create bit-cascade
44 bits = Signal(self.pwid, reset_less=True)
45 bl = []
46 for j in range(self.pwid):
47 bit = Signal(self.pwid, name="bit%d" % j, reset_less=True)
48 comb += bit.eq(C(0, self.pwid))
49 if j != 0:
50 comb += bit.eq((~self.gates[j]) & bl[j-1])
51 else:
52 comb += bit.eq(~self.gates[j])
53 bl.append(bit)
54
55 # XXX ARGH, really annoying: simulation bug, can't use Cat(*bl).
56 for j in range(bits.shape()[0]):
57 comb += bits[j].eq(bl[j])
58 comb += self.mask.eq(C(0, self.mask.shape()))
59 comb += self.mask.eq(Cat(minm, bits) & C(maxm, self.mask.shape()))
60
61 return m
62
63
64 class PartialResult(Elaboratable):
65 def __init__(self, pwid, bwid, reswid):
66 self.pwid = pwid
67 self.bwid = bwid
68 self.reswid = reswid
69 self.b = Signal(bwid, reset_less=True)
70 self.a_interval = Signal(bwid, reset_less=True)
71 self.gate = Signal(reset_less=True)
72 self.partial = Signal(reswid, reset_less=True)
73
74 def elaborate(self, platform):
75 m = Module()
76 comb = m.d.comb
77
78 shiftbits = math.ceil(math.log2(self.reswid+1))+1 # hmmm...
79 print ("partial", self.reswid, self.pwid, shiftbits)
80 element = self.b
81
82 # This calculates which partition of b to select the
83 # shifter from. According to the table above, the
84 # partition to select is given by the highest set bit in
85 # the partition mask, this calculates that with a mux
86 # chain
87
88 # This computes the partial results table. note that
89 # the shift amount is truncated because there's no point
90 # trying to shift data by 64 bits if the result width
91 # is only 8.
92 shifter = Signal(shiftbits, reset_less=True)
93 maxval = C(self.reswid, element.shape())
94 with m.If(element > maxval):
95 comb += shifter.eq(maxval)
96 with m.Else():
97 comb += shifter.eq(element)
98 comb += self.partial.eq(self.a_interval << shifter)
99
100 return m
101
102
103 class PartitionedDynamicShift(Elaboratable):
104
105 def __init__(self, width, partition_points):
106 self.width = width
107 self.partition_points = PartitionPoints(partition_points)
108
109 self.a = Signal(width, reset_less=True)
110 self.b = Signal(width, reset_less=True)
111 self.shift_right = Signal(reset_less=True)
112 self.output = Signal(width, reset_less=True)
113
114 def elaborate(self, platform):
115 m = Module()
116
117 # temporaries
118 comb = m.d.comb
119 width = self.width
120 pwid = self.partition_points.get_max_partition_count(width)-1
121 gates = Signal(pwid, reset_less=True)
122 comb += gates.eq(self.partition_points.as_sig())
123
124 matrix = []
125 keys = list(self.partition_points.keys()) + [self.width]
126 start = 0
127
128 # create gated-reversed versions of a, b and the output
129 # left-shift is non-reversed, right-shift is reversed
130 m.submodules.a_br = a_br = GatedBitReverse(self.a.width)
131 comb += a_br.data.eq(self.a)
132 comb += a_br.reverse_en.eq(self.shift_right)
133
134 m.submodules.out_br = out_br = GatedBitReverse(self.output.width)
135 comb += out_br.reverse_en.eq(self.shift_right)
136 comb += self.output.eq(out_br.output)
137
138 m.submodules.gate_br = gate_br = GatedBitReverse(pwid)
139 comb += gate_br.data.eq(gates)
140 comb += gate_br.reverse_en.eq(self.shift_right)
141
142 # break out both the input and output into partition-stratified blocks
143 a_intervals = []
144 b_intervals = []
145 intervals = []
146 widths = []
147 start = 0
148 for i in range(len(keys)):
149 end = keys[i]
150 widths.append(width - start)
151 a_intervals.append(a_br.output[start:end])
152 b_intervals.append(self.b[start:end])
153 intervals.append([start,end])
154 start = end
155
156 min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
157
158 # shifts are normally done as (e.g. for 32 bit) result = a &
159 # (b&0b11111) truncating the b input. however here of course
160 # the size of the partition varies dynamically.
161 shifter_masks = []
162 for i in range(len(b_intervals)):
163 bwid = b_intervals[i].shape()[0]
164 bitwid = pwid-i
165 if bitwid == 0:
166 shifter_masks.append(C((1<<min_bits)-1, bwid))
167 continue
168 max_bits = math.ceil(math.log2(width-intervals[i][0]))
169 sm = ShifterMask(bitwid, bwid, max_bits, min_bits)
170 setattr(m.submodules, "sm%d" % i, sm)
171 if bitwid != 0:
172 comb += sm.gates.eq(gates[i:pwid])
173 shifter_masks.append(sm.mask)
174
175 print(shifter_masks)
176
177 # Instead of generating the matrix described in the wiki, I
178 # instead calculate the shift amounts for each partition, then
179 # calculate the partial results of each partition << shift
180 # amount. On the wiki, the following table is given for output #3:
181 # p2p1p0 | o3
182 # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
183 # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
184 # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
185 # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
186 # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
187 # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
188 # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
189 # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
190
191 # Each output for o3 is given by a3bx and the partial results
192 # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
193 # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
194 # those partial results to calculate a0, a1, a2, and a3
195
196 masked_b = []
197 for i in range(0, len(keys)):
198 masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
199 reset_less=True)
200 comb += masked.eq(b_intervals[i] & shifter_masks[i])
201 masked_b.append(masked)
202 b_shl_amount = []
203 element = Signal(b_intervals[0].shape(), reset_less=True)
204 comb += element.eq(masked_b[0])
205 b_shl_amount.append(element)
206 for i in range(1, len(keys)):
207 element = Mux(gates[i-1], masked_b[i], element)
208 b_shl_amount.append(element) # FIXME: creates an O(N^2) cascade
209 # TODO: store result of Mux in a Signal of the correct width
210 # then append *that* into b_shl_amount
211
212 # because the right-shift input is reversed, we have to also
213 # reverse the *order* of the shift amounts (not the bits *in* the
214 # shift amounts)
215 b_shr_amount = list(reversed(b_shl_amount))
216
217 # select shift-amount (b) for partition based on op being left or right
218 shift_amounts = []
219 for i in range(len(b_shl_amount)):
220 # FIXME: all signals (with very few exceptions) have to be
221 # reset_less
222 shift_amount = Signal(masked_b[i].width, name="shift_amount%d" % i)
223 sel = Mux(self.shift_right, b_shr_amount[i], b_shl_amount[i])
224 comb += shift_amount.eq(sel)
225 shift_amounts.append(shift_amount)
226
227 # now calculate partial results
228
229 # first item (simple)
230 partial_results = []
231 partial = Signal(width, name="partial0", reset_less=True)
232 comb += partial.eq(a_intervals[0] << shift_amounts[0])
233 partial_results.append(partial)
234
235 # rest of list
236 for i in range(1, len(keys)):
237 reswid = width - intervals[i][0]
238 shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
239 print ("partial", reswid, width, intervals[i], shiftbits)
240 s, e = intervals[i]
241 pr = PartialResult(pwid, b_intervals[i].shape()[0], reswid)
242 setattr(m.submodules, "pr%d" % i, pr)
243 comb += pr.gate.eq(gate_br.output[i-1])
244 comb += pr.b.eq(shift_amounts[i])
245 comb += pr.a_interval.eq(a_intervals[i])
246 partial_results.append(pr.partial)
247
248 # This calculates the outputs o0-o3 from the partial results
249 # table above. Note: only relevant bits of the partial result equal
250 # to the width of the output column are accumulated in a Mux-cascade.
251 out = []
252 s,e = intervals[0]
253 result = partial_results[0]
254 out.append(result[s:e])
255 for i in range(1, len(keys)):
256 start, end = (intervals[i][0], width)
257 reswid = width - start
258 sel = Mux(gate_br.output[i-1], 0,
259 result[intervals[0][1]:][:end-start])
260 print("select: [%d:%d]" % (start, end))
261 res = Signal(end-start+1, name="res%d" % i, reset_less=True)
262 comb += res.eq(partial_results[i] | sel)
263 result = res
264 s,e = intervals[0]
265 out.append(res[s:e])
266
267 comb += out_br.data.eq(Cat(*out))
268
269 return m