make all signals resetless
[ieee754fpu.git] / src / ieee754 / part_shift / part_shift_dynamic.py
index 97be1118e016164896d5f35927b36deb9f32e560..5dd72ed792f9bc9c58526eaf2b0ebc4d37c7eaa0 100644 (file)
@@ -13,7 +13,7 @@ See:
 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
 """
-from nmigen import Signal, Module, Elaboratable, Cat, Mux
+from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
 from ieee754.part_mul_add.partpoints import PartitionPoints
 import math
 
@@ -23,81 +23,120 @@ class PartitionedDynamicShift(Elaboratable):
         self.width = width
         self.partition_points = PartitionPoints(partition_points)
 
-        self.a = Signal(width)
-        self.b = Signal(width)
-        self.output = Signal(width)
+        self.a = Signal(width, reset_less=True)
+        self.b = Signal(width, reset_less=True)
+        self.output = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
         m = Module()
         comb = m.d.comb
         width = self.width
-        gates = Signal(self.partition_points.get_max_partition_count(width)-1)
+        pwid = self.partition_points.get_max_partition_count(width)-1
+        gates = Signal(pwid, reset_less=True)
         comb += gates.eq(self.partition_points.as_sig())
 
         matrix = []
         keys = list(self.partition_points.keys()) + [self.width]
         start = 0
 
-        # create a matrix of partial shift-results (similar to PartitionedMul
-        # matrices).  These however have to be of length suitable to contain
-        # the full shifted "contribution".  i.e. B from the LSB *could* contain
-        # a number great enough to shift the entirety of A LSB right up to
-        # the MSB of the output, however B from the *MSB* is *only* going
-        # to contribute to the *MSB* of the output.
-        for i in range(len(keys)):
-            row = []
-            start = 0
-            for j in range(len(keys)):
-                end = keys[j]
-                row.append(Signal(width - start,
-                           name="matrix[%d][%d]" % (i, j)))
-            matrix.append(row)
-
+        # break out both the input and output into partition-stratified blocks
         a_intervals = []
         b_intervals = []
-        out_intervals = []
         intervals = []
+        widths = []
         start = 0
         for i in range(len(keys)):
             end = keys[i]
+            widths.append(width - start)
             a_intervals.append(self.a[start:end])
             b_intervals.append(self.b[start:end])
-            out_intervals.append(self.output[start:end])
             intervals.append([start,end])
             start = end
 
-        # actually calculate the shift-partials here
-        for i, b in enumerate(b_intervals):
-            start = 0
-            for j, a in enumerate(a_intervals):
-                end = keys[i]
-                result_width = matrix[i][j].width
-                bwidth = math.ceil(math.log2(result_width + 1))
-                comb += matrix[i][j].eq(a << b[:bwidth])
-                start = end
-
-        # now create a switch statement which sums the relevant partial results
-        # in each output-partition
-
-        intermed = matrix[0][0]
-        comb += out_intervals[0].eq(intermed)
-        for i in range(1, len(out_intervals)):
-            index = gates[:i]  # selects the 'i' least significant bits
-                               # of gates
-            element = Signal(width, name="element%d" % i)
-            for index in range(1<<i):
-                print(index)
-                with m.Switch(gates[:i]):
-                    with m.Case(index):
-                        index = math.ceil(math.log2(index + 1))
-                        comb += element.eq(matrix[index][i])
-            print(keys[i-1])
-            temp = Signal(width, name="intermed%d" % i)
-            print(intermed[keys[0]:])
-            intermed = Mux(gates[i-1], element, element | intermed[keys[0]:])
-            comb += temp.eq(intermed)
-            comb += out_intervals[i].eq(intermed)
-
+        min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
+        max_bits = math.ceil(math.log2(width))
+
+        # shifts are normally done as (e.g. for 32 bit) result = a & (b&0b11111)
+        # truncating the b input.  however here of course the size of the
+        # partition varies dynamically.
+        shifter_masks = []
+        for i in range(len(b_intervals)):
+            mask = Signal(b_intervals[i].shape(), name="shift_mask%d" % i,
+                          reset_less=True)
+            bits = []
+            for j in range(i, gates.width):
+                if bits:
+                    bits.append(~gates[j] & bits[-1])
+                else:
+                    bits.append(~gates[j])
+            comb += mask.eq(Cat((1 << min_bits)-1, bits)
+                            & ((1 << max_bits)-1))
+            shifter_masks.append(mask)
+
+        print(shifter_masks)
+
+        # Instead of generating the matrix described in the wiki, I
+        # instead calculate the shift amounts for each partition, then
+        # calculate the partial results of each partition << shift
+        # amount. On the wiki, the following table is given for output #3:
+        # p2p1p0 | o3
+        # 0 0 0  | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b0[7:0]
+        # 0 0 1  | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b1[7:0]
+        # 0 1 0  | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b2[7:0]
+        # 0 1 1  | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b2[7:0]
+        # 1 0 0  | a0b0[31:24] | a1b0[23:16] | a2b0[15:8] | a3b3[7:0]
+        # 1 0 1  | a0b0[31:24] | a1b1[23:16] | a2b1[15:8] | a3b3[7:0]
+        # 1 1 0  | a0b0[31:24] | a1b0[23:16] | a2b2[15:8] | a3b3[7:0]
+        # 1 1 1  | a0b0[31:24] | a1b1[23:16] | a2b2[15:8] | a3b3[7:0]
+
+        # Each output for o3 is given by a3bx and the partial results
+        # for o2 (namely, a2bx, a1bx, and a0b0). If I calculate the
+        # partial results [a0b0, a1bx, a2bx, a3bx], I can use just
+        # those partial results to calculate a0, a1, a2, and a3
+        shiftbits = math.ceil(math.log2(width))
+        element = b_intervals[0] & shifter_masks[0]
+        partial_results = []
+        partial_results.append(a_intervals[0] << element)
+        for i in range(1, len(keys)):
+            s, e = intervals[i]
+            masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
+                          reset_less=True)
+            comb += masked.eq(b_intervals[i] & shifter_masks[i])
+            element = Mux(gates[i-1], masked, element)
+
+            # This calculates which partition of b to select the
+            # shifter from. According to the table above, the
+            # partition to select is given by the highest set bit in
+            # the partition mask, this calculates that with a mux
+            # chain
+
+            # This computes the partial results table
+            shifter = Signal(shiftbits, name="shifter%d" % i,
+                          reset_less=True)
+            comb += shifter.eq(element)
+            partial = Signal(width, name="partial%d" % i, reset_less=True)
+            comb += partial.eq(a_intervals[i] << shifter)
+
+            partial_results.append(partial)
+
+        out = []
+
+        # This calculates the outputs o0-o3 from the partial results
+        # table above.
+        s,e = intervals[0]
+        result = partial_results[0]
+        out.append(result[s:e])
+        for i in range(1, len(keys)):
+            start, end = (intervals[i][0], width)
+            result = partial_results[i] | \
+                Mux(gates[i-1], 0, result[intervals[0][1]:])[:end-start]
+            print("select: [%d:%d]" % (start, end))
+            res = Signal(width, name="res%d" % i, reset_less=True)
+            comb += res.eq(result)
+            s,e = intervals[0]
+            out.append(res[s:e])
+
+        comb += self.output.eq(Cat(*out))
 
         return m