use MultiShiftRMerge module instead of shift_down_multi function
authorLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 4 Mar 2019 04:13:02 +0000 (04:13 +0000)
committerLuke Kenneth Casson Leighton <lkcl@lkcl.net>
Mon, 4 Mar 2019 04:13:02 +0000 (04:13 +0000)
src/add/fpbase.py
src/add/nmigen_add_experiment.py
src/add/test_add.py

index 4430f9c6a3fa235214d8191651109c9341167408..dd2179f577ea22c8a56d319cd4f0c2384e4fdbb5 100644 (file)
@@ -188,11 +188,13 @@ class MultiShiftRMerge:
     """ shifts down (right) and merges lower bits into m[0].
         m[0] is the "sticky" bit, basically
     """
     """ shifts down (right) and merges lower bits into m[0].
         m[0] is the "sticky" bit, basically
     """
-    def __init__(self, width):
-        self.smax = int(log(width) / log(2))
+    def __init__(self, width, s_max=None):
+        if s_max is None:
+            s_max = int(log(width) / log(2))
+        self.smax = s_max
         self.m = Signal(width, reset_less=True)
         self.inp = Signal(width, reset_less=True)
         self.m = Signal(width, reset_less=True)
         self.inp = Signal(width, reset_less=True)
-        self.diff = Signal(self.smax, reset_less=True)
+        self.diff = Signal(s_max, reset_less=True)
         self.width = width
 
     def elaborate(self, platform):
         self.width = width
 
     def elaborate(self, platform):
@@ -202,12 +204,16 @@ class MultiShiftRMerge:
         m_mask = Signal(self.width, reset_less=True)
         smask = Signal(self.width, reset_less=True)
         stickybit = Signal(reset_less=True)
         m_mask = Signal(self.width, reset_less=True)
         smask = Signal(self.width, reset_less=True)
         stickybit = Signal(reset_less=True)
+        maxslen = Signal(self.smax, reset_less=True)
+        maxsleni = Signal(self.smax, reset_less=True)
 
         sm = MultiShift(self.width-1)
         m0s = Const(0, self.width-1)
         mw = Const(self.width-1, len(self.diff))
 
         sm = MultiShift(self.width-1)
         m0s = Const(0, self.width-1)
         mw = Const(self.width-1, len(self.diff))
-        maxslen = Mux(self.diff > mw, mw, self.diff)
-        maxsleni = mw - maxslen
+        m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
+                     maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
+                    ]
+
         m.d.comb += [
                 # shift mantissa by maxslen, mask by inverse
                 rs.eq(sm.rshift(self.inp[1:], maxslen)),
         m.d.comb += [
                 # shift mantissa by maxslen, mask by inverse
                 rs.eq(sm.rshift(self.inp[1:], maxslen)),
index ee5959694b2839fa14cdc2a05d2d2a055df715e1..73e1d7a8277475eaa7b6ef092ba2ba4601bb114a 100644 (file)
@@ -7,6 +7,7 @@ from nmigen.lib.coding import PriorityEncoder
 from nmigen.cli import main, verilog
 
 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
 from nmigen.cli import main, verilog
 
 from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase
+from fpbase import MultiShiftRMerge
 #from fpbase import FPNumShiftMultiRight
 
 class FPState(FPBase):
 #from fpbase import FPNumShiftMultiRight
 
 class FPState(FPBase):
@@ -366,15 +367,25 @@ class FPAddAlignSingleMod:
         # temporary (muxed) input and output to be shifted
         t_inp = FPNumBase(self.width)
         t_out = FPNumIn(None, self.width)
         # temporary (muxed) input and output to be shifted
         t_inp = FPNumBase(self.width)
         t_out = FPNumIn(None, self.width)
+        espec = (len(self.in_a.e), True)
+        msr = MultiShiftRMerge(self.in_a.m_width, espec)
         m.submodules.align_t_in = t_inp
         m.submodules.align_t_out = t_out
         m.submodules.align_t_in = t_inp
         m.submodules.align_t_out = t_out
+        m.submodules.multishift_r = msr
 
 
-        ediff = Signal((len(self.in_a.e), True), reset_less=True)
-        ediffr = Signal((len(self.in_a.e), True), reset_less=True)
-        tdiff = Signal((len(self.in_a.e), True), reset_less=True)
+        ediff = Signal(espec, reset_less=True)
+        ediffr = Signal(espec, reset_less=True)
+        tdiff = Signal(espec, reset_less=True)
         elz = Signal(reset_less=True)
         egz = Signal(reset_less=True)
 
         elz = Signal(reset_less=True)
         egz = Signal(reset_less=True)
 
+        # connect multi-shifter to t_inp/out mantissa (and tdiff)
+        m.d.comb += msr.inp.eq(t_inp.m)
+        m.d.comb += msr.diff.eq(tdiff)
+        m.d.comb += t_out.m.eq(msr.m)
+        m.d.comb += t_out.e.eq(t_inp.e + tdiff)
+        m.d.comb += t_out.s.eq(t_inp.s)
+
         m.d.comb += ediff.eq(self.in_a.e - self.in_b.e)
         m.d.comb += ediffr.eq(self.in_b.e - self.in_a.e)
         m.d.comb += elz.eq(self.in_a.e < self.in_b.e)
         m.d.comb += ediff.eq(self.in_a.e - self.in_b.e)
         m.d.comb += ediffr.eq(self.in_b.e - self.in_a.e)
         m.d.comb += elz.eq(self.in_a.e < self.in_b.e)
@@ -384,7 +395,7 @@ class FPAddAlignSingleMod:
         m.d.comb += self.out_a.copy(self.in_a)
         m.d.comb += self.out_b.copy(self.in_b)
         # only one shifter (muxed)
         m.d.comb += self.out_a.copy(self.in_a)
         m.d.comb += self.out_b.copy(self.in_b)
         # only one shifter (muxed)
-        m.d.comb += t_out.shift_down_multi(tdiff, t_inp)
+        #m.d.comb += t_out.shift_down_multi(tdiff, t_inp)
         # exponent of a greater than b: shift b down
         with m.If(egz):
             m.d.comb += [t_inp.copy(self.in_b),
         # exponent of a greater than b: shift b down
         with m.If(egz):
             m.d.comb += [t_inp.copy(self.in_b),
index 1f143beb7215bda3100112a7f70a438663959f3d..dece89619efe889f1d512176bcd247b2e629afa5 100644 (file)
@@ -11,6 +11,7 @@ from unit_test_single import (get_mantissa, get_exponent, get_sign, is_nan,
                                 run_edge_cases, run_corner_cases)
 
 def testbench(dut):
                                 run_edge_cases, run_corner_cases)
 
 def testbench(dut):
+    yield from check_case(dut, 0x36093399, 0x7f6a12f1, 0x7f6a12f1)
     yield from check_case(dut, 0x006CE3EE, 0x806CE3EC, 0x00000002)
     yield from check_case(dut, 0x00000047, 0x80000048, 0x80000001)
     yield from check_case(dut, 0x000116C2, 0x8001170A, 0x80000048)
     yield from check_case(dut, 0x006CE3EE, 0x806CE3EC, 0x00000002)
     yield from check_case(dut, 0x00000047, 0x80000048, 0x80000001)
     yield from check_case(dut, 0x000116C2, 0x8001170A, 0x80000048)