Add gated bit reversal module
authorMichael Nolan <mtnolan2640@gmail.com>
Mon, 24 Feb 2020 20:15:30 +0000 (15:15 -0500)
committerMichael Nolan <mtnolan2640@gmail.com>
Mon, 24 Feb 2020 20:15:30 +0000 (15:15 -0500)
src/ieee754/part_shift/bitrev.py [new file with mode: 0644]
src/ieee754/part_shift/formal/proof_shift_scalar.py
src/ieee754/part_shift/part_shift_scalar.py

diff --git a/src/ieee754/part_shift/bitrev.py b/src/ieee754/part_shift/bitrev.py
new file mode 100644 (file)
index 0000000..47523b1
--- /dev/null
@@ -0,0 +1,20 @@
+from nmigen import Signal, Module, Elaboratable, Cat, Mux
+
+class GatedBitReverse(Elaboratable):
+    def __init__(self, width):
+        self.width = width
+        self.data = Signal(width, reset_less=True)
+        self.reverse_en = Signal(reset_less=True)
+        self.output = Signal(width, reset_less=True)
+    def elaborate(self, platform):
+        m = Module()
+        comb = m.d.comb
+        width = self.width
+
+        for i in range(width):
+            with m.If(self.reverse_en):
+                comb += self.output[i].eq(self.data[width-i-1])
+            with m.Else():
+                comb += self.output[i].eq(self.data[i])
+
+        return m
index b150e25b716c04859a07fa09c0c16aa9477c651d..f76d063241319d01a3a241c809de7fec57e95005 100644 (file)
@@ -41,12 +41,14 @@ class ShifterDriver(Elaboratable):
         shifter = Signal(shifterwidth)
         points = PartitionPoints()
         gates = Signal(mwidth-1)
+        bitrev = Signal()
         step = int(width/mwidth)
         for i in range(mwidth-1):
             points[(i+1)*step] = gates[i]
         print(points)
 
         comb += [data.eq(AnyConst(width)),
+                 bitrev.eq(AnyConst(1)),
                  shifter.eq(AnyConst(shifterwidth)),
                  gates.eq(AnyConst(mwidth-1))]
 
@@ -57,34 +59,37 @@ class ShifterDriver(Elaboratable):
 
         comb += [dut.data.eq(data),
                  dut.shifter.eq(shifter),
+                 dut.bitrev.eq(bitrev),
                  out.eq(dut.output)]
 
         expected = Signal(width)
 
-        with m.Switch(points.as_sig()):
-            with m.Case(0b00):
-                comb += Assert(
-                    out[0:24] == (data[0:24] << (shifter & 0x1f)) & 0xffffff)
-
-            with m.Case(0b01):
-                comb += Assert(out[0:8] ==
-                               (data[0:8] << (shifter & 0x7)) & 0xFF)
-                comb += Assert(out[8:24] ==
-                               (data[8:24] << (shifter & 0xf)) & 0xffff)
-
-            with m.Case(0b10):
-                comb += Assert(out[16:24] ==
-                               (data[16:24] << (shifter & 0x7)) & 0xff)
-                comb += Assert(out[0:16] ==
-                               (data[0:16] << (shifter & 0xf)) & 0xffff)
-
-            with m.Case(0b11):
-                comb += Assert(out[0:8] ==
-                               (data[0:8] << (shifter & 0x7)) & 0xFF)
-                comb += Assert(out[8:16] ==
-                               (data[8:16] << (shifter & 0x7)) & 0xff)
-                comb += Assert(out[16:24] ==
-                               (data[16:24] << (shifter & 0x7)) & 0xff)
+        with m.If(bitrev == 0):
+            with m.Switch(points.as_sig()):
+                with m.Case(0b00):
+                    comb += Assert(
+                        out[0:24] == (data[0:24] << (shifter & 0x1f)) &
+                        0xffffff)
+
+                with m.Case(0b01):
+                    comb += Assert(out[0:8] ==
+                                (data[0:8] << (shifter & 0x7)) & 0xFF)
+                    comb += Assert(out[8:24] ==
+                                (data[8:24] << (shifter & 0xf)) & 0xffff)
+
+                with m.Case(0b10):
+                    comb += Assert(out[16:24] ==
+                                (data[16:24] << (shifter & 0x7)) & 0xff)
+                    comb += Assert(out[0:16] ==
+                                (data[0:16] << (shifter & 0xf)) & 0xffff)
+
+                with m.Case(0b11):
+                    comb += Assert(out[0:8] ==
+                                (data[0:8] << (shifter & 0x7)) & 0xFF)
+                    comb += Assert(out[8:16] ==
+                                (data[8:16] << (shifter & 0x7)) & 0xff)
+                    comb += Assert(out[16:24] ==
+                                (data[16:24] << (shifter & 0x7)) & 0xff)
         return m
 
 class PartitionedScalarShiftTestCase(FHDLTestCase):
index 150e96b583e422a94f74aa80a79c8d94e35cd978..2a14fb5da608fef3fce2e39760fb935ede0f9f19 100644 (file)
@@ -15,6 +15,7 @@ See:
 from nmigen import Signal, Module, Elaboratable, Cat, Mux
 from ieee754.part_mul_add.partpoints import PartitionPoints
 from ieee754.part_shift.part_shift_dynamic import ShifterMask
+from ieee754.part_shift.bitrev import GatedBitReverse
 import math
 
 
@@ -27,6 +28,8 @@ class PartitionedScalarShift(Elaboratable):
         self.shiftbits = math.ceil(math.log2(width))
         self.shifter = Signal(self.shiftbits, reset_less=True)
         self.output = Signal(width, reset_less=True)
+        self.bitrev = Signal(reset_less=True) # Whether to bit-reverse the
+                                              # input and output
 
     def elaborate(self, platform):
         m = Module()
@@ -34,20 +37,26 @@ class PartitionedScalarShift(Elaboratable):
         width = self.width
         pwid = self.partition_points.get_max_partition_count(width)-1
         shiftbits = self.shiftbits
-        shifted = Signal(self.data.width, reset_less=True)
         gates = self.partition_points.as_sig()
-        comb += shifted.eq(self.data << self.shifter)
 
         parts = []
         outputs = []
         shiftparts = []
         intervals = []
         keys = list(self.partition_points.keys()) + [self.width]
+
+        m.submodules.in_br = in_br = GatedBitReverse(self.data.width)
+        comb += in_br.data.eq(self.data)
+        comb += in_br.reverse_en.eq(self.bitrev)
+
+        m.submodules.out_br = out_br = GatedBitReverse(self.data.width)
+        comb += out_br.reverse_en.eq(self.bitrev)
+        comb += self.output.eq(out_br.output)
         start = 0
         for i in range(len(keys)):
             end = keys[i]
-            parts.append(self.data[start:end])
-            outputs.append(self.output[start:end])
+            parts.append(in_br.output[start:end])
+            outputs.append(out_br.data[start:end])
             intervals.append((start,end))
             start = end  # for next time round loop
 
@@ -81,7 +90,7 @@ class PartitionedScalarShift(Elaboratable):
             _shifter = Signal(self.shifter.width, name="shifter%d" % i,
                               reset_less=True)
             comb += _shifter.eq(self.shifter & shifter_masks[i])
-            comb += sp[s:].eq(self.data[s:e] << _shifter)
+            comb += sp[s:].eq(in_br.output[s:e] << _shifter)
             shiftparts.append(sp)