limit matrix output width and limit shift amount
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
7
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
10
11 See:
12
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
15 """
16 from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
17 from ieee754.part_mul_add.partpoints import PartitionPoints
18 import math
19
20
21 class PartitionedDynamicShift(Elaboratable):
22 def __init__(self, width, partition_points):
23 self.width = width
24 self.partition_points = PartitionPoints(partition_points)
25
26 self.a = Signal(width)
27 self.b = Signal(width)
28 self.output = Signal(width)
29
30 def elaborate(self, platform):
31 m = Module()
32 comb = m.d.comb
33 width = self.width
34 gates = Signal(self.partition_points.get_max_partition_count(width)-1)
35 comb += gates.eq(self.partition_points.as_sig())
36
37 matrix = []
38 keys = list(self.partition_points.keys()) + [self.width]
39 start = 0
40
41 # create a matrix of partial shift-results (similar to PartitionedMul
42 # matrices). These however have to be of length suitable to contain
43 # the full shifted "contribution". i.e. B from the LSB *could* contain
44 # a number great enough to shift the entirety of A LSB right up to
45 # the MSB of the output, however B from the *MSB* is *only* going
46 # to contribute to the *MSB* of the output.
47 for i in range(len(keys)):
48 row = []
49 start = 0
50 for j in range(len(keys)):
51 end = keys[j]
52 row.append(Signal(width - start,
53 name="matrix[%d][%d]" % (i, j)))
54 start = end
55 matrix.append(row)
56
57 # break out both the input and output into partition-stratified blocks
58 a_intervals = []
59 b_intervals = []
60 out_intervals = []
61 intervals = []
62 start = 0
63 for i in range(len(keys)):
64 end = keys[i]
65 a_intervals.append(self.a[start:end])
66 b_intervals.append(self.b[start:end])
67 out_intervals.append(self.output[start:end])
68 intervals.append([start,end])
69 start = end
70
71 # actually calculate the shift-partials here
72 for i, b in enumerate(b_intervals):
73 start = 0
74 for j in range(i, len(a_intervals)):
75 a = a_intervals[j]
76 end = keys[i]
77 result_width = matrix[i][j].width
78 rw = math.ceil(math.log2(result_width + 1))
79 # XXX!
80 bw = math.ceil(math.log2(self.output.width + 1))
81 tshift = Signal(bw, name="ts%d_%d" % (i, j), reset_less=True)
82 with m.If(b[:bw] < 1<<rw)):
83 comb += tshift.eq(b[:bw])
84 with m.Else():
85 comb += tshift.eq(1<<rw)
86 comb += matrix[i][j].eq(a << tshift)
87 start = end
88
89 # now create a switch statement which sums the relevant partial results
90 # in each output-partition
91
92 out = []
93 intermed = matrix[0][0]
94 s, e = intervals[0]
95 out.append(intermed[s:e])
96 for i in range(1, len(out_intervals)):
97 s, e = intervals[i]
98 index = gates[:i] # selects the 'i' least significant bits
99 # of gates
100 element = Signal(width, name="element%d" % i)
101 for index in range(1<<i):
102 print(index)
103 with m.Switch(gates[:i]):
104 with m.Case(index):
105 index = math.ceil(math.log2(index + 1))
106 comb += element.eq(matrix[index][i])
107 print(keys[i-1])
108 temp = Signal(e-s+1, name="intermed%d" % i)
109 print(intermed[keys[0]:])
110 intermed = Mux(gates[i-1], element, element | intermed[keys[0]:])
111 comb += temp.eq(intermed)
112 out.append(intermed[:e-s])
113
114 comb += self.output.eq(Cat(*out))
115
116 return m
117