From 73f2a4b0606065b9ff213bca1bd7c14acadf2a1d Mon Sep 17 00:00:00 2001 From: Luke Kenneth Casson Leighton Date: Fri, 7 Feb 2020 14:47:07 +0000 Subject: [PATCH] fix syntax errors for test_partsig --- src/ieee754/part/partsig.py | 2 ++ src/ieee754/part/test/test_partsig.py | 4 +++- src/ieee754/part_mux/part_mux.py | 24 ++++++++++++++---------- 3 files changed, 19 insertions(+), 11 deletions(-) diff --git a/src/ieee754/part/partsig.py b/src/ieee754/part/partsig.py index b7e6eabb..9e1c3f72 100644 --- a/src/ieee754/part/partsig.py +++ b/src/ieee754/part/partsig.py @@ -49,6 +49,8 @@ class PartitionedSignal: return "%s%d" % (category, self.modnames[category]) def eq(self, val): + if isinstance(val, PartitionedSignal): + return self.sig.eq(val.sig) return self.sig.eq(val) # unary ops that require partitioning diff --git a/src/ieee754/part/test/test_partsig.py b/src/ieee754/part/test/test_partsig.py index eb12b8a6..c9dfcb52 100644 --- a/src/ieee754/part/test/test_partsig.py +++ b/src/ieee754/part/test/test_partsig.py @@ -26,6 +26,7 @@ def create_simulator(module, traces, test_name): class TestAddMod(Elaboratable): def __init__(self, width, partpoints): + self.partpoints = partpoints self.a = PartitionedSignal(partpoints, width) self.b = PartitionedSignal(partpoints, width) self.add_output = Signal(width) @@ -49,7 +50,8 @@ class TestAddMod(Elaboratable): m.d.comb += self.eq_output.eq(self.a == self.b) m.d.comb += self.ge_output.eq(self.a >= self.b) m.d.comb += self.add_output.eq(self.a + self.b) - m.d.comb += self.mux_out.eq(PMux(m, self.a, self.b, self.mux_sel)) + ppts = self.partpoints + m.d.comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b)) return m diff --git a/src/ieee754/part_mux/part_mux.py b/src/ieee754/part_mux/part_mux.py index 54f34812..b0cd90b4 100644 --- a/src/ieee754/part_mux/part_mux.py +++ b/src/ieee754/part_mux/part_mux.py @@ -15,13 +15,17 @@ See: from nmigen import Signal, Module, Elaboratable, Mux from ieee754.part_mul_add.partpoints import PartitionPoints +from ieee754.part_mul_add.partpoints import make_partition modcount = 0 # global for now -def PMux(m, sel, a, b): +def PMux(m, mask, sel, a, b): + global modcount modcount += 1 - pm = PartitionedMux(a.shape()[0]) - m.d.comb += pm.a.eq(a) - m.d.comb += pm.b.eq(b) + width = a.sig.shape()[0] # get width + part_pts = make_partition(mask, width) # create partition points + pm = PartitionedMux(width, part_pts) + m.d.comb += pm.a.eq(a.sig) + m.d.comb += pm.b.eq(b.sig) m.d.comb += pm.sel.eq(sel) setattr(m.submodules, "pmux%d" % modcount, pm) return pm.output @@ -35,7 +39,7 @@ class PartitionedMux(Elaboratable): consequently the incoming selector (sel) can completely ignore what the *actual* partition bits are. """ - def __init__(self, width): + def __init__(self, width, partition_points): self.width = width self.partition_points = PartitionPoints(partition_points) self.mwidth = len(self.partition_points)+1 @@ -43,8 +47,8 @@ class PartitionedMux(Elaboratable): self.b = Signal(width, reset_less=True) self.sel = Signal(self.mwidth, reset_less=True) self.output = Signal(width, reset_less=True) - assert (self.partition_points.fits_in_width(width), - "partition_points doesn't fit in width") + assert self.partition_points.fits_in_width(width), \ + "partition_points doesn't fit in width" def elaborate(self, platform): m = Module() @@ -56,12 +60,12 @@ class PartitionedMux(Elaboratable): start = 0 for i in range(len(keys)): end = keys[i] - mux = output[start:end] - mux.append(self.a[start:end] == self.b[start:end]) + mux = self.output[start:end] + comb += mux.eq(self.a[start:end] == self.b[start:end]) start = end # for next time round loop return m def ports(self): - return [self.a, self.b, self.sel, self.output] + return [self.a.sig, self.b.sig, self.sel, self.output] -- 2.30.2