class TestAddMod(Elaboratable):
def __init__(self, width, partpoints):
+ self.partpoints = partpoints
self.a = PartitionedSignal(partpoints, width)
self.b = PartitionedSignal(partpoints, width)
self.add_output = Signal(width)
m.d.comb += self.eq_output.eq(self.a == self.b)
m.d.comb += self.ge_output.eq(self.a >= self.b)
m.d.comb += self.add_output.eq(self.a + self.b)
- m.d.comb += self.mux_out.eq(PMux(m, self.a, self.b, self.mux_sel))
+ ppts = self.partpoints
+ m.d.comb += self.mux_out.eq(PMux(m, ppts, self.mux_sel, self.a, self.b))
return m
from nmigen import Signal, Module, Elaboratable, Mux
from ieee754.part_mul_add.partpoints import PartitionPoints
+from ieee754.part_mul_add.partpoints import make_partition
modcount = 0 # global for now
-def PMux(m, sel, a, b):
+def PMux(m, mask, sel, a, b):
+ global modcount
modcount += 1
- pm = PartitionedMux(a.shape()[0])
- m.d.comb += pm.a.eq(a)
- m.d.comb += pm.b.eq(b)
+ width = a.sig.shape()[0] # get width
+ part_pts = make_partition(mask, width) # create partition points
+ pm = PartitionedMux(width, part_pts)
+ m.d.comb += pm.a.eq(a.sig)
+ m.d.comb += pm.b.eq(b.sig)
m.d.comb += pm.sel.eq(sel)
setattr(m.submodules, "pmux%d" % modcount, pm)
return pm.output
consequently the incoming selector (sel) can completely
ignore what the *actual* partition bits are.
"""
- def __init__(self, width):
+ def __init__(self, width, partition_points):
self.width = width
self.partition_points = PartitionPoints(partition_points)
self.mwidth = len(self.partition_points)+1
self.b = Signal(width, reset_less=True)
self.sel = Signal(self.mwidth, reset_less=True)
self.output = Signal(width, reset_less=True)
- assert (self.partition_points.fits_in_width(width),
- "partition_points doesn't fit in width")
+ assert self.partition_points.fits_in_width(width), \
+ "partition_points doesn't fit in width"
def elaborate(self, platform):
m = Module()
start = 0
for i in range(len(keys)):
end = keys[i]
- mux = output[start:end]
- mux.append(self.a[start:end] == self.b[start:end])
+ mux = self.output[start:end]
+ comb += mux.eq(self.a[start:end] == self.b[start:end])
start = end # for next time round loop
return m
def ports(self):
- return [self.a, self.b, self.sel, self.output]
+ return [self.a.sig, self.b.sig, self.sel, self.output]