X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Fnmigen_add_experiment.py;h=8c691868142d2d856cb9ae40d3d7f45d6d0f0855;hb=e1859a9abdb55c218b2272609a05750c11ab8634;hp=795bd7c98ec1c578910b8d91c1659de47d4308a2;hpb=aeb0c772885d60535ed20f09449092ae837c42ee;p=ieee754fpu.git diff --git a/src/add/nmigen_add_experiment.py b/src/add/nmigen_add_experiment.py index 795bd7c9..8c691868 100644 --- a/src/add/nmigen_add_experiment.py +++ b/src/add/nmigen_add_experiment.py @@ -2,234 +2,45 @@ # Copyright (C) Jonathan P Dawson 2013 # 2013-12-12 -from nmigen import Module, Signal, Cat, Const +from nmigen import Module, Signal, Cat from nmigen.cli import main, verilog +from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase -class FPNum: - """ Floating-point Number Class, variable-width TODO (currently 32-bit) +class FPState(FPBase): + def __init__(self, state_from, state_to): + self.state_from = state_from + self.state_to = state_to - Contains signals for an incoming copy of the value, decoded into - sign / exponent / mantissa. - Also contains encoding functions, creation and recognition of - zero, NaN and inf (all signed) + def set_inputs(self, inputs): + self.inputs = inputs + for k,v in inputs.items(): + setattr(self, k, v) - Four extra bits are included in the mantissa: the top bit - (m[-1]) is effectively a carry-overflow. The other three are - guard (m[2]), round (m[1]), and sticky (m[0]) - """ - def __init__(self, width, m_width=None): - self.width = width - if m_width is None: - m_width = width - 5 # mantissa extra bits (top,guard,round) - self.m_width = m_width - self.v = Signal(width) # Latched copy of value - self.m = Signal(m_width) # Mantissa - self.e = Signal((10, True)) # Exponent: 10 bits, signed - self.s = Signal() # Sign bit - - self.mzero = Const(0, (m_width, False)) - self.m1s = Const(-1, (m_width, False)) - self.P128 = Const(128, (10, True)) - self.P127 = Const(127, (10, True)) - self.N127 = Const(-127, (10, True)) - self.N126 = Const(-126, (10, True)) - - def decode(self, v): - """ decodes a latched value into sign / exponent / mantissa - - bias is subtracted here, from the exponent. exponent - is extended to 10 bits so that subtract 127 is done on - a 10-bit number - """ - args = [0] * (self.m_width-24) + [v[0:23]] # pad with extra zeros - return [self.m.eq(Cat(*args)), # mantissa - self.e.eq(v[23:31] - self.P127), # exp (minus bias) - self.s.eq(v[31]), # sign - ] - - def create(self, s, e, m): - """ creates a value from sign / exponent / mantissa - - bias is added here, to the exponent - """ - return [ - self.v[31].eq(s), # sign - self.v[23:31].eq(e + self.P127), # exp (add on bias) - self.v[0:23].eq(m) # mantissa - ] - - def shift_down(self): - """ shifts a mantissa down by one. exponent is increased to compensate - - accuracy is lost as a result in the mantissa however there are 3 - guard bits (the latter of which is the "sticky" bit) - """ - return [self.e.eq(self.e + 1), - self.m.eq(Cat(self.m[0] | self.m[1], self.m[2:], 0)) - ] - - def nan(self, s): - return self.create(s, self.P128, 1<<22) - - def inf(self, s): - return self.create(s, self.P128, 0) - - def zero(self, s): - return self.create(s, self.N127, 0) + def set_outputs(self, outputs): + self.outputs = outputs + for k,v in outputs.items(): + setattr(self, k, v) - def is_nan(self): - return (self.e == self.P128) & (self.m != 0) - def is_inf(self): - return (self.e == self.P128) & (self.m == 0) +class FPGetOpA(FPState): - def is_zero(self): - return (self.e == self.N127) & (self.m == self.mzero) + def action(self, m): + self.get_op(m, self.in_a, self.a, self.state_to) - def is_overflowed(self): - return (self.e > self.P127) - def is_denormalised(self): - return (self.e == self.N126) & (self.m[23] == 0) +class FPGetOpB(FPState): - -class FPOp: - def __init__(self, width): - self.width = width - - self.v = Signal(width) - self.stb = Signal() - self.ack = Signal() - - def ports(self): - return [self.v, self.stb, self.ack] - - -class Overflow: - def __init__(self): - self.guard = Signal() # tot[2] - self.round_bit = Signal() # tot[1] - self.sticky = Signal() # tot[0] - - -class FPBase: - """ IEEE754 Floating Point Base Class - - contains common functions for FP manipulation, such as - extracting and packing operands, normalisation, denormalisation, - rounding etc. - """ - - def get_op(self, m, op, v, next_state): - """ this function moves to the next state and copies the operand - when both stb and ack are 1. - acknowledgement is sent by setting ack to ZERO. - """ - with m.If((op.ack) & (op.stb)): - m.next = next_state - m.d.sync += [ - v.decode(op.v), - op.ack.eq(0) - ] - with m.Else(): - m.d.sync += op.ack.eq(1) - - def denormalise(self, m, a): - """ denormalises a number - """ - with m.If(a.e == a.N127): - m.d.sync += a.e.eq(-126) # limit a exponent - with m.Else(): - m.d.sync += a.m[-1].eq(1) # set top mantissa bit - - def normalise_1(self, m, z, of, next_state): - """ first stage normalisation - - NOTE: just like "align", this one keeps going round every clock - until the result's exponent is within acceptable "range" - NOTE: the weirdness of reassigning guard and round is due to - the extra mantissa bits coming from tot[0..2] - """ - with m.If((z.m[-1] == 0) & (z.e > z.N126)): - m.d.sync +=[ - z.e.eq(z.e - 1), # DECREASE exponent - z.m.eq(z.m << 1), # shift mantissa UP - z.m[0].eq(of.guard), # steal guard bit (was tot[2]) - of.guard.eq(of.round_bit), # steal round_bit (was tot[1]) - of.round_bit.eq(0), # reset round bit - ] - with m.Else(): - m.next = next_state - - def normalise_2(self, m, z, of, next_state): - """ second stage normalisation - - NOTE: just like "align", this one keeps going round every clock - until the result's exponent is within acceptable "range" - NOTE: the weirdness of reassigning guard and round is due to - the extra mantissa bits coming from tot[0..2] - """ - with m.If(z.e < z.N126): - m.d.sync +=[ - z.e.eq(z.e + 1), # INCREASE exponent - z.m.eq(z.m >> 1), # shift mantissa DOWN - of.guard.eq(z.m[0]), - of.round_bit.eq(of.guard), - of.sticky.eq(of.sticky | of.round_bit) - ] - with m.Else(): - m.next = next_state - - def roundz(self, m, z, of, next_state): - """ performs rounding on the output. TODO: different kinds of rounding - """ - m.next = next_state - with m.If(of.guard & (of.round_bit | of.sticky | z.m[0])): - m.d.sync += z.m.eq(z.m + 1) # mantissa rounds up - with m.If(z.m == z.m1s): # all 1s - m.d.sync += z.e.eq(z.e + 1) # exponent rounds up - - def corrections(self, m, z, next_state): - """ denormalisation and sign-bug corrections - """ - m.next = next_state - # denormalised, correct exponent to zero - with m.If(z.is_denormalised()): - m.d.sync += z.m.eq(-127) - # FIX SIGN BUG: -a + a = +0. - with m.If((z.e == z.N126) & (z.m[0:] == 0)): - m.d.sync += z.s.eq(0) - - def pack(self, m, z, next_state): - """ packs the result into the output (detects overflow->Inf) - """ - m.next = next_state - # if overflow occurs, return inf - with m.If(z.is_overflowed()): - m.d.sync += z.inf(0) - with m.Else(): - m.d.sync += z.create(z.s, z.e, z.m) - - def put_z(self, m, z, out_z, next_state): - """ put_z: stores the result in the output. raises stb and waits - for ack to be set to 1 before moving to the next state. - resets stb back to zero when that occurs, as acknowledgement. - """ - m.d.sync += [ - out_z.stb.eq(1), - out_z.v.eq(z.v) - ] - with m.If(out_z.stb & out_z.ack): - m.d.sync += out_z.stb.eq(0) - m.next = next_state + def action(self, m): + self.get_op(m, self.in_b, self.b, self.state_to) class FPADD(FPBase): - def __init__(self, width): + def __init__(self, width, single_cycle=False): FPBase.__init__(self) self.width = width + self.single_cycle = single_cycle self.in_a = FPOp(width) self.in_b = FPOp(width) @@ -241,13 +52,29 @@ class FPADD(FPBase): m = Module() # Latches - a = FPNum(self.width) - b = FPNum(self.width) - z = FPNum(self.width, 24) + a = FPNumIn(self.in_a, self.width) + b = FPNumIn(self.in_b, self.width) + z = FPNumOut(self.width, False) - tot = Signal(28) # sticky/round/guard bits, 23 result, 1 overflow + m.submodules.fpnum_a = a + m.submodules.fpnum_b = b + m.submodules.fpnum_z = z + + w = z.m_width + 4 + tot = Signal(w, reset_less=True) # sticky/round/guard, {mantissa} result, 1 overflow of = Overflow() + m.submodules.overflow = of + + geta = FPGetOpA("get_a", "get_b") + geta.set_inputs({"in_a": self.in_a}) + geta.set_outputs({"a": a}) + m.d.comb += a.v.eq(self.in_a.v) # links in_a to a + + getb = FPGetOpB("get_b", "special_cases") + getb.set_inputs({"in_b": self.in_b}) + getb.set_outputs({"b": b}) + m.d.comb += b.v.eq(self.in_b.v) # links in_b to b with m.FSM() as fsm: @@ -255,13 +82,14 @@ class FPADD(FPBase): # gets operand a with m.State("get_a"): - self.get_op(m, self.in_a, a, "get_b") + geta.action(m) # ****** # gets operand b with m.State("get_b"): - self.get_op(m, self.in_b, b, "special_cases") + #self.get_op(m, self.in_b, b, "special_cases") + getb.action(m) # ****** # special cases: NaNs, infs, zeros, denormalised @@ -270,58 +98,120 @@ class FPADD(FPBase): with m.State("special_cases"): + s_nomatch = Signal() + m.d.comb += s_nomatch.eq(a.s != b.s) + + m_match = Signal() + m.d.comb += m_match.eq(a.m == b.m) + # if a is NaN or b is NaN return NaN - with m.If(a.is_nan() | b.is_nan()): + with m.If(a.is_nan | b.is_nan): m.next = "put_z" m.d.sync += z.nan(1) + # XXX WEIRDNESS for FP16 non-canonical NaN handling + # under review + + ## if a is zero and b is NaN return -b + #with m.If(a.is_zero & (a.s==0) & b.is_nan): + # m.next = "put_z" + # m.d.sync += z.create(b.s, b.e, Cat(b.m[3:-2], ~b.m[0])) + + ## if b is zero and a is NaN return -a + #with m.Elif(b.is_zero & (b.s==0) & a.is_nan): + # m.next = "put_z" + # m.d.sync += z.create(a.s, a.e, Cat(a.m[3:-2], ~a.m[0])) + + ## if a is -zero and b is NaN return -b + #with m.Elif(a.is_zero & (a.s==1) & b.is_nan): + # m.next = "put_z" + # m.d.sync += z.create(a.s & b.s, b.e, Cat(b.m[3:-2], 1)) + + ## if b is -zero and a is NaN return -a + #with m.Elif(b.is_zero & (b.s==1) & a.is_nan): + # m.next = "put_z" + # m.d.sync += z.create(a.s & b.s, a.e, Cat(a.m[3:-2], 1)) + # if a is inf return inf (or NaN) - with m.Elif(a.is_inf()): + with m.Elif(a.is_inf): m.next = "put_z" m.d.sync += z.inf(a.s) # if a is inf and signs don't match return NaN - with m.If((b.e == b.P128) & (a.s != b.s)): - m.d.sync += z.nan(b.s) + with m.If(b.exp_128 & s_nomatch): + m.d.sync += z.nan(1) # if b is inf return inf - with m.Elif(b.is_inf()): + with m.Elif(b.is_inf): m.next = "put_z" m.d.sync += z.inf(b.s) # if a is zero and b zero return signed-a/b - with m.Elif(a.is_zero() & b.is_zero()): + with m.Elif(a.is_zero & b.is_zero): m.next = "put_z" - m.d.sync += z.create(a.s & b.s, b.e[0:8], b.m[3:-1]) + m.d.sync += z.create(a.s & b.s, b.e, b.m[3:-1]) # if a is zero return b - with m.Elif(a.is_zero()): + with m.Elif(a.is_zero): m.next = "put_z" - m.d.sync += z.create(b.s, b.e[0:8], b.m[3:-1]) + m.d.sync += z.create(b.s, b.e, b.m[3:-1]) # if b is zero return a - with m.Elif(b.is_zero()): + with m.Elif(b.is_zero): + m.next = "put_z" + m.d.sync += z.create(a.s, a.e, a.m[3:-1]) + + # if a equal to -b return zero (+ve zero) + with m.Elif(s_nomatch & m_match & (a.e == b.e)): m.next = "put_z" - m.d.sync += z.create(a.s, a.e[0:8], a.m[3:-1]) + m.d.sync += z.zero(0) # Denormalised Number checks with m.Else(): - m.next = "align" - self.denormalise(m, a) - self.denormalise(m, b) + m.next = "denormalise" # ****** - # align. NOTE: this does *not* do single-cycle multi-shifting, - # it *STAYS* in the align state until the exponents match + # denormalise. + + with m.State("denormalise"): + # Denormalised Number checks + m.next = "align" + self.denormalise(m, a) + self.denormalise(m, b) + + # ****** + # align. with m.State("align"): - # exponent of a greater than b: increment b exp, shift b mant - with m.If(a.e > b.e): - m.d.sync += b.shift_down() - # exponent of b greater than a: increment a exp, shift a mant - with m.Elif(a.e < b.e): - m.d.sync += a.shift_down() - # exponents equal: move to next stage. - with m.Else(): + if not self.single_cycle: + # NOTE: this does *not* do single-cycle multi-shifting, + # it *STAYS* in the align state until exponents match + + # exponent of a greater than b: shift b down + with m.If(a.e > b.e): + m.d.sync += b.shift_down() + # exponent of b greater than a: shift a down + with m.Elif(a.e < b.e): + m.d.sync += a.shift_down() + # exponents equal: move to next stage. + with m.Else(): + m.next = "add_0" + else: + # This one however (single-cycle) will do the shift + # in one go. + + # XXX TODO: the shifter used here is quite expensive + # having only one would be better + + ediff = Signal((len(a.e), True), reset_less=True) + ediffr = Signal((len(a.e), True), reset_less=True) + m.d.comb += ediff.eq(a.e - b.e) + m.d.comb += ediffr.eq(b.e - a.e) + with m.If(ediff > 0): + m.d.sync += b.shift_down_multi(ediff) + # exponent of b greater than a: shift a down + with m.Elif(ediff < 0): + m.d.sync += a.shift_down_multi(ediffr) + m.next = "add_0" # ****** @@ -335,19 +225,19 @@ class FPADD(FPBase): # same-sign (both negative or both positive) add mantissas with m.If(a.s == b.s): m.d.sync += [ - tot.eq(a.m + b.m), + tot.eq(Cat(a.m, 0) + Cat(b.m, 0)), z.s.eq(a.s) ] # a mantissa greater than b, use a with m.Elif(a.m >= b.m): m.d.sync += [ - tot.eq(a.m - b.m), + tot.eq(Cat(a.m, 0) - Cat(b.m, 0)), z.s.eq(a.s) ] # b mantissa greater than a, use b with m.Else(): m.d.sync += [ - tot.eq(b.m - a.m), + tot.eq(Cat(b.m, 0) - Cat(a.m, 0)), z.s.eq(b.s) ] @@ -358,9 +248,10 @@ class FPADD(FPBase): with m.State("add_1"): m.next = "normalise_1" # tot[27] gets set when the sum overflows. shift result down - with m.If(tot[27]): + with m.If(tot[-1]): m.d.sync += [ - z.m.eq(tot[4:28]), + z.m.eq(tot[4:]), + of.m0.eq(tot[4]), of.guard.eq(tot[3]), of.round_bit.eq(tot[2]), of.sticky.eq(tot[1] | tot[0]), @@ -369,7 +260,8 @@ class FPADD(FPBase): # tot[27] zero case with m.Else(): m.d.sync += [ - z.m.eq(tot[3:27]), + z.m.eq(tot[3:]), + of.m0.eq(tot[3]), of.guard.eq(tot[2]), of.round_bit.eq(tot[1]), of.sticky.eq(tot[0])