From 1fdeca5501de3c4d9c4a98e4a70a8d80aeaf6f78 Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Fri, 23 Aug 2019 00:39:13 +0100 Subject: [PATCH] move part-bytes to AllTerms --- src/ieee754/part_mul_add/multiply.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 7b8782e2..0d65e30f 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -1074,7 +1074,7 @@ class AllTerms(Elaboratable): """Set of terms to be added together """ - def __init__(self, pbwid, n_inputs, output_width, n_parts, register_levels, + def __init__(self, n_inputs, output_width, n_parts, register_levels, partition_points): """Create an ``AddReduce``. @@ -1086,7 +1086,6 @@ class AllTerms(Elaboratable): """ self.epps = partition_points.like() self.register_levels = register_levels - self.pbwid = pbwid self.n_inputs = n_inputs self.n_parts = n_parts self.output_width = output_width @@ -1096,15 +1095,22 @@ class AllTerms(Elaboratable): self.a = Signal(64) self.b = Signal(64) - self.pbs = Signal(pbwid, reset_less=True) self.part_ops = [Signal(2, name=f"part_ops_{i}") for i in range(8)] def elaborate(self, platform): m = Module() - pbs = self.pbs eps = self.epps + # collect part-bytes + pbs = Signal(8, reset_less=True) + tl = [] + for i in range(8): + pb = Signal(name="pb%d" % i, reset_less=True) + m.d.comb += pb.eq(eps.part_byte(i, mfactor=2)) + tl.append(pb) + m.d.comb += pbs.eq(Cat(*tl)) + # local variables signs = [] for i in range(8): @@ -1273,15 +1279,6 @@ class Mul8_16_32_64(Elaboratable): def elaborate(self, platform): m = Module() - # collect part-bytes - pbs = Signal(8, reset_less=True) - tl = [] - for i in range(8): - pb = Signal(name="pb%d" % i, reset_less=True) - m.d.comb += pb.eq(self.part_pts.part_byte(i)) - tl.append(pb) - m.d.comb += pbs.eq(Cat(*tl)) - # create (doubled) PartitionPoints (output is double input width) expanded_part_pts = eps = PartitionPoints() for i, v in self.part_pts.items(): @@ -1291,12 +1288,11 @@ class Mul8_16_32_64(Elaboratable): n_inputs = 64 + 4 n_parts = 8 #len(self.part_pts) - t = AllTerms(8, n_inputs, 128, n_parts, self.register_levels, + t = AllTerms(n_inputs, 128, n_parts, self.register_levels, eps) m.submodules.allterms = t m.d.comb += t.a.eq(self.a) m.d.comb += t.b.eq(self.b) - m.d.comb += t.pbs.eq(pbs) m.d.comb += t.epps.eq(eps) for i in range(8): m.d.comb += t.part_ops[i].eq(self.part_ops[i]) -- 2.30.2