1 # SPDX-License-Identifier: LGPL-2.1-or-later
2 # See Notices.txt for copyright information
5 Copyright (C) 2020 Luke Kenneth Casson Leighton <lkcl@lkcl.net>
6 Copyright (C) 2020 Michael Nolan <mtnolan2640@gmail.com>
8 dynamically partitionable shifter. Unlike part_shift_scalar, both
9 operands can be partitioned
13 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
14 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
16 from nmigen
import Signal
, Module
, Elaboratable
, Cat
, Mux
, C
17 from ieee754
.part_mul_add
.partpoints
import PartitionPoints
21 class PartitionedDynamicShift(Elaboratable
):
22 def __init__(self
, width
, partition_points
):
24 self
.partition_points
= PartitionPoints(partition_points
)
26 self
.a
= Signal(width
, reset_less
=True)
27 self
.b
= Signal(width
, reset_less
=True)
28 self
.output
= Signal(width
, reset_less
=True)
30 def elaborate(self
, platform
):
34 pwid
= self
.partition_points
.get_max_partition_count(width
)-1
35 gates
= Signal(pwid
, reset_less
=True)
36 comb
+= gates
.eq(self
.partition_points
.as_sig())
39 keys
= list(self
.partition_points
.keys()) + [self
.width
]
42 # break out both the input and output into partition-stratified blocks
48 for i
in range(len(keys
)):
50 widths
.append(width
- start
)
51 a_intervals
.append(self
.a
[start
:end
])
52 b_intervals
.append(self
.b
[start
:end
])
53 intervals
.append([start
,end
])
56 min_bits
= math
.ceil(math
.log2(intervals
[0][1] - intervals
[0][0]))
57 max_bits
= math
.ceil(math
.log2(width
))
59 # shifts are normally done as (e.g. for 32 bit) result = a & (b&0b11111)
60 # truncating the b input. however here of course the size of the
61 # partition varies dynamically.
63 for i
in range(len(b_intervals
)):
64 mask
= Signal(b_intervals
[i
].shape(), name
="shift_mask%d" % i
,
66 bits
= Signal(gates
.width
-i
+1, name
="bits%d" % i
, reset_less
=True)
68 for j
in range(i
, gates
.width
):
70 bl
.append(~gates
[j
] & bits
[j
-i
-1])
73 comb
+= bits
.eq(Cat(*bl
))
74 comb
+= mask
.eq(Cat((1 << min_bits
)-1, bits
)
75 & ((1 << max_bits
)-1))
76 shifter_masks
.append(mask
)
80 # Instead of generating the matrix described in the wiki, I
81 # instead calculate the shift amounts for each partition, then
82 # calculate the partial results of each partition << shift
83 # amount. On the wiki, the following table is given for output #3:
85 # 0 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
86 # 0 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
87 # 0 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
88 # 0 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
89 # 1 0 0 | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
90 # 1 0 1 | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
91 # 1 1 0 | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
92 # 1 1 1 | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
94 # Each output for o3 is given by a3bx and the partial results
95 # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
96 # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
97 # those partial results to calculate a0, a1, a2, and a3
98 element
= b_intervals
[0] & shifter_masks
[0]
100 partial
= Signal(width
, name
="partial0", reset_less
=True)
101 comb
+= partial
.eq(a_intervals
[0] << element
)
102 partial_results
.append(partial
)
103 for i
in range(1, len(keys
)):
104 reswid
= width
- intervals
[i
][0]
105 shiftbits
= math
.ceil(math
.log2(reswid
+1))+1 # hmmm...
106 print ("partial", reswid
, width
, intervals
[i
], shiftbits
)
108 masked
= Signal(b_intervals
[i
].shape(), name
="masked%d" % i
,
110 comb
+= masked
.eq(b_intervals
[i
] & shifter_masks
[i
])
111 element
= Mux(gates
[i
-1], masked
, element
)
112 elmux
= Signal(b_intervals
[i
].shape(), name
="elmux%d" % i
,
114 comb
+= elmux
.eq(element
)
117 # This calculates which partition of b to select the
118 # shifter from. According to the table above, the
119 # partition to select is given by the highest set bit in
120 # the partition mask, this calculates that with a mux
123 # This computes the partial results table
124 shifter
= Signal(shiftbits
, name
="shifter%d" % i
,
126 #with m.If(element > shiftbits):
127 # comb += shifter.eq(shiftbits)
129 # comb += shifter.eq(element)
130 comb
+= shifter
.eq(element
)
131 partial
= Signal(reswid
, name
="partial%d" % i
, reset_less
=True)
132 comb
+= partial
.eq(a_intervals
[i
] << shifter
)
134 partial_results
.append(partial
)
138 # This calculates the outputs o0-o3 from the partial results
141 result
= partial_results
[0]
142 out
.append(result
[s
:e
])
143 for i
in range(1, len(keys
)):
144 start
, end
= (intervals
[i
][0], width
)
145 result
= partial_results
[i
] | \
146 Mux(gates
[i
-1], 0, result
[intervals
[0][1]:])[:end
-start
]
147 print("select: [%d:%d]" % (start
, end
))
148 res
= Signal(width
, name
="res%d" % i
, reset_less
=True)
149 comb
+= res
.eq(result
)
154 comb
+= self
.output
.eq(Cat(*out
))