Add bit reversal to part_shift_dynamic
authorMichael Nolan <mtnolan2640@gmail.com>
Wed, 26 Feb 2020 15:56:50 +0000 (10:56 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Wed, 26 Feb 2020 15:56:50 +0000 (10:56 -0500)
Shift Right not working yet

src/ieee754/part_shift/formal/proof_shift_dynamic.py
src/ieee754/part_shift/part_shift_dynamic.py

index a836771c2262ebd0bf24f2d720561a5367257708..ffb93ece255cded7558867640daa874aed992956 100644 (file)
@@ -38,6 +38,7 @@ class ShifterDriver(Elaboratable):
         # setup the inputs and outputs of the DUT as anyconst
         a = Signal(width)
         b = Signal(width)
+        bitrev = Signal()
         out = Signal(width)
         points = PartitionPoints()
         gates = Signal(mwidth-1)
@@ -48,6 +49,7 @@ class ShifterDriver(Elaboratable):
 
         comb += [a.eq(AnyConst(width)),
                  b.eq(AnyConst(width)),
+                 bitrev.eq(AnyConst(1)),
                  gates.eq(AnyConst(mwidth-1))]
 
         m.submodules.dut = dut = PartitionedDynamicShift(width, points)
@@ -58,60 +60,71 @@ class ShifterDriver(Elaboratable):
 
         comb += [dut.a.eq(a),
                  dut.b.eq(b),
+                 dut.bitrev.eq(bitrev),
                  out.eq(dut.output)]
 
 
-        with m.Switch(points.as_sig()):
-            with m.Case(0b000):
-                comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
-            with m.Case(0b001):
-                comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
-                comb += Assert(Cat(out_intervals[1:4]) ==
-                               (Cat(a_intervals[1:4])
-                                << b_intervals[1][0:5]) & 0xffffff)
-            with m.Case(0b010):
-                comb += Assert(Cat(out_intervals[0:2]) ==
-                               (Cat(a_intervals[0:2])
-                                << (b_intervals[0] & 0xf)) & 0xffff)
-                comb += Assert(Cat(out_intervals[2:4]) ==
-                               (Cat(a_intervals[2:4])
-                                << (b_intervals[2] & 0xf)) & 0xffff)
-            with m.Case(0b011):
-                comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
-                comb += Assert(out_intervals[1] ==
-                               (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
-                comb += Assert(Cat(out_intervals[2:4]) ==
-                               (Cat(a_intervals[2:4])
-                                << b_intervals[2][0:4]) & 0xffff)
-            with m.Case(0b100):
-                comb += Assert(Cat(out_intervals[0:3]) ==
-                               (Cat(a_intervals[0:3])
-                                << b_intervals[0][0:5]) & 0xffffff)
-                comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
-            with m.Case(0b101):
-                comb += Assert(out_intervals[0] ==
-                               (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
-                comb += Assert(Cat(out_intervals[1:3]) ==
-                               (Cat(a_intervals[1:3])
-                                << b_intervals[1][0:4]) & 0xffff)
-                comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
-            with m.Case(0b110):
-                comb += Assert(Cat(out_intervals[0:2]) ==
-                               (Cat(a_intervals[0:2])
-                                << b_intervals[0][0:4]) & 0xffff)
-                comb += Assert(out_intervals[2] ==
-                               (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
-                comb += Assert(out_intervals[3] ==
-                               (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
-            with m.Case(0b111):
-                for i, o in enumerate(out_intervals):
-                    comb += Assert(o ==
-                                   (a_intervals[i] << b_intervals[i][0:3])
-                                   & 0xff)
+        with m.If(bitrev == 0):
+            with m.Switch(points.as_sig()):
+                with m.Case(0b000):
+                    comb += Assert(out == (a<<b[0:5]) & 0xffffffff)
+                with m.Case(0b001):
+                    comb += Assert(out_intervals[0] ==
+                                (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
+                    comb += Assert(Cat(out_intervals[1:4]) ==
+                                (Cat(a_intervals[1:4])
+                                    << b_intervals[1][0:5]) & 0xffffff)
+                with m.Case(0b010):
+                    comb += Assert(Cat(out_intervals[0:2]) ==
+                                (Cat(a_intervals[0:2])
+                                    << (b_intervals[0] & 0xf)) & 0xffff)
+                    comb += Assert(Cat(out_intervals[2:4]) ==
+                                (Cat(a_intervals[2:4])
+                                    << (b_intervals[2] & 0xf)) & 0xffff)
+                with m.Case(0b011):
+                    comb += Assert(out_intervals[0] ==
+                                (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
+                    comb += Assert(out_intervals[1] ==
+                                (a_intervals[1] << b_intervals[1][0:3]) & 0xff)
+                    comb += Assert(Cat(out_intervals[2:4]) ==
+                                (Cat(a_intervals[2:4])
+                                    << b_intervals[2][0:4]) & 0xffff)
+                with m.Case(0b100):
+                    comb += Assert(Cat(out_intervals[0:3]) ==
+                                (Cat(a_intervals[0:3])
+                                    << b_intervals[0][0:5]) & 0xffffff)
+                    comb += Assert(out_intervals[3] ==
+                                (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
+                with m.Case(0b101):
+                    comb += Assert(out_intervals[0] ==
+                                (a_intervals[0] << b_intervals[0][0:3]) & 0xff)
+                    comb += Assert(Cat(out_intervals[1:3]) ==
+                                (Cat(a_intervals[1:3])
+                                    << b_intervals[1][0:4]) & 0xffff)
+                    comb += Assert(out_intervals[3] ==
+                                (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
+                with m.Case(0b110):
+                    comb += Assert(Cat(out_intervals[0:2]) ==
+                                (Cat(a_intervals[0:2])
+                                    << b_intervals[0][0:4]) & 0xffff)
+                    comb += Assert(out_intervals[2] ==
+                                (a_intervals[2] << b_intervals[2][0:3]) & 0xff)
+                    comb += Assert(out_intervals[3] ==
+                                (a_intervals[3] << b_intervals[3][0:3]) & 0xff)
+                with m.Case(0b111):
+                    for i, o in enumerate(out_intervals):
+                        comb += Assert(o ==
+                                    (a_intervals[i] << b_intervals[i][0:3])
+                                    & 0xff)
+        with m.Else():
+            with m.Switch(points.as_sig()):
+                with m.Case(0b000):
+                    comb += Assert(out == (a>>b[0:5]) & 0xffffffff)
+                with m.Case(0b111):
+                    for i, o in enumerate(out_intervals):
+                        comb += Assert(o ==
+                                    (a_intervals[i] >> b_intervals[i][0:3])
+                                    & 0xff)
 
         return m
 
index e5ae557606e62cf802e04ee8a0fa913d9db4227b..230455092914d4a7ce82aca275463cb0caf94a5e 100644 (file)
@@ -15,6 +15,7 @@ See:
 """
 from nmigen import Signal, Module, Elaboratable, Cat, Mux, C
 from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_shift.bitrev import GatedBitReverse
 import math
 
 class ShifterMask(Elaboratable):
@@ -107,6 +108,7 @@ class PartitionedDynamicShift(Elaboratable):
 
         self.a = Signal(width, reset_less=True)
         self.b = Signal(width, reset_less=True)
+        self.bitrev = Signal(reset_less=True)
         self.output = Signal(width, reset_less=True)
 
     def elaborate(self, platform):
@@ -121,6 +123,19 @@ class PartitionedDynamicShift(Elaboratable):
         keys = list(self.partition_points.keys()) + [self.width]
         start = 0
 
+        m.submodules.a_br = a_br = GatedBitReverse(self.a.width)
+        comb += a_br.data.eq(self.a)
+        comb += a_br.reverse_en.eq(self.bitrev)
+
+        m.submodules.out_br = out_br = GatedBitReverse(self.output.width)
+        comb += out_br.reverse_en.eq(self.bitrev)
+        comb += self.output.eq(out_br.output)
+
+        m.submodules.gate_br = gate_br = GatedBitReverse(pwid)
+        comb += gate_br.data.eq(gates)
+        comb += gate_br.reverse_en.eq(self.bitrev)
+
+
         # break out both the input and output into partition-stratified blocks
         a_intervals = []
         b_intervals = []
@@ -130,16 +145,16 @@ class PartitionedDynamicShift(Elaboratable):
         for i in range(len(keys)):
             end = keys[i]
             widths.append(width - start)
-            a_intervals.append(self.a[start:end])
+            a_intervals.append(a_br.output[start:end])
             b_intervals.append(self.b[start:end])
             intervals.append([start,end])
             start = end
 
         min_bits = math.ceil(math.log2(intervals[0][1] - intervals[0][0]))
 
-        # shifts are normally done as (e.g. for 32 bit) result = a & (b&0b11111)
-        # truncating the b input.  however here of course the size of the
-        # partition varies dynamically.
+        # shifts are normally done as (e.g. for 32 bit) result = a &
+        # (b&0b11111) truncating the b input.  however here of course
+        # the size of the partition varies dynamically.
         shifter_masks = []
         for i in range(len(b_intervals)):
             bwid = b_intervals[i].shape()[0]
@@ -151,7 +166,7 @@ class PartitionedDynamicShift(Elaboratable):
             sm = ShifterMask(bitwid, bwid, max_bits, min_bits)
             setattr(m.submodules, "sm%d" % i, sm)
             if bitwid != 0:
-                comb += sm.gates.eq(gates[i:pwid])
+                comb += sm.gates.eq(gate_br.output[i:pwid])
             shifter_masks.append(sm.mask)
 
         print(shifter_masks)
@@ -190,7 +205,7 @@ class PartitionedDynamicShift(Elaboratable):
             masked = Signal(b_intervals[i].shape(), name="masked%d" % i,
                           reset_less=True)
             comb += pr.masked.eq(b_intervals[i] & shifter_masks[i])
-            comb += pr.gate.eq(gates[i-1])
+            comb += pr.gate.eq(gate_br.output[i-1])
             comb += pr.element.eq(element)
             comb += pr.a_interval.eq(a_intervals[i])
             partial_results.append(pr.partial)
@@ -207,7 +222,8 @@ class PartitionedDynamicShift(Elaboratable):
         for i in range(1, len(keys)):
             start, end = (intervals[i][0], width)
             reswid = width - start
-            sel = Mux(gates[i-1], 0, result[intervals[0][1]:][:end-start])
+            sel = Mux(gate_br.output[i-1], 0,
+                      result[intervals[0][1]:][:end-start])
             print("select: [%d:%d]" % (start, end))
             res = Signal(end-start+1, name="res%d" % i, reset_less=True)
             comb += res.eq(partial_results[i] | sel)
@@ -215,7 +231,7 @@ class PartitionedDynamicShift(Elaboratable):
             s,e = intervals[0]
             out.append(res[s:e])
 
-        comb += self.output.eq(Cat(*out))
+        comb += out_br.data.eq(Cat(*out))
 
         return m