Rudimentary working partitioned scalar shifter
authorMichael Nolan <mtnolan2640@gmail.com>
Mon, 10 Feb 2020 20:08:52 +0000 (15:08 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Tue, 11 Feb 2020 18:15:53 +0000 (13:15 -0500)
src/ieee754/part_shift_scalar/formal/proof_shift_scalar.py
src/ieee754/part_shift_scalar/part_shift_scalar.py

index 780c48b726f62943088777e7e36ee0225c772a81..11c8648736a88ae0eabd9966459d0cbf87e3c70a 100644 (file)
@@ -64,9 +64,20 @@ class ShifterDriver(Elaboratable):
 
         with m.Switch(points.as_sig()):
             with m.Case(0b00):
+                comb += Assert(out[0:24] == (data[0:24] << shifter) & 0xffffff)
+
+            with m.Case(0b01):
                 comb += Assert(out[0:8] == expected[0:8])
-                comb += Assert(out[8:16] == expected[8:16])
+                comb += Assert(out[8:24] == (data[8:24] << shifter) & 0xffff)
+
+            with m.Case(0b10):
+                comb += Assert(out[16:24] == (data[16:24] << shifter) & 0xff)
+                comb += Assert(out[0:16] == (data[0:16] << shifter) & 0xffff)
 
+            with m.Case(0b11):
+                comb += Assert(out[0:8] == expected[0:8])
+                comb += Assert(out[8:16] == (data[8:16] << shifter) & 0xff)
+                comb += Assert(out[16:24] == (data[16:24] << shifter) & 0xff)
         
         return m
 
index 13ae5642d25000f27d984be4ab9b773d86a472a2..87908c5c5afc16de27530aae692c9398c7fd1ebe 100644 (file)
@@ -13,7 +13,7 @@ See:
 * http://libre-riscv.org/3d_gpu/architecture/dynamic_simd/shift/
 * http://bugs.libre-riscv.org/show_bug.cgi?id=173
 """
-from nmigen import Signal, Module, Elaboratable, Cat, C
+from nmigen import Signal, Module, Elaboratable, Cat, Mux
 from ieee754.part_mul_add.partpoints import PartitionPoints
 import math
 
@@ -33,13 +33,35 @@ class PartitionedScalarShift(Elaboratable):
         comb = m.d.comb
         width = self.width
         shiftbits = self.shiftbits
-
         shifted = Signal(self.data.width)
+        gates = self.partition_points.as_sig()
         comb += shifted.eq(self.data << self.shifter)
 
-        comb += self.output[0:8].eq(shifted[0:8])
-        comb += self.output[8:16].eq(shifted[8:16])
+        parts = []
+        outputs = []
+        shiftparts = []
+        intervals = []
+        keys = list(self.partition_points.keys()) + [self.width]
+        start = 0
+        for i in range(len(keys)):
+            end = keys[i]
+            parts.append(self.data[start:end])
+            outputs.append(self.output[start:end])
+            intervals.append((start,end))
+
+            sp = Signal(width)
+            comb += sp[start:].eq(self.data[start:end] << self.shifter)
+            shiftparts.append(sp)
+            
+            start = end  # for next time round loop
+
+        for i, interval in enumerate(intervals):
+            start, end = interval
+            if i == 0:
+                intermed = shiftparts[i]
+            else:
+                intermed = shiftparts[i] | Mux(gates[i-1], 0, prev)
+            comb += outputs[i].eq(intermed[start:end])
+            prev = intermed
 
         return m
-        
-