move get_a and get_b to their own classes
[ieee754fpu.git] / src / add / nmigen_add_experiment.py
index efb0c8bfa1de2bd8526d213e8426f934841f112a..8c691868142d2d856cb9ae40d3d7f45d6d0f0855 100644 (file)
@@ -2,17 +2,45 @@
 # Copyright (C) Jonathan P Dawson 2013
 # 2013-12-12
 
-from nmigen import Module, Signal
+from nmigen import Module, Signal, Cat
 from nmigen.cli import main, verilog
 
-from fpbase import FPNum, FPOp, Overflow, FPBase
+from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase
+
+class FPState(FPBase):
+    def __init__(self, state_from, state_to):
+        self.state_from = state_from
+        self.state_to = state_to
+
+    def set_inputs(self, inputs):
+        self.inputs = inputs
+        for k,v in inputs.items():
+            setattr(self, k, v)
+
+    def set_outputs(self, outputs):
+        self.outputs = outputs
+        for k,v in outputs.items():
+            setattr(self, k, v)
+
+
+class FPGetOpA(FPState):
+
+    def action(self, m):
+        self.get_op(m, self.in_a, self.a, self.state_to)
+
+
+class FPGetOpB(FPState):
+
+    def action(self, m):
+        self.get_op(m, self.in_b, self.b, self.state_to)
 
 
 class FPADD(FPBase):
 
-    def __init__(self, width):
+    def __init__(self, width, single_cycle=False):
         FPBase.__init__(self)
         self.width = width
+        self.single_cycle = single_cycle
 
         self.in_a  = FPOp(width)
         self.in_b  = FPOp(width)
@@ -24,13 +52,29 @@ class FPADD(FPBase):
         m = Module()
 
         # Latches
-        a = FPNum(self.width)
-        b = FPNum(self.width)
-        z = FPNum(self.width, False)
+        a = FPNumIn(self.in_a, self.width)
+        b = FPNumIn(self.in_b, self.width)
+        z = FPNumOut(self.width, False)
 
-        tot = Signal(28)     # sticky/round/guard bits, 23 result, 1 overflow
+        m.submodules.fpnum_a = a
+        m.submodules.fpnum_b = b
+        m.submodules.fpnum_z = z
+
+        w = z.m_width + 4
+        tot = Signal(w, reset_less=True) # sticky/round/guard, {mantissa} result, 1 overflow
 
         of = Overflow()
+        m.submodules.overflow = of
+
+        geta = FPGetOpA("get_a", "get_b")
+        geta.set_inputs({"in_a": self.in_a})
+        geta.set_outputs({"a": a})
+        m.d.comb += a.v.eq(self.in_a.v) # links in_a to a
+
+        getb = FPGetOpB("get_b", "special_cases")
+        getb.set_inputs({"in_b": self.in_b})
+        getb.set_outputs({"b": b})
+        m.d.comb += b.v.eq(self.in_b.v) # links in_b to b
 
         with m.FSM() as fsm:
 
@@ -38,13 +82,14 @@ class FPADD(FPBase):
             # gets operand a
 
             with m.State("get_a"):
-                self.get_op(m, self.in_a, a, "get_b")
+                geta.action(m)
 
             # ******
             # gets operand b
 
             with m.State("get_b"):
-                self.get_op(m, self.in_b, b, "special_cases")
+                #self.get_op(m, self.in_b, b, "special_cases")
+                getb.action(m)
 
             # ******
             # special cases: NaNs, infs, zeros, denormalised
@@ -53,58 +98,120 @@ class FPADD(FPBase):
 
             with m.State("special_cases"):
 
+                s_nomatch = Signal()
+                m.d.comb += s_nomatch.eq(a.s != b.s)
+
+                m_match = Signal()
+                m.d.comb += m_match.eq(a.m == b.m)
+
                 # if a is NaN or b is NaN return NaN
-                with m.If(a.is_nan() | b.is_nan()):
+                with m.If(a.is_nan | b.is_nan):
                     m.next = "put_z"
                     m.d.sync += z.nan(1)
 
+                # XXX WEIRDNESS for FP16 non-canonical NaN handling
+                # under review
+
+                ## if a is zero and b is NaN return -b
+                #with m.If(a.is_zero & (a.s==0) & b.is_nan):
+                #    m.next = "put_z"
+                #    m.d.sync += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0]))
+
+                ## if b is zero and a is NaN return -a
+                #with m.Elif(b.is_zero & (b.s==0) & a.is_nan):
+                #    m.next = "put_z"
+                #    m.d.sync += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0]))
+
+                ## if a is -zero and b is NaN return -b
+                #with m.Elif(a.is_zero & (a.s==1) & b.is_nan):
+                #    m.next = "put_z"
+                #    m.d.sync += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1))
+
+                ## if b is -zero and a is NaN return -a
+                #with m.Elif(b.is_zero & (b.s==1) & a.is_nan):
+                #    m.next = "put_z"
+                #    m.d.sync += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1))
+
                 # if a is inf return inf (or NaN)
-                with m.Elif(a.is_inf()):
+                with m.Elif(a.is_inf):
                     m.next = "put_z"
                     m.d.sync += z.inf(a.s)
                     # if a is inf and signs don't match return NaN
-                    with m.If((b.e == b.P128) & (a.s != b.s)):
-                        m.d.sync += z.nan(b.s)
+                    with m.If(b.exp_128 & s_nomatch):
+                        m.d.sync += z.nan(1)
 
                 # if b is inf return inf
