format fpbase.py
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 11 Jul 2019 09:46:58 +0000 (02:46 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 11 Jul 2019 09:47:43 +0000 (02:47 -0700)
src/ieee754/fpcommon/fpbase.py

index 0e057f9ee5dbc2c184292b5e296b011d7ba152be..71b2d9ecc58c81684029ae1b9a2b97b5868f7eb6 100644 (file)
@@ -46,8 +46,8 @@ class MultiShift:
         return res[:len(op)]
         res = op
         for i in range(self.smax):
-            zeros = [0] * (1<<i)
-            res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
+            zeros = [0] * (1 << i)
+            res = Mux(s & (1 << i), Cat(zeros, res[0:-(1 << i)]), res)
         return res
 
     def rshift(self, op, s):
@@ -55,20 +55,21 @@ class MultiShift:
         return res[:len(op)]
         res = op
         for i in range(self.smax):
-            zeros = [0] * (1<<i)
-            res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
+            zeros = [0] * (1 << i)
+            res = Mux(s & (1 << i), Cat(res[(1 << i):], zeros), res)
         return res
 
 
 class FPNumBaseRecord:
     """ Floating-point Base Number Class
     """
+
     def __init__(self, width, m_extra=True, e_extra=False):
         self.width = width
-        m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
-        e_width = {16: 7,  32: 10, 64: 13}[width] # 2 extra bits (overflow)
-        e_max = 1<<(e_width-3)
-        self.rmw = m_width - 1 # real mantissa width (not including extras)
+        m_width = {16: 11, 32: 24, 64: 53}[width]  # 1 extra bit (overflow)
+        e_width = {16: 7,  32: 10, 64: 13}[width]  # 2 extra bits (overflow)
+        e_max = 1 << (e_width-3)
+        self.rmw = m_width - 1  # real mantissa width (not including extras)
         self.e_max = e_max
         if m_extra:
             # mantissa extra bits (top,guard,round)
@@ -77,19 +78,19 @@ class FPNumBaseRecord:
         else:
             self.m_extra = 0
         if e_extra:
-            self.e_extra = 6 # enough to cover FP64 when converting to FP16
+            self.e_extra = 6  # enough to cover FP64 when converting to FP16
             e_width += self.e_extra
         else:
             self.e_extra = 0
-        #print (m_width, e_width, e_max, self.rmw, self.m_extra)
+        # print (m_width, e_width, e_max, self.rmw, self.m_extra)
         self.m_width = m_width
         self.e_width = e_width
         self.e_start = self.rmw
-        self.e_end = self.rmw + self.e_width - 2 # for decoding
+        self.e_end = self.rmw + self.e_width - 2  # for decoding
 
         self.v = Signal(width, reset_less=True)      # Latched copy of value
         self.m = Signal(m_width, reset_less=True)    # Mantissa
-        self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
+        self.e = Signal((e_width, True), reset_less=True)  # exp+2 bits, signed
         self.s = Signal(reset_less=True)           # Sign bit
 
         self.fp = self
@@ -114,7 +115,7 @@ class FPNumBaseRecord:
         e_width = self.e_width
 
         self.mzero = Const(0, (m_width, False))
-        m_msb = 1<<(self.m_width-2)
+        m_msb = 1 << (self.m_width-2)
         self.msb1 = Const(m_msb, (m_width, False))
         self.m1s = Const(-1, (m_width, False))
         self.P128 = Const(e_max, (e_width, True))
@@ -132,12 +133,12 @@ class FPNumBaseRecord:
         """
         return [
           self.v[0:self.e_start].eq(m),        # mantissa
-          self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add on bias)
+          self.v[self.e_start:self.e_end].eq(e + self.fp.P127),  # (add bias)
           self.v[-1].eq(s),          # sign
         ]
 
     def _nan(self, s):
-        return (s, self.fp.P128, 1<<(self.e_start-1))
+        return (s, self.fp.P128, 1 << (self.e_start-1))
 
     def _inf(self, s):
         return (s, self.fp.P128, 0)
@@ -159,7 +160,7 @@ class FPNumBaseRecord:
 
             bias is added here, to the exponent
         """
-        e = e + self.P127 # exp (add on bias)
+        e = e + self.P127  # exp (add on bias)
         return Cat(m[0:self.e_start],
                    e[0:self.e_end-self.e_start],
                    s)
@@ -185,6 +186,7 @@ class FPNumBaseRecord:
 class FPNumBase(FPNumBaseRecord, Elaboratable):
     """ Floating-point Base Number Class
     """
+
     def __init__(self, fp):
         fp.drop_in(self)
         self.fp = fp
@@ -252,6 +254,7 @@ class FPNumOut(FPNumBase):
         (m[-1]) is effectively a carry-overflow.  The other three are
         guard (m[2]), round (m[1]), and sticky (m[0])
     """
+
     def __init__(self, fp):
         FPNumBase.__init__(self, fp)
 
@@ -265,6 +268,7 @@ class MultiShiftRMerge(Elaboratable):
     """ shifts down (right) and merges lower bits into m[0].
         m[0] is the "sticky" bit, basically
     """
+
     def __init__(self, width, s_max=None):
         if s_max is None:
             s_max = int(log(width) / log(2))
@@ -289,7 +293,7 @@ class MultiShiftRMerge(Elaboratable):
         mw = Const(self.width-1, len(self.diff))
         m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
                      maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
-                    ]
+                     ]
 
         m.d.comb += [
                 # shift mantissa by maxslen, mask by inverse
@@ -307,6 +311,7 @@ class MultiShiftRMerge(Elaboratable):
 class FPNumShift(FPNumBase, Elaboratable):
     """ Floating-point Number Class for shifting
     """
+
     def __init__(self, mainm, op, inv, width, m_extra=True):
         FPNumBase.__init__(self, width, m_extra)
         self.latch_in = Signal()
@@ -335,7 +340,7 @@ class FPNumShift(FPNumBase, Elaboratable):
         """
         return [self.e.eq(inp.e + 1),
                 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
-               ]
+                ]
 
     def shift_down_multi(self, diff):
         """ shifts a mantissa down. exponent is increased to compensate
@@ -357,12 +362,12 @@ class FPNumShift(FPNumBase, Elaboratable):
         maxslen = Mux(diff > mw, mw, diff)
         rs = sm.rshift(self.m[1:], maxslen)
         maxsleni = mw - maxslen
-        m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
+        m_mask = sm.rshift(self.m1s[1:], maxsleni)  # shift and invert
 
         stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
         return [self.e.eq(self.e + diff),
                 self.m.eq(Cat(stickybits, rs))
-               ]
+                ]
 
     def shift_up_multi(self, diff):
         """ shifts a mantissa up. exponent is decreased to compensate
