add a variable-length single-cycle shift_down of mantissa, and unit test
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 17 Feb 2019 13:07:06 +0000 (13:07 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Sun, 17 Feb 2019 13:07:06 +0000 (13:07 +0000)
src/add/fpbase.py
src/add/test_fpnum.py [new file with mode: 0644]

index 8c0b54b0800ac9fcbddc67d95361955fbe297c36..cc5184df63a7910b35c6e8039bc5629f8f4b9dcf 100644 (file)
@@ -4,6 +4,8 @@
 
 from nmigen import Signal, Cat, Const, Mux
 from math import log
+from operator import or_
+from functools import reduce
 
 class MultiShift:
     """ Generates variable-length single-cycle shifter from a series
@@ -53,6 +55,7 @@ class FPNum:
         e_width = {32: 10, 64: 13}[width]
         e_max = 1<<(e_width-3)
         self.rmw = m_width # real mantissa width (not including extras)
+        self.e_max = e_max
         if m_extra:
             # mantissa extra bits (top,guard,round)
             self.m_extra = 3
@@ -85,7 +88,7 @@ class FPNum:
             a 10-bit number
         """
         args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
-        print (self.e_end)
+        print ("decode", self.e_end)
         return [self.m.eq(Cat(*args)), # mantissa
                 self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
                 self.s.eq(v[-1]),                 # sign
@@ -112,6 +115,33 @@ class FPNum:
                 self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0))
                ]
 
+    def shift_down_multi(self, diff):
+        """ shifts a mantissa down. exponent is increased to compensate
+
+            accuracy is lost as a result in the mantissa however there are 3
+            guard bits (the latter of which is the "sticky" bit)
+
+            this code works by variable-shifting the mantissa by up to
+            its maximum bit-length: no point doing more (it'll still be
+            zero).
+
+            the sticky bit is computed by shifting a batch of 1s by
+            the same amount, which will introduce zeros.  it's then
+            inverted and used as a mask to get the LSBs of the mantissa.
+            those are then |'d into the sticky bit.
+        """
+        sm = MultiShift(self.width)
+        mw = Const(self.m_width-1, len(diff))
+        maxslen = Mux(diff > mw, mw, diff)
+        rs = sm.rshift(self.m[1:], maxslen)
+        maxsleni = mw - maxslen
+        m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
+
+        stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
+        return [self.e.eq(self.e + diff),
+                self.m.eq(Cat(stickybits, rs))
+               ]
+
     def nan(self, s):
         return self.create(s, self.P128, 1<<(self.e_start-1))
 
diff --git a/src/add/test_fpnum.py b/src/add/test_fpnum.py
new file mode 100644 (file)
index 0000000..2003658
--- /dev/null
@@ -0,0 +1,60 @@
+from random import randint
+from nmigen import Module, Signal
+from nmigen.compat.sim import run_simulation
+
+from fpbase import FPNum
+
+class FPNumModShiftMulti:
+    def __init__(self, width):
+        self.a = FPNum(width)
+        self.ediff = Signal((self.a.e_width, True))
+
+    def get_fragment(self, platform=None):
+
+        m = Module()
+        #m.d.sync += self.a.decode(self.a.v)
+        m.d.sync += self.a.shift_down_multi(self.ediff)
+
+        return m
+
+def check_case(dut, width, e_width, m, e, i):
+    yield dut.a.m.eq(m)
+    yield dut.a.e.eq(e)
+    yield dut.ediff.eq(i)
+    yield
+    yield
+
+    out_m = yield dut.a.m
+    out_e = yield dut.a.e
+    ed = yield dut.ediff
+    calc_e = (e + i) 
+    print (e, bin(m), out_e, calc_e, bin(out_m), i, ed)
+
+    calc_m = ((m >> (i+1)) << 1) | (m & 1)
+    for l in range(i):
+        if m & (1<<(l+1)):
+            calc_m |= 1
+
+    assert out_e == calc_e, "Output e 0x%x != expected 0x%x" % (out_e, calc_e)
+    assert out_m == calc_m, "Output m 0x%x != expected 0x%x" % (out_m, calc_m)
+
+def testbench(dut):
+    m_width = dut.a.m_width
+    e_width = dut.a.e_width
+    e_max = dut.a.e_max
+    for j in range(200):
+        m = randint(0, (1<<m_width)-1)
+        zeros = randint(0, 31)
+        for i in range(zeros):
+            m &= ~(1<<i)
+        e = randint(-e_max, e_max)
+        for i in range(32):
+            yield from check_case(dut, m_width, e_width, m, e, i)
+
+if __name__ == '__main__':
+    dut = FPNumModShiftMulti(width=32)
+    run_simulation(dut, testbench(dut), vcd_name="test_multishift.vcd")
+
+    #dut = MultiShiftModL(width=32)
+    #run_simulation(dut, testbench(dut), vcd_name="test_multishift.vcd")
+