96334e7cf34b2a4828291037dd81bc4e9f56f2b4
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
7
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
10
11 See:
12
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
15 """
16 from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
17 from ieee754.part_mul_add.partpoints import PartitionPoints
18 import math
19
20
21 class PartitionedDynamicShift(Elaboratable):
22 def __init__(self, width, partition_points):
23 self.width = width
24 self.partition_points = PartitionPoints(partition_points)
25
26 self.a = Signal(width, reset_less=True)
27 self.b = Signal(width, reset_less=True)
28 self.output = Signal(width, reset_less=True)
29
30 def elaborate(self, platform):
31 m = Module()
32 comb = m.d.comb
33 width = self.width
34 pwid = self.partition_points.get_max_partition_count(width)-1
35 gates = Signal(pwid, reset_less=True)
36 comb += gates.eq(self.partition_points.as_sig())
37
38 matrix = []
39 keys = list(self.partition_points.keys()) + [self.width]
40 start = 0
41
42 # break out both the input and output into partition-stratified blocks
43 a_intervals = []
44 b_intervals = []
45 intervals = []
46 widths = []
47 start = 0
48 for i in range(len(keys)):
49 end = keys[i]
50 widths.append(width - start)
51 a_intervals.append(self.a[start:end])
52 b_intervals.append(self.b[start:end])
53 intervals.append([start,end])
54 start = end
55
56 min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
57 max_bits = math.ceil(math.log2(width))
58
59 # shifts are normally done as (e.g. for 32 bit) result = a & (b&0b11111)
60 # truncating the b input. however here of course the size of the
61 # partition varies dynamically.
62 shifter_masks = []
63 for i in range(len(b_intervals)):
64 mask = Signal(b_intervals[i].shape(), name="shift_mask%d" % i,
65 reset_less=True)
66 bits = Signal(pwid-i, name="bits%d" % i, reset_less=True)
67 bl = []
68 for idx, j in enumerate(range(i, pwid)):
69 if idx != 0:
70 bl.append((~gates[j]) & bits[idx-1])
71 else:
72 bl.append(~gates[j])
73 # XXX ARGH, really annoying: simulation bug, can't use Cat(*bl).
74 for j in range(bits.shape()[0]):
75 comb += bits[j].eq(bl[j])
76 comb += mask.eq(Cat((1 << min_bits)-1, bits)
77 & ((1 << max_bits)-1))
78 shifter_masks.append(mask)
79
80 print(shifter_masks)
81
82 # Instead of generating the matrix described in the wiki, I
83 # instead calculate the shift amounts for each partition, then
84 # calculate the partial results of each partition << shift
85 # amount. On the wiki, the following table is given for output #3:
86 # p2p1p0 | o3
87 # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
88 # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
89 # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
90 # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
91 # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
92 # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
93 # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
94 # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
95
96 # Each output for o3 is given by a3bx and the partial results
97 # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
98 # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
99 # those partial results to calculate a0, a1, a2, and a3
100 element = b_intervals[0] & shifter_masks[0]
101 partial_results = []
102 partial = Signal(width, name="partial0", reset_less=True)
103 comb += partial.eq(a_intervals[0] << element)
104 partial_results.append(partial)
105 for i in range(1, len(keys)):
106 reswid = width - intervals[i][0]
107 shiftbits = math.ceil(math.log2(reswid+1))+1 # hmmm...
108 print ("partial", reswid, width, intervals[i], shiftbits)
109 s, e = intervals[i]
110 masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
111 reset_less=True)
112 comb += masked.eq(b_intervals[i] & shifter_masks[i])
113 element = Mux(gates[i-1], masked, element)
114 elmux = Signal(b_intervals[i].shape(), name="elmux%d" % i,
115 reset_less=True)
116 comb += elmux.eq(element)
117 element = elmux
118
119 # This calculates which partition of b to select the
120 # shifter from. According to the table above, the
121 # partition to select is given by the highest set bit in
122 # the partition mask, this calculates that with a mux
123 # chain
124
125 # This computes the partial results table. note that
126 # the shift amount is truncated because there's no point
127 # trying to shift data by 64 bits if the result width
128 # is only 8.
129 shifter = Signal(shiftbits, name="shifter%d" % i,
130 reset_less=True)
131 with m.If(element > shiftbits):
132 comb += shifter.eq(shiftbits)
133 with m.Else():
134 comb += shifter.eq(element)
135 comb += shifter.eq(element)
136 partial = Signal(reswid, name="partial%d" % i, reset_less=True)
137 comb += partial.eq(a_intervals[i] << shifter)
138
139 partial_results.append(partial)
140
141 out = []
142
143 # This calculates the outputs o0-o3 from the partial results
144 # table above. Note: only relevant bits of the partial result equal
145 # to the width of the output column are accumulated in a Mux-cascade.
146 s,e = intervals[0]
147 result = partial_results[0]
148 out.append(result[s:e])
149 for i in range(1, len(keys)):
150 start, end = (intervals[i][0], width)
151 reswid = width - start
152 sel = Mux(gates[i-1], 0, result[intervals[0][1]:][:end-start])
153 print("select: [%d:%d]" % (start, end))
154 res = Signal(end-start+1, name="res%d" % i, reset_less=True)
155 comb += res.eq(partial_results[i] | sel)
156 result = res
157 s,e = intervals[0]
158 out.append(res[s:e])
159
160 comb += self.output.eq(Cat(*out))
161
162 return m
163