230455092914d4a7ce82aca275463cb0caf94a5e
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
7
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
10
11 See:
12
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
15 """
16 from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
17 from ieee754.part_mul_add.partpoints import PartitionPoints
18 from ieee754.part_shift.bitrev import GatedBitReverse
19 import math
20
21 class ShifterMask(Elaboratable):
22 def __init__(self, pwid, bwid, max_bits, min_bits):
23 self.max_bits = max_bits
24 self.min_bits = min_bits
25 self.pwid = pwid
26 self.mask = Signal(bwid, reset_less=True)
27 if pwid != 0:
28 self.gates = Signal(pwid, reset_less=True)
29
30 def elaborate(self, platform):
31 m = Module()
32 comb = m.d.comb
33
34 minm = (1<<self.min_bits)-1
35 maxm = (1<<self.max_bits)-1
36
37 # zero-width mustn't try to do anything
38 if self.pwid == 0:
39 comb += self.mask.eq(minm)
40 return m
41
42 bits = Signal(self.pwid, reset_less=True)
43 bl = []
44 for j in range(self.pwid):
45 bit = Signal(self.pwid, name="bit%d" % j, reset_less=True)
46 comb += bit.eq(C(0, self.pwid))
47 if j != 0:
48 comb += bit.eq((~self.gates[j]) & bl[j-1])
49 else:
50 comb += bit.eq(~self.gates[j])
51 bl.append(bit)
52 # XXX ARGH, really annoying: simulation bug, can't use Cat(*bl).
53 for j in range(bits.shape()[0]):
54 comb += bits[j].eq(bl[j])
55 comb += self.mask.eq(C(0, self.mask.shape()))
56 comb += self.mask.eq(Cat(minm, bits) & C(maxm, self.mask.shape()))
57
58 return m
59
60
61 class PartialResult(Elaboratable):
62 def __init__(self, pwid, bwid, reswid):
63 self.pwid = pwid
64 self.bwid = bwid
65 self.reswid = reswid
66 self.element = Signal(bwid, reset_less=True)
67 self.elmux = Signal(bwid, reset_less=True)
68 self.a_interval = Signal(bwid, reset_less=True)
69 self.masked = Signal(bwid, reset_less=True)
70 self.gate = Signal(reset_less=True)
71 self.partial = Signal(reswid, reset_less=True)
72
73 def elaborate(self, platform):
74 m = Module()
75 comb = m.d.comb
76
77 shiftbits = math.ceil(math.log2(self.reswid+1))+1 # hmmm...
78 print ("partial", self.reswid, self.pwid, shiftbits)
79 element = Mux(self.gate, self.masked, self.element)
80 comb += self.elmux.eq(element)
81 element = self.elmux
82
83 # This calculates which partition of b to select the
84 # shifter from. According to the table above, the
85 # partition to select is given by the highest set bit in
86 # the partition mask, this calculates that with a mux
87 # chain
88
89 # This computes the partial results table. note that
90 # the shift amount is truncated because there's no point
91 # trying to shift data by 64 bits if the result width
92 # is only 8.
93 shifter = Signal(shiftbits, reset_less=True)
94 maxval = C(self.reswid, element.shape())
95 with m.If(element > maxval):
96 comb += shifter.eq(maxval)
97 with m.Else():
98 comb += shifter.eq(element)
99 comb += self.partial.eq(self.a_interval << shifter)
100
101 return m
102
103
104 class PartitionedDynamicShift(Elaboratable):
105 def __init__(self, width, partition_points):
106 self.width = width
107 self.partition_points = PartitionPoints(partition_points)
108
109 self.a = Signal(width, reset_less=True)
110 self.b = Signal(width, reset_less=True)
111 self.bitrev = Signal(reset_less=True)
112 self.output = Signal(width, reset_less=True)
113
114 def elaborate(self, platform):
115 m = Module()
116 comb = m.d.comb
117 width = self.width
118 pwid = self.partition_points.get_max_partition_count(width)-1
119 gates = Signal(pwid, reset_less=True)
120 comb += gates.eq(self.partition_points.as_sig())
121
122 matrix = []
123 keys = list(self.partition_points.keys()) + [self.width]
124 start = 0
125
126 m.submodules.a_br = a_br = GatedBitReverse(self.a.width)
127 comb += a_br.data.eq(self.a)
128 comb += a_br.reverse_en.eq(self.bitrev)
129
130 m.submodules.out_br = out_br = GatedBitReverse(self.output.width)
131 comb += out_br.reverse_en.eq(self.bitrev)
132 comb += self.output.eq(out_br.output)
133
134 m.submodules.gate_br = gate_br = GatedBitReverse(pwid)
135 comb += gate_br.data.eq(gates)
136 comb += gate_br.reverse_en.eq(self.bitrev)
137
138
139 # break out both the input and output into partition-stratified blocks
140 a_intervals = []
141 b_intervals = []
142 intervals = []
143 widths = []
144 start = 0
145 for i in range(len(keys)):
146 end = keys[i]
147 widths.append(width - start)
148 a_intervals.append(a_br.output[start:end])
149 b_intervals.append(self.b[start:end])
150 intervals.append([start,end])
151 start = end
152
153 min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
154
155 # shifts are normally done as (e.g. for 32 bit) result = a &
156 # (b&0b11111) truncating the b input. however here of course
157 # the size of the partition varies dynamically.
158 shifter_masks = []
159 for i in range(len(b_intervals)):
160 bwid = b_intervals[i].shape()[0]
161 bitwid = pwid-i
162 if bitwid == 0:
163 shifter_masks.append(C((1<<min_bits)-1, bwid))
164 continue
165 max_bits = math.ceil(math.log2(width-intervals[i][0]))
166 sm = ShifterMask(bitwid, bwid, max_bits, min_bits)
167 setattr(m.submodules, "sm%d" % i, sm)
168 if bitwid != 0:
169 comb += sm.gates.eq(gate_br.output[i:pwid])
170 shifter_masks.append(sm.mask)
171
172 print(shifter_masks)
173
174 # Instead of generating the matrix described in the wiki, I
175 # instead calculate the shift amounts for each partition, then
176 # calculate the partial results of each partition << shift
177 # amount. On the wiki, the following table is given for output #3:
178 # p2p1p0 | o3
179 # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
180 # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
181 # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
182 # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
183 # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
184 # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
185 # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
186 # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
187
188 # Each output for o3 is given by a3bx and the partial results
189 # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
190 # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
191 # those partial results to calculate a0, a1, a2, and a3
192 element = Signal(b_intervals[0].shape(), reset_less=True)
193 comb += element.eq(b_intervals[0] & shifter_masks[0])
194 partial_results = []
195 partial = Signal(width, name="partial0", reset_less=True)
196 comb += partial.eq(a_intervals[0] << element)
197 partial_results.append(partial)
198 for i in range(1, len(keys)):
199 reswid = width - intervals[i][0]
200 shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
201 print ("partial", reswid, width, intervals[i], shiftbits)
202 s, e = intervals[i]
203 pr = PartialResult(pwid, b_intervals[i].shape()[0], reswid)
204 setattr(m.submodules, "pr%d" % i, pr)
205 masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
206 reset_less=True)
207 comb += pr.masked.eq(b_intervals[i] & shifter_masks[i])
208 comb += pr.gate.eq(gate_br.output[i-1])
209 comb += pr.element.eq(element)
210 comb += pr.a_interval.eq(a_intervals[i])
211 partial_results.append(pr.partial)
212 element = pr.elmux
213
214 out = []
215
216 # This calculates the outputs o0-o3 from the partial results
217 # table above. Note: only relevant bits of the partial result equal
218 # to the width of the output column are accumulated in a Mux-cascade.
219 s,e = intervals[0]
220 result = partial_results[0]
221 out.append(result[s:e])
222 for i in range(1, len(keys)):
223 start, end = (intervals[i][0], width)
224 reswid = width - start
225 sel = Mux(gate_br.output[i-1], 0,
226 result[intervals[0][1]:][:end-start])
227 print("select: [%d:%d]" % (start, end))
228 res = Signal(end-start+1, name="res%d" % i, reset_less=True)
229 comb += res.eq(partial_results[i] | sel)
230 result = res
231 s,e = intervals[0]
232 out.append(res[s:e])
233
234 comb += out_br.data.eq(Cat(*out))
235
236 return m
237