-                with m.Elif(b.is_inf()):
+                with m.Elif(b.is_inf):
                     m.next = "put_z"
                     m.d.sync += z.inf(b.s)
 
                 # if a is zero and b zero return signed-a/b
-                with m.Elif(a.is_zero() & b.is_zero()):
+                with m.Elif(a.is_zero & b.is_zero):
                     m.next = "put_z"
-                    m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1])
+                    m.d.sync += z.create(a.s & b.s, b.e, b.m[3:-1])
 
                 # if a is zero return b
-                with m.Elif(a.is_zero()):
+                with m.Elif(a.is_zero):
                     m.next = "put_z"
-                    m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1])
+                    m.d.sync += z.create(b.s, b.e, b.m[3:-1])
 
                 # if b is zero return a
-                with m.Elif(b.is_zero()):
+                with m.Elif(b.is_zero):
+                    m.next = "put_z"
+                    m.d.sync += z.create(a.s, a.e, a.m[3:-1])
+
+                # if a equal to -b return zero (+ve zero)
+                with m.Elif(s_nomatch & m_match & (a.e == b.e)):
                     m.next = "put_z"
-                    m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1])
+                    m.d.sync += z.zero(0)
 
                 # Denormalised Number checks
                 with m.Else():
-                    m.next = "align"
-                    self.denormalise(m, a)
-                    self.denormalise(m, b)
+                    m.next = "denormalise"
 
             # ******
-            # align.  NOTE: this does *not* do single-cycle multi-shifting,
-            #         it *STAYS* in the align state until the exponents match
+            # denormalise.
+
+            with m.State("denormalise"):
+                # Denormalised Number checks
+                m.next = "align"
+                self.denormalise(m, a)
+                self.denormalise(m, b)
+
+            # ******
+            # align.
 
             with m.State("align"):
-                # exponent of a greater than b: increment b exp, shift b mant
-                with m.If(a.e > b.e):
-                    m.d.sync += b.shift_down()
-                # exponent of b greater than a: increment a exp, shift a mant
-                with m.Elif(a.e < b.e):
-                    m.d.sync += a.shift_down()
-                # exponents equal: move to next stage.
-                with m.Else():
+                if not self.single_cycle:
+                    # NOTE: this does *not* do single-cycle multi-shifting,
+                    #       it *STAYS* in the align state until exponents match
+
+                    # exponent of a greater than b: shift b down
+                    with m.If(a.e > b.e):
+                        m.d.sync += b.shift_down()
+                    # exponent of b greater than a: shift a down
+                    with m.Elif(a.e < b.e):
+                        m.d.sync += a.shift_down()
+                    # exponents equal: move to next stage.
+                    with m.Else():
+                        m.next = "add_0"
+                else:
+                    # This one however (single-cycle) will do the shift
+                    # in one go.
+
+                    # XXX TODO: the shifter used here is quite expensive
+                    # having only one would be better
+
+                    ediff = Signal((len(a.e), True), reset_less=True)
+                    ediffr = Signal((len(a.e), True), reset_less=True)
+                    m.d.comb += ediff.eq(a.e - b.e)
+                    m.d.comb += ediffr.eq(b.e - a.e)
+                    with m.If(ediff > 0):
+                        m.d.sync += b.shift_down_multi(ediff)
+                    # exponent of b greater than a: shift a down
+                    with m.Elif(ediff < 0):
+                        m.d.sync += a.shift_down_multi(ediffr)
+
                     m.next = "add_0"
 
             # ******
@@ -118,19 +225,19 @@ class FPADD(FPBase):
                 # same-sign (both negative or both positive) add mantissas
                 with m.If(a.s == b.s):
                     m.d.sync += [
-                        tot.eq(a.m + b.m),
+                        tot.eq(Cat(a.m, 0) + Cat(b.m, 0)),
                         z.s.eq(a.s)
                     ]
                 # a mantissa greater than b, use a
                 with m.Elif(a.m >= b.m):
                     m.d.sync += [
-                        tot.eq(a.m - b.m),
+                        tot.eq(Cat(a.m, 0) - Cat(b.m, 0)),
                         z.s.eq(a.s)
                     ]
                 # b mantissa greater than a, use b
                 with m.Else():
                     m.d.sync += [
-                        tot.eq(b.m - a.m),
+                        tot.eq(Cat(b.m, 0) - Cat(a.m, 0)),
                         z.s.eq(b.s)
                 ]
 
@@ -141,9 +248,10 @@ class FPADD(FPBase):
             with m.State("add_1"):
                 m.next = "normalise_1"
                 # tot[27] gets set when the sum overflows. shift result down
-                with m.If(tot[27]):
+                with m.If(tot[-1]):
                     m.d.sync += [
-                        z.m.eq(tot[4:28]),
+                        z.m.eq(tot[4:]),
+                        of.m0.eq(tot[4]),
                         of.guard.eq(tot[3]),
                         of.round_bit.eq(tot[2]),
                         of.sticky.eq(tot[1] | tot[0]),
@@ -152,7 +260,8 @@ class FPADD(FPBase):
                 # tot[27] zero case
                 with m.Else():
                     m.d.sync += [
-                        z.m.eq(tot[3:27]),
+                        z.m.eq(tot[3:]),
+                        of.m0.eq(tot[3]),
                         of.guard.eq(tot[2]),
                         of.round_bit.eq(tot[1]),
                         of.sticky.eq(tot[0])