return res[:len(op)]
res = op
for i in range(self.smax):
- zeros = [0] * (1<<i)
- res = Mux(s & (1<<i), Cat(zeros, res[0:-(1<<i)]), res)
+ zeros = [0] * (1 << i)
+ res = Mux(s & (1 << i), Cat(zeros, res[0:-(1 << i)]), res)
return res
def rshift(self, op, s):
return res[:len(op)]
res = op
for i in range(self.smax):
- zeros = [0] * (1<<i)
- res = Mux(s & (1<<i), Cat(res[(1<<i):], zeros), res)
+ zeros = [0] * (1 << i)
+ res = Mux(s & (1 << i), Cat(res[(1 << i):], zeros), res)
return res
class FPNumBaseRecord:
""" Floating-point Base Number Class
"""
+
def __init__(self, width, m_extra=True, e_extra=False):
self.width = width
- m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
- e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
- e_max = 1<<(e_width-3)
- self.rmw = m_width - 1 # real mantissa width (not including extras)
+ m_width = {16: 11, 32: 24, 64: 53}[width] # 1 extra bit (overflow)
+ e_width = {16: 7, 32: 10, 64: 13}[width] # 2 extra bits (overflow)
+ e_max = 1 << (e_width-3)
+ self.rmw = m_width - 1 # real mantissa width (not including extras)
self.e_max = e_max
if m_extra:
# mantissa extra bits (top,guard,round)
else:
self.m_extra = 0
if e_extra:
- self.e_extra = 6 # enough to cover FP64 when converting to FP16
+ self.e_extra = 6 # enough to cover FP64 when converting to FP16
e_width += self.e_extra
else:
self.e_extra = 0
- #print (m_width, e_width, e_max, self.rmw, self.m_extra)
+ # print (m_width, e_width, e_max, self.rmw, self.m_extra)
self.m_width = m_width
self.e_width = e_width
self.e_start = self.rmw
- self.e_end = self.rmw + self.e_width - 2 # for decoding
+ self.e_end = self.rmw + self.e_width - 2 # for decoding
self.v = Signal(width, reset_less=True) # Latched copy of value
self.m = Signal(m_width, reset_less=True) # Mantissa
- self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
+ self.e = Signal((e_width, True), reset_less=True) # exp+2 bits, signed
self.s = Signal(reset_less=True) # Sign bit
self.fp = self
e_width = self.e_width
self.mzero = Const(0, (m_width, False))
- m_msb = 1<<(self.m_width-2)
+ m_msb = 1 << (self.m_width-2)
self.msb1 = Const(m_msb, (m_width, False))
self.m1s = Const(-1, (m_width, False))
self.P128 = Const(e_max, (e_width, True))
"""
return [
self.v[0:self.e_start].eq(m), # mantissa
- self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add on bias)
+ self.v[self.e_start:self.e_end].eq(e + self.fp.P127), # (add bias)
self.v[-1].eq(s), # sign
]
def _nan(self, s):
- return (s, self.fp.P128, 1<<(self.e_start-1))
+ return (s, self.fp.P128, 1 << (self.e_start-1))
def _inf(self, s):
return (s, self.fp.P128, 0)
bias is added here, to the exponent
"""
- e = e + self.P127 # exp (add on bias)
+ e = e + self.P127 # exp (add on bias)
return Cat(m[0:self.e_start],
e[0:self.e_end-self.e_start],
s)
class FPNumBase(FPNumBaseRecord, Elaboratable):
""" Floating-point Base Number Class
"""
+
def __init__(self, fp):
fp.drop_in(self)
self.fp = fp
(m[-1]) is effectively a carry-overflow. The other three are
guard (m[2]), round (m[1]), and sticky (m[0])
"""
+
def __init__(self, fp):
FPNumBase.__init__(self, fp)
""" shifts down (right) and merges lower bits into m[0].
m[0] is the "sticky" bit, basically
"""
+
def __init__(self, width, s_max=None):
if s_max is None:
s_max = int(log(width) / log(2))
mw = Const(self.width-1, len(self.diff))
m.d.comb += [maxslen.eq(Mux(self.diff > mw, mw, self.diff)),
maxsleni.eq(Mux(self.diff > mw, 0, mw-self.diff)),
- ]
+ ]
m.d.comb += [
# shift mantissa by maxslen, mask by inverse
class FPNumShift(FPNumBase, Elaboratable):
""" Floating-point Number Class for shifting
"""
+
def __init__(self, mainm, op, inv, width, m_extra=True):
FPNumBase.__init__(self, width, m_extra)
self.latch_in = Signal()
"""
return [self.e.eq(inp.e + 1),
self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
- ]
+ ]
def shift_down_multi(self, diff):
""" shifts a mantissa down. exponent is increased to compensate
maxslen = Mux(diff > mw, mw, diff)
rs = sm.rshift(self.m[1:], maxslen)
maxsleni = mw - maxslen
- m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
+ m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
stickybits = reduce(or_, self.m[1:] & m_mask) | self.m[0]
return [self.e.eq(self.e + diff),
self.m.eq(Cat(stickybits, rs))
- ]
+ ]
def shift_up_multi(self, diff):
""" shifts a mantissa up. exponent is decreased to compensate
return [self.e.eq(self.e - diff),
self.m.eq(sm.lshift(self.m, maxslen))
- ]
+ ]
class FPNumDecode(FPNumBase):
(m[-1]) is effectively a carry-overflow. The other three are
guard (m[2]), round (m[1]), and sticky (m[0])
"""
+
def __init__(self, op, fp):
FPNumBase.__init__(self, fp)
self.op = op
is extended to 10 bits so that subtract 127 is done on
a 10-bit number
"""
- args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+ args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
#print ("decode", self.e_end)
- return [self.m.eq(Cat(*args)), # mantissa
- self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
+ return [self.m.eq(Cat(*args)), # mantissa
+ self.e.eq(v[self.e_start:self.e_end] - self.fp.P127), # exp
self.s.eq(v[-1]), # sign
]
+
class FPNumIn(FPNumBase):
""" Floating-point Number Class
(m[-1]) is effectively a carry-overflow. The other three are
guard (m[2]), round (m[1]), and sticky (m[0])
"""
+
def __init__(self, op, fp):
FPNumBase.__init__(self, fp)
self.latch_in = Signal()
a 10-bit number
"""
v = self.v
- args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+ args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
#print ("decode", self.e_end)
res = ObjectProxy(m, pipemode=False)
res.m = Cat(*args) # mantissa
- res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
+ res.e = v[self.e_start:self.e_end] - self.fp.P127 # exp
res.s = v[-1] # sign
return res
is extended to 10 bits so that subtract 127 is done on
a 10-bit number
"""
- args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
+ args = [0] * self.m_extra + [v[0:self.e_start]] # pad with extra zeros
#print ("decode", self.e_end)
- return [self.m.eq(Cat(*args)), # mantissa
- self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
+ return [self.m.eq(Cat(*args)), # mantissa
+ self.e.eq(v[self.e_start:self.e_end] - self.P127), # exp
self.s.eq(v[-1]), # sign
]
"""
return [self.e.eq(inp.e + 1),
self.m.eq(Cat(inp.m[0] | inp.m[1], inp.m[2:], 0))
- ]
+ ]
def shift_down_multi(self, diff, inp=None):
""" shifts a mantissa down. exponent is increased to compensate
maxslen = Mux(diff > mw, mw, diff)
rs = sm.rshift(inp.m[1:], maxslen)
maxsleni = mw - maxslen
- m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
+ m_mask = sm.rshift(self.m1s[1:], maxsleni) # shift and invert
#stickybit = reduce(or_, inp.m[1:] & m_mask) | inp.m[0]
stickybit = (inp.m[1:] & m_mask).bool() | inp.m[0]
return [self.e.eq(inp.e + diff),
self.m.eq(Cat(stickybit, rs))
- ]
+ ]
def shift_up_multi(self, diff):
""" shifts a mantissa up. exponent is decreased to compensate
return [self.e.eq(self.e - diff),
self.m.eq(sm.lshift(self.m, maxslen))
- ]
+ ]
+
class Trigger(Elaboratable):
def __init__(self):
def eq(self, inp):
return [self.stb.eq(inp.stb),
self.ack.eq(inp.ack)
- ]
+ ]
def ports(self):
return [self.stb, self.ack]
stb = stb & extra
return [self.v.eq(in_op.v), # receive value
self.stb.eq(stb), # receive STB
- in_op.ack.eq(~self.ack), # send ACK
- ]
+ in_op.ack.eq(~self.ack), # send ACK
+ ]
def chain_from(self, in_op, extra=None):
stb = in_op.stb
stb = stb & extra
return [self.v.eq(in_op.v), # receive value
self.stb.eq(stb), # receive STB
- in_op.ack.eq(self.ack), # send ACK
- ]
+ in_op.ack.eq(self.ack), # send ACK
+ ]
class FPOpOut(NextControl):
stb = stb & extra
return [self.v.eq(in_op.v), # receive value
self.stb.eq(stb), # receive STB
- in_op.ack.eq(~self.ack), # send ACK
- ]
+ in_op.ack.eq(~self.ack), # send ACK
+ ]
def chain_from(self, in_op, extra=None):
stb = in_op.stb
stb = stb & extra
return [self.v.eq(in_op.v), # receive value
self.stb.eq(stb), # receive STB
- in_op.ack.eq(self.ack), # send ACK
- ]
+ in_op.ack.eq(self.ack), # send ACK
+ ]
-class Overflow: #(Elaboratable):
+class Overflow: # (Elaboratable):
def __init__(self):
self.guard = Signal(reset_less=True) # tot[2]
- self.round_bit = Signal(reset_less=True) # tot[1]
+ self.round_bit = Signal(reset_less=True) # tot[1]
self.sticky = Signal(reset_less=True) # tot[0]
self.m0 = Signal(reset_less=True) # mantissa zero bit
which has to be taken into account when extracting the result.
"""
with m.If(a.exp_n127):
- m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
+ m.d.sync += a.e.eq(a.fp.N126) # limit a exponent
with m.Else():
- m.d.sync += a.m[-1].eq(1) # set top mantissa bit
+ m.d.sync += a.m[-1].eq(1) # set top mantissa bit
def op_normalise(self, m, op, next_state):
""" operand normalisation
NOTE: just like "align", this one keeps going round every clock
until the result's exponent is within acceptable "range"
"""
- with m.If((op.m[-1] == 0)): # check last bit of mantissa
- m.d.sync +=[
+ with m.If((op.m[-1] == 0)): # check last bit of mantissa
+ m.d.sync += [
op.e.eq(op.e - 1), # DECREASE exponent
- op.m.eq(op.m << 1), # shift mantissa UP
+ op.m.eq(op.m << 1), # shift mantissa UP
]
with m.Else():
m.next = next_state
with m.If((z.m[-1] == 0) & (z.e > z.fp.N126)):
m.d.sync += [
z.e.eq(z.e - 1), # DECREASE exponent
- z.m.eq(z.m << 1), # shift mantissa UP
+ z.m.eq(z.m << 1), # shift mantissa UP
z.m[0].eq(of.guard), # steal guard bit (was tot[2])
- of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
+ of.guard.eq(of.round_bit), # steal round_bit (was tot[1])
of.round_bit.eq(0), # reset round bit
of.m0.eq(of.guard),
]
the extra mantissa bits coming from tot[0..2]
"""
with m.If(z.e < z.fp.N126):
- m.d.sync +=[
+ m.d.sync += [
z.e.eq(z.e + 1), # INCREASE exponent
- z.m.eq(z.m >> 1), # shift mantissa DOWN
+ z.m.eq(z.m >> 1), # shift mantissa DOWN
of.guard.eq(z.m[0]),
of.m0.eq(z.m[1]),
of.round_bit.eq(of.guard),
""" performs rounding on the output. TODO: different kinds of rounding
"""
with m.If(roundz):
- m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
- with m.If(z.m == z.fp.m1s): # all 1s
- m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
+ m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up
+ with m.If(z.m == z.fp.m1s): # all 1s
+ m.d.sync += z.e.eq(z.e + 1) # exponent rounds up
def corrections(self, m, z, next_state):
""" denormalisation and sign-bug corrections
def set_inputs(self, inputs):
self.inputs = inputs
- for k,v in inputs.items():
+ for k, v in inputs.items():
setattr(self, k, v)
def set_outputs(self, outputs):
self.outputs = outputs
- for k,v in outputs.items():
+ for k, v in outputs.items():
setattr(self, k, v)
def idsync(self, m):
if self.id_wid is not None:
m.d.sync += self.out_mid.eq(self.in_mid)
-
-