X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fadd%2Fnmigen_add_experiment.py;h=edcb17dd62461eba301b44e5239d0ea57c67fca8;hb=e768533532bb2035e9cbc78e2db86affc694e290;hp=cbce60a564e67412d0463d07eeb58bd06e236fd3;hpb=ec47adc80a3c803553eff3a72fb454535512bc68;p=ieee754fpu.git diff --git a/src/add/nmigen_add_experiment.py b/src/add/nmigen_add_experiment.py index cbce60a5..edcb17dd 100644 --- a/src/add/nmigen_add_experiment.py +++ b/src/add/nmigen_add_experiment.py @@ -7,7 +7,7 @@ from nmigen.lib.coding import PriorityEncoder from nmigen.cli import main, verilog from fpbase import FPNumIn, FPNumOut, FPOp, Overflow, FPBase, FPNumBase -from fpbase import MultiShiftRMerge +from fpbase import MultiShiftRMerge, Trigger #from fpbase import FPNumShiftMultiRight class FPState(FPBase): @@ -74,6 +74,69 @@ class FPGetOp(FPState): m.d.sync += self.in_op.ack.eq(1) +class FPGet2OpMod(Trigger): + def __init__(self, width): + Trigger.__init__(self) + self.in_op1 = Signal(width, reset_less=True) + self.in_op2 = Signal(width, reset_less=True) + self.out_op1 = FPNumIn(None, width) + self.out_op2 = FPNumIn(None, width) + + def elaborate(self, platform): + m = Trigger.elaborate(self, platform) + #m.submodules.get_op_in = self.in_op + m.submodules.get_op1_out = self.out_op1 + m.submodules.get_op2_out = self.out_op2 + with m.If(self.trigger): + m.d.comb += [ + self.out_op1.decode(self.in_op1), + self.out_op2.decode(self.in_op2), + ] + return m + + +class FPGet2Op(FPState): + """ gets operands + """ + + def __init__(self, in_state, out_state, in_op1, in_op2, width): + FPState.__init__(self, in_state) + self.out_state = out_state + self.mod = FPGet2OpMod(width) + self.in_op1 = in_op1 + self.in_op2 = in_op2 + self.out_op1 = FPNumIn(None, width) + self.out_op2 = FPNumIn(None, width) + self.in_stb = Signal(reset_less=True) + self.out_ack = Signal(reset_less=True) + self.out_decode = Signal(reset_less=True) + + def setup(self, m, in_op1, in_op2, in_stb): + """ links module to inputs and outputs + """ + m.submodules.get_ops = self.mod + m.d.comb += self.mod.in_op1.eq(in_op1) + m.d.comb += self.mod.in_op2.eq(in_op2) + m.d.comb += self.mod.stb.eq(in_stb) + m.d.comb += self.out_ack.eq(self.mod.ack) + m.d.comb += self.out_decode.eq(self.mod.trigger) + #m.d.comb += self.out_op1.v.eq(self.mod.out_op1.v) + #m.d.comb += self.out_op2.v.eq(self.mod.out_op2.v) + + def action(self, m): + with m.If(self.out_decode): + m.next = self.out_state + m.d.sync += [ + self.mod.ack.eq(0), + #self.out_op1.v.eq(self.mod.out_op1.v), + #self.out_op2.v.eq(self.mod.out_op2.v), + self.out_op1.copy(self.mod.out_op1), + self.out_op2.copy(self.mod.out_op2) + ] + with m.Else(): + m.d.sync += self.mod.ack.eq(1) + + class FPAddSpecialCasesMod: """ special cases: NaNs, infs, zeros, denormalised NOTE: some of these are unique to add. see "Special Operations" @@ -175,14 +238,14 @@ class FPID: def __init__(self, id_wid): self.id_wid = id_wid if self.id_wid: - self.in_mid = Signal(width, reset_less) - self.out_mid = Signal(width, reset_less) + self.in_mid = Signal(id_wid, reset_less=True) + self.out_mid = Signal(id_wid, reset_less=True) else: self.in_mid = None self.out_mid = None def idsync(self, m): - if self.id_wid: + if self.id_wid is not None: m.d.sync += self.out_mid.eq(self.in_mid) @@ -207,7 +270,7 @@ class FPAddSpecialCases(FPState, FPID): m.d.comb += self.mod.in_b.copy(in_b) #m.d.comb += self.out_z.v.eq(self.mod.out_z.v) m.d.comb += self.out_do_z.eq(self.mod.out_do_z) - if self.in_mid: + if self.in_mid is not None: m.d.comb += self.in_mid.eq(in_mid) def action(self, m): @@ -249,22 +312,26 @@ class FPAddDeNormMod(FPState): return m -class FPAddDeNorm(FPState): +class FPAddDeNorm(FPState, FPID): - def __init__(self, width): + def __init__(self, width, id_wid): FPState.__init__(self, "denormalise") + FPID.__init__(self, id_wid) self.mod = FPAddDeNormMod(width) self.out_a = FPNumBase(width) self.out_b = FPNumBase(width) - def setup(self, m, in_a, in_b): + def setup(self, m, in_a, in_b, in_mid): """ links module to inputs and outputs """ m.submodules.denormalise = self.mod m.d.comb += self.mod.in_a.copy(in_a) m.d.comb += self.mod.in_b.copy(in_b) + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) def action(self, m): + self.idsync(m) # Denormalised Number checks m.next = "align" m.d.sync += self.out_a.copy(self.mod.out_a) @@ -313,16 +380,17 @@ class FPAddAlignMultiMod(FPState): return m -class FPAddAlignMulti(FPState): +class FPAddAlignMulti(FPState, FPID): - def __init__(self, width): + def __init__(self, width, id_wid): + FPID.__init__(self, id_wid) FPState.__init__(self, "align") self.mod = FPAddAlignMultiMod(width) self.out_a = FPNumIn(None, width) self.out_b = FPNumIn(None, width) self.exp_eq = Signal(reset_less=True) - def setup(self, m, in_a, in_b): + def setup(self, m, in_a, in_b, in_mid): """ links module to inputs and outputs """ m.submodules.align = self.mod @@ -331,8 +399,11 @@ class FPAddAlignMulti(FPState): #m.d.comb += self.out_a.copy(self.mod.out_a) #m.d.comb += self.out_b.copy(self.mod.out_b) m.d.comb += self.exp_eq.eq(self.mod.exp_eq) + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) def action(self, m): + self.idsync(m) m.d.sync += self.out_a.copy(self.mod.out_a) m.d.sync += self.out_b.copy(self.mod.out_b) with m.If(self.exp_eq): @@ -413,22 +484,26 @@ class FPAddAlignSingleMod: return m -class FPAddAlignSingle(FPState): +class FPAddAlignSingle(FPState, FPID): - def __init__(self, width): + def __init__(self, width, id_wid): FPState.__init__(self, "align") + FPID.__init__(self, id_wid) self.mod = FPAddAlignSingleMod(width) self.out_a = FPNumIn(None, width) self.out_b = FPNumIn(None, width) - def setup(self, m, in_a, in_b): + def setup(self, m, in_a, in_b, in_mid): """ links module to inputs and outputs """ m.submodules.align = self.mod m.d.comb += self.mod.in_a.copy(in_a) m.d.comb += self.mod.in_b.copy(in_b) + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) def action(self, m): + self.idsync(m) # NOTE: could be done as comb m.d.sync += self.out_a.copy(self.mod.out_a) m.d.sync += self.out_b.copy(self.mod.out_b) @@ -483,31 +558,34 @@ class FPAddStage0Mod: return m -class FPAddStage0(FPState): +class FPAddStage0(FPState, FPID): """ First stage of add. covers same-sign (add) and subtract special-casing when mantissas are greater or equal, to give greatest accuracy. """ - def __init__(self, width): + def __init__(self, width, id_wid): FPState.__init__(self, "add_0") + FPID.__init__(self, id_wid) self.mod = FPAddStage0Mod(width) self.out_z = FPNumBase(width, False) self.out_tot = Signal(self.out_z.m_width + 4, reset_less=True) - def setup(self, m, in_a, in_b): + def setup(self, m, in_a, in_b, in_mid): """ links module to inputs and outputs """ m.submodules.add0 = self.mod - m.d.comb += self.mod.in_a.copy(in_a) m.d.comb += self.mod.in_b.copy(in_b) + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) def action(self, m): - m.next = "add_1" + self.idsync(m) # NOTE: these could be done as combinatorial (merge add0+add1) m.d.sync += self.out_z.copy(self.mod.out_z) m.d.sync += self.out_tot.eq(self.mod.out_tot) + m.next = "add_1" class FPAddStage1Mod(FPState): @@ -551,27 +629,32 @@ class FPAddStage1Mod(FPState): return m -class FPAddStage1(FPState): +class FPAddStage1(FPState, FPID): - def __init__(self, width): + def __init__(self, width, id_wid): FPState.__init__(self, "add_1") + FPID.__init__(self, id_wid) self.mod = FPAddStage1Mod(width) self.out_z = FPNumBase(width, False) self.out_of = Overflow() self.norm_stb = Signal() - def setup(self, m, in_tot, in_z): + def setup(self, m, in_tot, in_z, in_mid): """ links module to inputs and outputs """ m.submodules.add1 = self.mod + m.submodules.add1_out_overflow = self.out_of m.d.comb += self.mod.in_z.copy(in_z) m.d.comb += self.mod.in_tot.eq(in_tot) m.d.sync += self.norm_stb.eq(0) # sets to zero when not in add1 state + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) + def action(self, m): - m.submodules.add1_out_overflow = self.out_of + self.idsync(m) m.d.sync += self.out_of.copy(self.mod.out_of) m.d.sync += self.out_z.copy(self.mod.out_z) m.d.sync += self.norm_stb.eq(1) @@ -746,9 +829,10 @@ class FPNorm1ModMulti: return m -class FPNorm1(FPState): +class FPNorm1(FPState, FPID): - def __init__(self, width, single_cycle=True): + def __init__(self, width, id_wid, single_cycle=True): + FPID.__init__(self, id_wid) FPState.__init__(self, "normalise_1") if single_cycle: self.mod = FPNorm1ModSingle(width) @@ -763,7 +847,7 @@ class FPNorm1(FPState): self.out_z = FPNumBase(width) self.out_roundz = Signal(reset_less=True) - def setup(self, m, in_z, in_of, norm_stb): + def setup(self, m, in_z, in_of, norm_stb, in_mid): """ links module to inputs and outputs """ m.submodules.normalise_1 = self.mod @@ -781,8 +865,11 @@ class FPNorm1(FPState): m.d.comb += self.stb.eq(norm_stb) m.d.sync += self.ack.eq(0) # sets to zero when not in normalise_1 state - def action(self, m): + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) + def action(self, m): + self.idsync(m) m.d.comb += self.in_accept.eq((~self.ack) & (self.stb)) m.d.sync += self.temp_of.copy(self.mod.out_of) m.d.sync += self.temp_z.copy(self.out_z) @@ -817,22 +904,26 @@ class FPRoundMod: return m -class FPRound(FPState): +class FPRound(FPState, FPID): - def __init__(self, width): + def __init__(self, width, id_wid): FPState.__init__(self, "round") + FPID.__init__(self, id_wid) self.mod = FPRoundMod(width) self.out_z = FPNumBase(width) - def setup(self, m, in_z, roundz): + def setup(self, m, in_z, roundz, in_mid): """ links module to inputs and outputs """ m.submodules.roundz = self.mod m.d.comb += self.mod.in_z.copy(in_z) m.d.comb += self.mod.in_roundz.eq(roundz) + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) def action(self, m): + self.idsync(m) m.d.sync += self.out_z.copy(self.mod.out_z) m.next = "corrections" @@ -853,20 +944,24 @@ class FPCorrectionsMod: return m -class FPCorrections(FPState): +class FPCorrections(FPState, FPID): - def __init__(self, width): + def __init__(self, width, id_wid): FPState.__init__(self, "corrections") + FPID.__init__(self, id_wid) self.mod = FPCorrectionsMod(width) self.out_z = FPNumBase(width) - def setup(self, m, in_z): + def setup(self, m, in_z, in_mid): """ links module to inputs and outputs """ m.submodules.corrections = self.mod m.d.comb += self.mod.in_z.copy(in_z) + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) def action(self, m): + self.idsync(m) m.d.sync += self.out_z.copy(self.mod.out_z) m.next = "pack" @@ -887,42 +982,134 @@ class FPPackMod: return m -class FPPack(FPState): +class FPPack(FPState, FPID): - def __init__(self, width): + def __init__(self, width, id_wid): FPState.__init__(self, "pack") + FPID.__init__(self, id_wid) self.mod = FPPackMod(width) self.out_z = FPNumOut(width, False) - def setup(self, m, in_z): + def setup(self, m, in_z, in_mid): """ links module to inputs and outputs """ m.submodules.pack = self.mod m.d.comb += self.mod.in_z.copy(in_z) + if self.in_mid is not None: + m.d.comb += self.in_mid.eq(in_mid) def action(self, m): + self.idsync(m) m.d.sync += self.out_z.v.eq(self.mod.out_z.v) m.next = "pack_put_z" class FPPutZ(FPState): - def __init__(self, state, in_z, out_z): + def __init__(self, state, in_z, out_z, in_mid, out_mid): FPState.__init__(self, state) self.in_z = in_z self.out_z = out_z + self.in_mid = in_mid + self.out_mid = out_mid def action(self, m): + if self.in_mid is not None: + m.d.sync += self.out_mid.eq(self.in_mid) m.d.sync += [ self.out_z.v.eq(self.in_z.v) ] with m.If(self.out_z.stb & self.out_z.ack): m.d.sync += self.out_z.stb.eq(0) - m.next = "get_a" + m.next = "get_ops" with m.Else(): m.d.sync += self.out_z.stb.eq(1) +class FPADDBase(FPID): + + def __init__(self, width, id_wid=None, single_cycle=False): + """ IEEE754 FP Add + + * width: bit-width of IEEE754. supported: 16, 32, 64 + * id_wid: an identifier that is sync-connected to the input + * single_cycle: True indicates each stage to complete in 1 clock + """ + FPID.__init__(self, id_wid) + self.width = width + self.single_cycle = single_cycle + + self.in_t = Trigger() + self.in_a = Signal(width) + self.in_b = Signal(width) + self.out_z = FPOp(width) + + self.states = [] + + def add_state(self, state): + self.states.append(state) + return state + + def get_fragment(self, platform=None): + """ creates the HDL code-fragment for FPAdd + """ + m = Module() + m.submodules.out_z = self.out_z + m.submodules.in_t = self.in_t + + get = self.add_state(FPGet2Op("get_ops", "special_cases", + self.in_a, self.in_b, self.width)) + get.setup(m, self.in_a, self.in_b, self.in_t.stb) + m.d.comb += self.in_t.ack.eq(get.mod.ack) + a = get.out_op1 + b = get.out_op2 + + sc = self.add_state(FPAddSpecialCases(self.width, self.id_wid)) + sc.setup(m, a, b, self.in_mid) + + dn = self.add_state(FPAddDeNorm(self.width, self.id_wid)) + dn.setup(m, a, b, sc.in_mid) + + if self.single_cycle: + alm = self.add_state(FPAddAlignSingle(self.width, self.id_wid)) + alm.setup(m, dn.out_a, dn.out_b, dn.in_mid) + else: + alm = self.add_state(FPAddAlignMulti(self.width, self.id_wid)) + alm.setup(m, dn.out_a, dn.out_b, dn.in_mid) + + add0 = self.add_state(FPAddStage0(self.width, self.id_wid)) + add0.setup(m, alm.out_a, alm.out_b, alm.in_mid) + + add1 = self.add_state(FPAddStage1(self.width, self.id_wid)) + add1.setup(m, add0.out_tot, add0.out_z, add0.in_mid) + + n1 = self.add_state(FPNorm1(self.width, self.id_wid)) + n1.setup(m, add1.out_z, add1.out_of, add1.norm_stb, add0.in_mid) + + rn = self.add_state(FPRound(self.width, self.id_wid)) + rn.setup(m, n1.out_z, n1.out_roundz, n1.in_mid) + + cor = self.add_state(FPCorrections(self.width, self.id_wid)) + cor.setup(m, rn.out_z, rn.in_mid) + + pa = self.add_state(FPPack(self.width, self.id_wid)) + pa.setup(m, cor.out_z, rn.in_mid) + + ppz = self.add_state(FPPutZ("pack_put_z", pa.out_z, self.out_z, + pa.in_mid, self.out_mid)) + + pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z, + pa.in_mid, self.out_mid)) + + with m.FSM() as fsm: + + for state in self.states: + with m.State(state.state_from): + state.action(m) + + return m + + class FPADD(FPID): def __init__(self, width, id_wid=None, single_cycle=False): @@ -964,40 +1151,12 @@ class FPADD(FPID): getb.setup(m, self.in_b) b = getb.out_op - sc = self.add_state(FPAddSpecialCases(self.width, self.id_wid)) - sc.setup(m, a, b, self.in_mid) - - dn = self.add_state(FPAddDeNorm(self.width)) - dn.setup(m, a, b) - - if self.single_cycle: - alm = self.add_state(FPAddAlignSingle(self.width)) - alm.setup(m, dn.out_a, dn.out_b) - else: - alm = self.add_state(FPAddAlignMulti(self.width)) - alm.setup(m, dn.out_a, dn.out_b) - - add0 = self.add_state(FPAddStage0(self.width)) - add0.setup(m, alm.out_a, alm.out_b) - - add1 = self.add_state(FPAddStage1(self.width)) - add1.setup(m, add0.out_tot, add0.out_z) - - n1 = self.add_state(FPNorm1(self.width)) - n1.setup(m, add1.out_z, add1.out_of, add1.norm_stb) - - rn = self.add_state(FPRound(self.width)) - rn.setup(m, n1.out_z, n1.out_roundz) - - cor = self.add_state(FPCorrections(self.width)) - cor.setup(m, rn.out_z) - - pa = self.add_state(FPPack(self.width)) - pa.setup(m, cor.out_z) - - ppz = self.add_state(FPPutZ("pack_put_z", pa.out_z, self.out_z)) + ab = FPADDBase() + #pa = self.add_state(FPPack(self.width, self.id_wid)) + #pa.setup(m, cor.out_z, rn.in_mid) - pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z)) + pz = self.add_state(FPPutZ("put_z", sc.out_z, self.out_z, + pa.in_mid, self.out_mid)) with m.FSM() as fsm: @@ -1009,8 +1168,11 @@ class FPADD(FPID): if __name__ == "__main__": - alu = FPADD(width=32, single_cycle=True) - main(alu, ports=alu.in_a.ports() + alu.in_b.ports() + alu.out_z.ports()) + alu = FPADDBase(width=32, id_wid=5, single_cycle=True) + main(alu, ports=[alu.in_a, alu.in_b] + \ + alu.in_t.ports() + \ + alu.out_z.ports() + \ + [alu.in_mid, alu.out_mid]) # works... but don't use, just do "python fname.py convert -t v"