unit test for multi-bit shift right with merge (sticky bit)
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 3 Mar 2019 23:13:51 +0000 (23:13 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 3 Mar 2019 23:13:51 +0000 (23:13 +0000)
src/add/fpbase.py
src/add/test_multishift.py

index 58b4994a4de4ea8904156dca641e41bb6e0f3d27..d6a059e0339b0ea4c7464a2274a1e474cecb1535 100644 (file)
@@ -184,30 +184,19 @@ class FPNumOut(FPNumBase):
         return self.create(s, self.N127, 0)
 
 
-class FPNumShiftMultiRight(FPNumBase):
-    """ shifts a mantissa down. exponent is increased to compensate
-
-        accuracy is lost as a result in the mantissa however there are 3
-        guard bits (the latter of which is the "sticky" bit)
-
-        this code works by variable-shifting the mantissa by up to
-        its maximum bit-length: no point doing more (it'll still be
-        zero).
-
-        the sticky bit is computed by shifting a batch of 1s by
-        the same amount, which will introduce zeros.  it's then
-        inverted and used as a mask to get the LSBs of the mantissa.
-        those are then |'d into the sticky bit.
+class MultiShiftRMerge:
+    """ shifts down (right) and merges lower bits into m[0].
+        m[0] is the "sticky" bit, basically
     """
-    def __init__(self, inp, diff, width):
+    def __init__(self, width):
+        self.smax = int(log(width) / log(2))
         self.m = Signal(width, reset_less=True)
-        self.inp = inp
-        self.diff = diff
+        self.inp = Signal(width, reset_less=True)
+        self.diff = Signal(self.smax, reset_less=True)
         self.width = width
 
     def elaborate(self, platform):
         m = Module()
-        #m.submodules.inp = self.inp
 
         rs = Signal(self.width, reset_less=True)
         m_mask = Signal(self.width, reset_less=True)
@@ -215,18 +204,17 @@ class FPNumShiftMultiRight(FPNumBase):
         stickybit = Signal(reset_less=True)
 
         sm = MultiShift(self.width-1)
+        m0s = Const(0, self.width-1)
         mw = Const(self.width-1, len(self.diff))
         maxslen = Mux(self.diff > mw, mw, self.diff)
         maxsleni = mw - maxslen
         m.d.comb += [
                 # shift mantissa by maxslen, mask by inverse
-                rs.eq(sm.rshift(self.inp.m[1:], maxslen)),
-                m_mask.eq(sm.rshift(self.inp.m1s[1:], maxsleni)),
-                smask.eq(self.inp.m[1:] & m_mask),
+                rs.eq(sm.rshift(self.inp[1:], maxslen)),
+                m_mask.eq(sm.rshift(~m0s, maxsleni)),
+                smask.eq(self.inp[1:] & m_mask),
                 # sticky bit combines all mask (and mantissa low bit)
-                stickybit.eq(smask.bool() | self.inp.m[0]),
-                #self.s.eq(self.inp.s),
-                #self.e.eq(self.inp.e + diff),
+                stickybit.eq(smask.bool() | self.inp[0]),
                 # mantissa result contains m[0] already.
                 self.m.eq(Cat(stickybit, rs))
            ]
index 5fa649ef83f79dcbce7d3338da418d02dd9b9b2d..2aa6ba330dc28a933fad01951146d441a88c0bce 100644 (file)
@@ -2,7 +2,7 @@ from random import randint
 from nmigen import Module, Signal
 from nmigen.compat.sim import run_simulation
 
-from fpbase import MultiShift, MultiShiftR
+from fpbase import MultiShift, MultiShiftR, MultiShiftRMerge
 
 class MultiShiftModL:
     def __init__(self, width):
@@ -49,6 +49,24 @@ class MultiShiftModRMod:
 
         return m
 
+class MultiShiftRMergeMod:
+    def __init__(self, width):
+        self.ms = MultiShiftRMerge(width)
+        self.a = Signal(width)
+        self.b = Signal(self.ms.smax)
+        self.x = Signal(width)
+
+    def get_fragment(self, platform=None):
+
+        m = Module()
+        m.submodules += self.ms
+        m.d.comb += self.ms.inp.eq(self.a)
+        m.d.comb += self.ms.diff.eq(self.b)
+        m.d.comb += self.x.eq(self.ms.m)
+
+        return m
+
+
 def check_case(dut, width, a, b):
     yield dut.a.eq(a)
     yield dut.b.eq(b)
@@ -69,6 +87,27 @@ def check_caser(dut, width, a, b):
     out_x = yield dut.x
     assert out_x == x, "Output x 0x%x not equal to expected 0x%x" % (out_x, x)
 
+
+def check_case_merge(dut, width, a, b):
+    yield dut.a.eq(a)
+    yield dut.b.eq(b)
+    yield
+
+    x = (a >> b) & ((1<<width)-1) # actual shift
+    if (a & ((2<<b)-1)) != 0: # mask for sticky bit
+        x |= 1 # set LSB
+
+    out_x = yield dut.x
+    assert out_x == x, \
+                "\nshift %d\nInput\n%+32s\nOutput x\n%+32s != \n%+32s" % \
+                        (b, bin(a), bin(out_x), bin(x))
+
+def testmerge(dut):
+    for i in range(32):
+        for j in range(1000):
+            a = randint(0, (1<<32)-1)
+            yield from check_case_merge(dut, 32, a, i)
+
 def testbench(dut):
     for i in range(32):
         for j in range(1000):
@@ -82,6 +121,8 @@ def testbenchr(dut):
             yield from check_caser(dut, 32, a, i)
 
 if __name__ == '__main__':
+    dut = MultiShiftRMergeMod(width=32)
+    run_simulation(dut, testmerge(dut), vcd_name="test_multishiftmerge.vcd")
     dut = MultiShiftModRMod(width=32)
     run_simulation(dut, testbenchr(dut), vcd_name="test_multishift.vcd")