store intermediate in temp, append that to output
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
3
4 """
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
7
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
10
11 See:
12
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
15 """
16 from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
17 from ieee754.part_mul_add.partpoints import PartitionPoints
18 import math
19
20
21 class PartitionedDynamicShift(Elaboratable):
22 def __init__(self, width, partition_points):
23 self.width = width
24 self.partition_points = PartitionPoints(partition_points)
25
26 self.a = Signal(width)
27 self.b = Signal(width)
28 self.output = Signal(width)
29
30 def elaborate(self, platform):
31 m = Module()
32 comb = m.d.comb
33 width = self.width
34 gates = Signal(self.partition_points.get_max_partition_count(width)-1)
35 comb += gates.eq(self.partition_points.as_sig())
36
37 matrix = []
38 keys = list(self.partition_points.keys()) + [self.width]
39 start = 0
40
41 # create a matrix of partial shift-results (similar to PartitionedMul
42 # matrices). These however have to be of length suitable to contain
43 # the full shifted "contribution". i.e. B from the LSB *could* contain
44 # a number great enough to shift the entirety of A LSB right up to
45 # the MSB of the output, however B from the *MSB* is *only* going
46 # to contribute to the *MSB* of the output.
47 for i in range(len(keys)):
48 row = []
49 start = 0
50 for j in range(len(keys)):
51 end = keys[j]
52 row.append(Signal(width - start,
53 name="matrix[%d][%d]" % (i, j)))
54 start = end
55 matrix.append(row)
56
57 # break out both the input and output into partition-stratified blocks
58 a_intervals = []
59 b_intervals = []
60 out_intervals = []
61 intervals = []
62 start = 0
63 for i in range(len(keys)):
64 end = keys[i]
65 a_intervals.append(self.a[start:end])
66 b_intervals.append(self.b[start:end])
67 out_intervals.append(self.output[start:end])
68 intervals.append([start,end])
69 start = end
70
71 # actually calculate the shift-partials here
72 for i, b in enumerate(b_intervals):
73 start = 0
74 for j in range(i, len(a_intervals)):
75 a = a_intervals[j]
76 end = keys[i]
77 result_width = matrix[i][j].width
78 rw = math.ceil(math.log2(result_width + 1))
79 # XXX!
80 bw = math.ceil(math.log2(self.output.width + 1))
81 tshift = Signal(bw, name="ts%d_%d" % (i, j), reset_less=True)
82 ow = math.ceil(math.log2(width-start))
83 maxshift = (1<<(ow))
84 print ("part", i, b, j, a, rw, bw, ow, maxshift)
85 with m.If(b[:bw] < maxshift):
86 comb += tshift.eq(b[:bw])
87 with m.Else():
88 comb += tshift.eq(maxshift)
89 comb += matrix[i][j].eq(a << tshift)
90 start = end
91
92 # now create a switch statement which sums the relevant partial results
93 # in each output-partition
94
95 out = []
96 intermed = matrix[0][0]
97 s, e = intervals[0]
98 out.append(intermed[s:e])
99 for i in range(1, len(out_intervals)):
100 s, e = intervals[i]
101 index = gates[:i] # selects the 'i' least significant bits
102 # of gates
103 element = Signal(width, name="element%d" % i)
104 for index in range(1<<i):
105 print(index)
106 with m.Switch(gates[:i]):
107 with m.Case(index):
108 index = math.ceil(math.log2(index + 1))
109 comb += element.eq(matrix[index][i])
110 print(keys[i-1])
111 temp = Signal(e-s+1, name="intermed%d" % i)
112 print(intermed[keys[0]:])
113 intermed = Mux(gates[i-1], element, element | intermed[keys[0]:])
114 comb += temp.eq(intermed)
115 out.append(temp[:e-s])
116
117 comb += self.output.eq(Cat(*out))
118
119 return m
120