X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fieee754%2Fpart_mul_add%2Fmultiply.py;h=a5c5e7064bfafdbeb37acf679d9f852447454549;hb=be1b8a0b886346d9cc7118d833145e3d9d76060d;hp=64994ad36c9feff26d84cd6b52acd05193456fa9;hpb=a750daabd40ac4fe2054bb88b632ab720d985c4e;p=ieee754fpu.git diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 64994ad3..a5c5e706 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -50,15 +50,17 @@ class PartitionPoints(dict): raise ValueError("point must be a non-negative integer") self[point] = Value.wrap(enabled) - def like(self, name=None, src_loc_at=0): + def like(self, name=None, src_loc_at=0, mul=1): """Create a new ``PartitionPoints`` with ``Signal``s for all values. :param name: the base name for the new ``Signal``s. + :param mul: a multiplication factor on the indices """ if name is None: name = Signal(src_loc_at=1+src_loc_at).name # get variable name retval = PartitionPoints() for point, enabled in self.items(): + point *= mul retval[point] = Signal(enabled.shape(), name=f"{name}_{point}") return retval @@ -762,7 +764,10 @@ class Part(Elaboratable): the extra terms - as separate terms - are then thrown at the AddReduce alongside the multiplication part-results. """ - def __init__(self, width, n_parts, n_levels, pbwid): + def __init__(self, epps, width, n_parts, n_levels, pbwid): + + self.pbwid = pbwid + self.epps = epps # inputs self.a = Signal(64) @@ -773,13 +778,6 @@ class Part(Elaboratable): # outputs self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)] - self.delayed_parts = [ - [Signal(name=f"delayed_part_{delay}_{i}") - for i in range(n_parts)] - for delay in range(n_levels)] - # XXX REALLY WEIRD BUG - have to take a copy of the last delayed_parts - self.dplast = [Signal(name=f"dplast_{i}") - for i in range(n_parts)] self.not_a_term = Signal(width) self.neg_lsb_a_term = Signal(width) @@ -789,24 +787,14 @@ class Part(Elaboratable): def elaborate(self, platform): m = Module() - pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts - # negated-temporary copy of partition bits + pbs, parts = self.pbs, self.parts + epps = self.epps + m.submodules.p = p = Parts(self.pbwid, epps, len(parts)) + m.d.comb += p.epps.eq(epps) + parts = p.parts + npbs = Signal.like(pbs, reset_less=True) - m.d.comb += npbs.eq(~pbs) byte_count = 8 // len(parts) - for i in range(len(parts)): - pbl = [] - pbl.append(npbs[i * byte_count - 1]) - for j in range(i * byte_count, (i + 1) * byte_count - 1): - pbl.append(pbs[j]) - pbl.append(npbs[(i + 1) * byte_count - 1]) - value = Signal(len(pbl), name="value_%di" % i, reset_less=True) - m.d.comb += value.eq(Cat(*pbl)) - m.d.comb += parts[i].eq(~(value).bool()) - m.d.comb += delayed_parts[0][i].eq(parts[i]) - m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i]) - for j in range(len(delayed_parts)-1)] - m.d.comb += self.dplast[i].eq(delayed_parts[-1][i]) not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \ self.not_a_term, self.neg_lsb_a_term, \ @@ -1024,7 +1012,7 @@ class Mul8_16_32_64(Elaboratable): m.d.comb += pbs.eq(Cat(*tl)) # create (doubled) PartitionPoints (output is double input width) - expanded_part_pts = PartitionPoints() + expanded_part_pts = eps = PartitionPoints() for i, v in self.part_pts.items(): ep = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True) expanded_part_pts[i * 2] = ep @@ -1039,10 +1027,10 @@ class Mul8_16_32_64(Elaboratable): m.d.comb += s.part_ops.eq(self.part_ops[i]) n_levels = len(self.register_levels)+1 - m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8) - m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8) - m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8) - m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8) + m.submodules.part_8 = part_8 = Part(eps, 128, 8, n_levels, 8) + m.submodules.part_16 = part_16 = Part(eps, 128, 4, n_levels, 8) + m.submodules.part_32 = part_32 = Part(eps, 128, 2, n_levels, 8) + m.submodules.part_64 = part_64 = Part(eps, 128, 1, n_levels, 8) nat_l, nbt_l, nla_l, nlb_l = [], [], [], [] for mod in [part_8, part_16, part_32, part_64]: m.d.comb += mod.a.eq(self.a) @@ -1090,6 +1078,7 @@ class Mul8_16_32_64(Elaboratable): self.part_ops) out_part_ops = add_reduce.levels[-1].out_part_ops + out_part_pts = add_reduce.levels[-1]._reg_partition_points m.submodules.add_reduce = add_reduce m.d.comb += self._intermediate_output.eq(add_reduce.output) @@ -1117,14 +1106,24 @@ class Mul8_16_32_64(Elaboratable): for i in range(8): m.d.comb += io8.part_ops[i].eq(out_part_ops[i]) + m.submodules.p_8 = p_8 = Parts(8, eps, len(part_8.parts)) + m.submodules.p_16 = p_16 = Parts(8, eps, len(part_16.parts)) + m.submodules.p_32 = p_32 = Parts(8, eps, len(part_32.parts)) + m.submodules.p_64 = p_64 = Parts(8, eps, len(part_64.parts)) + + m.d.comb += p_8.epps.eq(out_part_pts) + m.d.comb += p_16.epps.eq(out_part_pts) + m.d.comb += p_32.epps.eq(out_part_pts) + m.d.comb += p_64.epps.eq(out_part_pts) + # final output m.submodules.finalout = finalout = FinalOut(64) - for i in range(len(part_8.delayed_parts[-1])): - m.d.comb += finalout.d8[i].eq(part_8.dplast[i]) - for i in range(len(part_16.delayed_parts[-1])): - m.d.comb += finalout.d16[i].eq(part_16.dplast[i]) - for i in range(len(part_32.delayed_parts[-1])): - m.d.comb += finalout.d32[i].eq(part_32.dplast[i]) + for i in range(len(part_8.parts)): + m.d.comb += finalout.d8[i].eq(p_8.parts[i]) + for i in range(len(part_16.parts)): + m.d.comb += finalout.d16[i].eq(p_16.parts[i]) + for i in range(len(part_32.parts)): + m.d.comb += finalout.d32[i].eq(p_32.parts[i]) m.d.comb += finalout.i8.eq(io8.output) m.d.comb += finalout.i16.eq(io16.output) m.d.comb += finalout.i32.eq(io32.output)