@@ -373,7 +378,7 @@ class FPNumShift(FPNumBase, Elaboratable):
 
         return [self.e.eq(self.e - diff),
                 self.m.eq(sm.lshift(self.m, maxslen))
-               ]
+                ]
 
 
 class FPNumDecode(FPNumBase):
@@ -388,6 +393,7 @@ class FPNumDecode(FPNumBase):
         (m[-1]) is effectively a carry-overflow.  The other three are
         guard (m[2]), round (m[1]), and sticky (m[0])
     """
+
     def __init__(self, op, fp):
         FPNumBase.__init__(self, fp)
         self.op = op
@@ -406,13 +412,14 @@ class FPNumDecode(FPNumBase):
             is extended to 10 bits so that subtract 127 is done on
             a 10-bit number
         """
-        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        args = [0] * self.m_extra + [v[0:self.e_start]]  # pad with extra zeros
         #print ("decode", self.e_end)
-        return [self.m.eq(Cat(*args)), # mantissa
-                self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
+        return [self.m.eq(Cat(*args)),  # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.fp.P127),  # exp
                 self.s.eq(v[-1]),                 # sign
                 ]
 
+
 class FPNumIn(FPNumBase):
     """ Floating-point Number Class
 
@@ -425,6 +432,7 @@ class FPNumIn(FPNumBase):
         (m[-1]) is effectively a carry-overflow.  The other three are
         guard (m[2]), round (m[1]), and sticky (m[0])
     """
+
     def __init__(self, op, fp):
         FPNumBase.__init__(self, fp)
         self.latch_in = Signal()
@@ -438,11 +446,11 @@ class FPNumIn(FPNumBase):
             a 10-bit number
         """
         v = self.v
-        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        args = [0] * self.m_extra + [v[0:self.e_start]]  # pad with extra zeros
         #print ("decode", self.e_end)
         res = ObjectProxy(m, pipemode=False)
         res.m = Cat(*args)                             # mantissa
-        res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
+        res.e = v[self.e_start:self.e_end] - self.fp.P127  # exp
         res.s = v[-1]                                  # sign
         return res
 
@@ -453,10 +461,10 @@ class FPNumIn(FPNumBase):
             is extended to 10 bits so that subtract 127 is done on
             a 10-bit number
         """
-        args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+        args = [0] * self.m_extra + [v[0:self.e_start]]  # pad with extra zeros
         #print ("decode", self.e_end)
-        return [self.m.eq(Cat(*args)), # mantissa
-                self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+        return [self.m.eq(Cat(*args)),  # mantissa
+                self.e.eq(v[self.e_start:self.e_end] - self.P127),  # exp
                 self.s.eq(v[-1]),                 # sign
                 ]
 
@@ -468,7 +476,7 @@ class FPNumIn(FPNumBase):
         """
         return [self.e.eq(inp.e + 1),
                 self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
-               ]
+                ]
 
     def shift_down_multi(self, diff, inp=None):
         """ shifts a mantissa down. exponent is increased to compensate
@@ -492,13 +500,13 @@ class FPNumIn(FPNumBase):
         maxslen = Mux(diff > mw, mw, diff)
         rs = sm.rshift(inp.m[1:], maxslen)
         maxsleni = mw - maxslen
-        m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
+        m_mask = sm.rshift(self.m1s[1:], maxsleni)  # shift and invert
 
         #stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
         stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
         return [self.e.eq(inp.e + diff),
                 self.m.eq(Cat(stickybit, rs))
-               ]
+                ]
 
     def shift_up_multi(self, diff):
         """ shifts a mantissa up. exponent is decreased to compensate
@@ -509,7 +517,8 @@ class FPNumIn(FPNumBase):
 
         return [self.e.eq(self.e - diff),
                 self.m.eq(sm.lshift(self.m, maxslen))
-               ]
+                ]
+
 
 class Trigger(Elaboratable):
     def __init__(self):
@@ -526,7 +535,7 @@ class Trigger(Elaboratable):
     def eq(self, inp):
         return [self.stb.eq(inp.stb),
                 self.ack.eq(inp.ack)
-               ]
+                ]
 
     def ports(self):
         return [self.stb, self.ack]
@@ -547,8 +556,8 @@ class FPOpIn(PrevControl):
             stb = stb & extra
         return [self.v.eq(in_op.v),          # receive value
                 self.stb.eq(stb),      # receive STB
-                in_op.ack.eq(~self.ack), # send ACK
-               ]
+                in_op.ack.eq(~self.ack),  # send ACK
+                ]
 
     def chain_from(self, in_op, extra=None):
         stb = in_op.stb
@@ -556,8 +565,8 @@ class FPOpIn(PrevControl):
             stb = stb & extra
         return [self.v.eq(in_op.v),          # receive value
                 self.stb.eq(stb),      # receive STB
-                in_op.ack.eq(self.ack), # send ACK
-               ]
+                in_op.ack.eq(self.ack),  # send ACK
+                ]
 
 
 class FPOpOut(NextControl):
@@ -575,8 +584,8 @@ class FPOpOut(NextControl):
             stb = stb & extra
         return [self.v.eq(in_op.v),          # receive value
                 self.stb.eq(stb),      # receive STB
-                in_op.ack.eq(~self.ack), # send ACK
-               ]
+                in_op.ack.eq(~self.ack),  # send ACK
+                ]
 
     def chain_from(self, in_op, extra=None):
         stb = in_op.stb
@@ -584,14 +593,14 @@ class FPOpOut(NextControl):
             stb = stb & extra
         return [self.v.eq(in_op.v),          # receive value
                 self.stb.eq(stb),      # receive STB
-                in_op.ack.eq(self.ack), # send ACK
-               ]
+                in_op.ack.eq(self.ack),  # send ACK
+                ]
 
 
-class Overflow: #(Elaboratable):
+class Overflow:  # (Elaboratable):
     def __init__(self):
         self.guard = Signal(reset_less=True)     # tot[2]
-        self.round_bit = Signal(reset_less=True) # tot[1]
+        self.round_bit = Signal(reset_less=True)  # tot[1]
         self.sticky = Signal(reset_less=True)    # tot[0]
         self.m0 = Signal(reset_less=True)        # mantissa zero bit
 
@@ -648,19 +657,19 @@ class FPBase:
             which has to be taken into account when extracting the result.
         """
         with m.If(a.exp_n127):
-            m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
+            m.d.sync += a.e.eq(a.fp.N126)  # limit a exponent
         with m.Else():
-            m.d.sync += a.m[-1].eq(1) # set top mantissa bit
+            m.d.sync += a.m[-1].eq(1)  # set top mantissa bit
 
     def op_normalise(self, m, op, next_state):
         """ operand normalisation
             NOTE: just like "align", this one keeps going round every clock
                   until the result's exponent is within acceptable "range"
         """
-        with m.If((op.m[-1] == 0)): # check last bit of mantissa
-            m.d.sync +=[
+        with m.If((op.m[-1] == 0)):  # check last bit of mantissa
+            m.d.sync += [
                 op.e.eq(op.e - 1),  # DECREASE exponent
-                op.m.eq(op.m << 1), # shift mantissa UP
+                op.m.eq(op.m << 1),  # shift mantissa UP
             ]
         with m.Else():
             m.next = next_state
@@ -676,9 +685,9 @@ class FPBase:
         with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
             m.d.sync += [
                 z.e.eq(z.e - 1),  # DECREASE exponent
-                z.m.eq(z.m << 1), # shift mantissa UP
+                z.m.eq(z.m << 1),  # shift mantissa UP
                 z.m[0].eq(of.guard),       # steal guard bit (was tot[2])
-                of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
+                of.guard.eq(of.round_bit),  # steal round_bit (was tot[1])
                 of.round_bit.eq(0),        # reset round bit
                 of.m0.eq(of.guard),
             ]
@@ -694,9 +703,9 @@ class FPBase:
                   the extra mantissa bits coming from tot[0..2]
         """
         with m.If(z.e < z.fp.N126):
-            m.d.sync +=[
+            m.d.sync += [
                 z.e.eq(z.e + 1),  # INCREASE exponent
-                z.m.eq(z.m >> 1), # shift mantissa DOWN
+                z.m.eq(z.m >> 1),  # shift mantissa DOWN
                 of.guard.eq(z.m[0]),
                 of.m0.eq(z.m[1]),
                 of.round_bit.eq(of.guard),
@@ -709,9 +718,9 @@ class FPBase:
         """ performs rounding on the output.  TODO: different kinds of rounding
         """
         with m.If(roundz):
-            m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
-            with m.If(z.m == z.fp.m1s): # all 1s
-                m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
+            m.d.sync += z.m.eq(z.m + 1)  # mantissa rounds up
+            with m.If(z.m == z.fp.m1s):  # all 1s
+                m.d.sync += z.e.eq(z.e + 1)  # exponent rounds up
 
     def corrections(self, m, z, next_state):
         """ denormalisation and sign-bug corrections
@@ -752,12 +761,12 @@ class FPState(FPBase):
 
     def set_inputs(self, inputs):
         self.inputs = inputs
-        for k,v in inputs.items():
+        for k, v in inputs.items():
             setattr(self, k, v)
 
     def set_outputs(self, outputs):
         self.outputs = outputs
-        for k,v in outputs.items():
+        for k, v in outputs.items():
             setattr(self, k, v)
 
 
@@ -774,5 +783,3 @@ class FPID:
     def idsync(self, m):
         if self.id_wid is not None:
             m.d.sync += self.out_mid.eq(self.in_mid)
-
-