From 81452677a105ac366e6ca8f2021f8b3d6feb353e Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Thu, 22 Aug 2019 08:04:06 +0100 Subject: [PATCH] move part modules into FinalOut --- src/ieee754/part_mul_add/multiply.py | 96 +++++++++++++++------------- 1 file changed, 51 insertions(+), 45 deletions(-) diff --git a/src/ieee754/part_mul_add/multiply.py b/src/ieee754/part_mul_add/multiply.py index 044b92c6..09922caa 100644 --- a/src/ieee754/part_mul_add/multiply.py +++ b/src/ieee754/part_mul_add/multiply.py @@ -939,22 +939,51 @@ class FinalOut(Elaboratable): that some partitions requested 8-bit computation whilst others requested 16 or 32 bit. """ - def __init__(self, out_wid): - # inputs - self.d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)] - self.d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)] - self.d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)] - - self.i8 = Signal(out_wid, reset_less=True) - self.i16 = Signal(out_wid, reset_less=True) - self.i32 = Signal(out_wid, reset_less=True) - self.i64 = Signal(out_wid, reset_less=True) - + def __init__(self, output_width, n_parts, partition_points): + self.expanded_part_points = partition_points + self.i = IntermediateData(partition_points, output_width, n_parts) + self.out_wid = output_width//2 # output - self.out = Signal(out_wid, reset_less=True) + self.out = Signal(self.out_wid, reset_less=True) + self.intermediate_output = Signal(output_width, reset_less=True) def elaborate(self, platform): m = Module() + + eps = self.expanded_part_points + m.submodules.p_8 = p_8 = Parts(8, eps, 8) + m.submodules.p_16 = p_16 = Parts(8, eps, 4) + m.submodules.p_32 = p_32 = Parts(8, eps, 2) + m.submodules.p_64 = p_64 = Parts(8, eps, 1) + + out_part_pts = self.i.reg_partition_points + + # temporaries + d8 = [Signal(name=f"d8_{i}", reset_less=True) for i in range(8)] + d16 = [Signal(name=f"d16_{i}", reset_less=True) for i in range(4)] + d32 = [Signal(name=f"d32_{i}", reset_less=True) for i in range(2)] + + i8 = Signal(self.out_wid, reset_less=True) + i16 = Signal(self.out_wid, reset_less=True) + i32 = Signal(self.out_wid, reset_less=True) + i64 = Signal(self.out_wid, reset_less=True) + + m.d.comb += p_8.epps.eq(out_part_pts) + m.d.comb += p_16.epps.eq(out_part_pts) + m.d.comb += p_32.epps.eq(out_part_pts) + m.d.comb += p_64.epps.eq(out_part_pts) + + for i in range(len(p_8.parts)): + m.d.comb += d8[i].eq(p_8.parts[i]) + for i in range(len(p_16.parts)): + m.d.comb += d16[i].eq(p_16.parts[i]) + for i in range(len(p_32.parts)): + m.d.comb += d32[i].eq(p_32.parts[i]) + m.d.comb += i8.eq(self.i.outputs[0]) + m.d.comb += i16.eq(self.i.outputs[1]) + m.d.comb += i32.eq(self.i.outputs[2]) + m.d.comb += i64.eq(self.i.outputs[3]) + ol = [] for i in range(8): # select one of the outputs: d8 selects i8, d16 selects i16 @@ -964,13 +993,12 @@ class FinalOut(Elaboratable): # if neither d8 nor d16 are set, d32 selects either i32 or i64. op = Signal(8, reset_less=True, name="op_%d" % i) m.d.comb += op.eq( - Mux(self.d8[i] | self.d16[i // 2], - Mux(self.d8[i], self.i8.part(i * 8, 8), - self.i16.part(i * 8, 8)), - Mux(self.d32[i // 4], self.i32.part(i * 8, 8), - self.i64.part(i * 8, 8)))) + Mux(d8[i] | d16[i // 2], + Mux(d8[i], i8.part(i * 8, 8), i16.part(i * 8, 8)), + Mux(d32[i // 4], i32.part(i * 8, 8), i64.part(i * 8, 8)))) ol.append(op) m.d.comb += self.out.eq(Cat(*ol)) + m.d.comb += self.intermediate_output.eq(self.i.intermediate_output) return m @@ -1016,6 +1044,7 @@ class Signs(Elaboratable): return m + class IntermediateData: def __init__(self, ppoints, output_width, n_parts): @@ -1031,7 +1060,7 @@ class IntermediateData: part_ops): return [self.reg_partition_points.eq(reg_partition_points)] + \ [self.intermediate_output.eq(intermediate_output)] + \ - [self.outputs.eq(outputs) + [self.outputs[i].eq(outputs[i]) for i in range(4)] + \ [self.part_ops[i].eq(part_ops[i]) for i in range(len(self.part_ops))] @@ -1045,8 +1074,7 @@ class Intermediates(Elaboratable): """ Intermediate output modules """ - def __init__(self, output_width, n_parts, register_levels, - partition_points): + def __init__(self, output_width, n_parts, partition_points): self.i = FinalReduceData(partition_points, output_width, n_parts) self.o = IntermediateData(partition_points, output_width, n_parts) @@ -1226,35 +1254,13 @@ class Mul8_16_32_64(Elaboratable): m.submodules.add_reduce = add_reduce m.d.comb += self.intermediate_output.eq(add_reduce.o.output) - interm = Intermediates(128, 8, self.register_levels, - expanded_part_pts) + interm = Intermediates(128, 8, expanded_part_pts) m.submodules.intermediates = interm m.d.comb += interm.i.eq(add_reduce.o) - m.submodules.p_8 = p_8 = Parts(8, eps, 8) - m.submodules.p_16 = p_16 = Parts(8, eps, 4) - m.submodules.p_32 = p_32 = Parts(8, eps, 2) - m.submodules.p_64 = p_64 = Parts(8, eps, 1) - - out_part_pts = interm.o.reg_partition_points - - m.d.comb += p_8.epps.eq(out_part_pts) - m.d.comb += p_16.epps.eq(out_part_pts) - m.d.comb += p_32.epps.eq(out_part_pts) - m.d.comb += p_64.epps.eq(out_part_pts) - # final output - m.submodules.finalout = finalout = FinalOut(64) - for i in range(len(p_8.parts)): - m.d.comb += finalout.d8[i].eq(p_8.parts[i]) - for i in range(len(p_16.parts)): - m.d.comb += finalout.d16[i].eq(p_16.parts[i]) - for i in range(len(p_32.parts)): - m.d.comb += finalout.d32[i].eq(p_32.parts[i]) - m.d.comb += finalout.i8.eq(interm.o.outputs[0]) - m.d.comb += finalout.i16.eq(interm.o.outputs[1]) - m.d.comb += finalout.i32.eq(interm.o.outputs[2]) - m.d.comb += finalout.i64.eq(interm.o.outputs[3]) + m.submodules.finalout = finalout = FinalOut(128, 8, expanded_part_pts) + m.d.comb += finalout.i.eq(interm.o) m.d.comb += self.output.eq(finalout.out) return m -- 2.30.2