From 1e72693ec8b5bd42d1c772e927caf9c535bfe3c2 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Fri, 23 Aug 2019 15:34:23 +0100 Subject: [PATCH] add Stage API setup/process to AddReduceInternal --- src/ieee754/part_mul_add/multiply.py | 35 +++++++++++++++++++++------- 1 file changed, 27 insertions(+), 8 deletions(-) diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 92ff7d30..8aecf8ae 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -346,8 +346,9 @@ class FinalAdd(Elaboratable): """ Final stage of add reduce """ - def __init__(self, n_inputs, output_width, n_parts, partition_points, + def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points, partition_step=1): + self.lidx = lidx self.partition_step = partition_step self.output_width = output_width self.n_inputs = n_inputs @@ -367,6 +368,13 @@ class FinalAdd(Elaboratable): return FinalReduceData(self.partition_points, self.output_width, self.n_parts) + def setup(self, m, i): + m.submodules.finaladd = self + m.d.comb += self.i.eq(i) + + def process(self, i): + return self.o + def elaborate(self, platform): """Elaborate this module.""" m = Module() @@ -408,7 +416,7 @@ class AddReduceSingle(Elaboratable): supported, except for by ``Signal.eq``. """ - def __init__(self, n_inputs, output_width, n_parts, partition_points, + def __init__(self, lidx, n_inputs, output_width, n_parts, partition_points, partition_step=1): """Create an ``AddReduce``. @@ -416,6 +424,7 @@ class AddReduceSingle(Elaboratable): :param output_width: bit-width of ``output``. :param partition_points: the input partition points. """ + self.lidx = lidx self.partition_step = partition_step self.n_inputs = n_inputs self.n_parts = n_parts @@ -438,6 +447,13 @@ class AddReduceSingle(Elaboratable): return AddReduceData(self.partition_points, self.n_terms, self.output_width, self.n_parts) + def setup(self, m, i): + setattr(m.submodules, "addreduce_%d" % self.lidx, self) + m.d.comb += self.i.eq(i) + + def process(self, i): + return self.o + @staticmethod def calc_n_inputs(n_inputs, groups): retval = len(groups)*2 @@ -577,7 +593,8 @@ class AddReduceInternal: groups = AddReduceSingle.full_adder_groups(len(inputs)) if len(groups) == 0: break - next_level = AddReduceSingle(ilen, self.output_width, n_parts, + lidx = len(mods) + next_level = AddReduceSingle(lidx, ilen, self.output_width, n_parts, partition_points, self.partition_step) mods.append(next_level) @@ -586,7 +603,8 @@ class AddReduceInternal: ilen = len(inputs) part_ops = next_level.i.part_ops - next_level = FinalAdd(ilen, self.output_width, n_parts, + lidx = len(mods) + next_level = FinalAdd(lidx, ilen, self.output_width, n_parts, partition_points, self.partition_step) mods.append(next_level) @@ -1414,12 +1432,13 @@ class Mul8_16_32_64(Elaboratable): i = at.i for idx in range(len(at.levels)): mcur = at.levels[idx] - setattr(m.submodules, "addreduce_%d" % idx, mcur) + mcur.setup(m, i) + o = mcur.ospec() if idx in self.register_levels: - m.d.sync += mcur.i.eq(i) + m.d.sync += o.eq(mcur.process(i)) else: - m.d.comb += mcur.i.eq(i) - i = mcur.o # for next loop + m.d.comb += o.eq(mcur.process(i)) + i = o # for next loop interm = Intermediates(128, 8, part_pts) interm.setup(m, i) -- 2.30.2