from nmigen.hdl.ast import Assign
from abc import ABCMeta, abstractmethod
from nmigen.cli import main
-
+from functools import reduce
+from operator import or_
class PartitionPoints(dict):
"""Partition points and corresponding ``Value``s.
return m
+class Part(Elaboratable):
+ def __init__(self, width, n_parts, n_levels, pbwid):
+
+ # inputs
+ self.a = Signal(64)
+ self.b = Signal(64)
+ self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
+ self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
+ self.pbs = Signal(pbwid, reset_less=True)
+
+ # outputs
+ self.parts = [Signal(name=f"part_{i}") for i in range(n_parts)]
+ self.delayed_parts = [
+ [Signal(name=f"delayed_part_8_{delay}_{i}")
+ for i in range(n_parts)]
+ for delay in range(n_levels)]
+
+ self.not_a_term = Signal(width)
+ self.neg_lsb_a_term = Signal(width)
+ self.not_b_term = Signal(width)
+ self.neg_lsb_b_term = Signal(width)
+
+ def elaborate(self, platform):
+ m = Module()
+
+ pbs, parts, delayed_parts = self.pbs, self.parts, self.delayed_parts
+ byte_count = 8 // len(parts)
+ for i in range(len(parts)):
+ pbl = []
+ pbl.append(~pbs[i * byte_count - 1])
+ for j in range(i * byte_count, (i + 1) * byte_count - 1):
+ pbl.append(pbs[j])
+ pbl.append(~pbs[(i + 1) * byte_count - 1])
+ value = Signal(len(pbl), reset_less=True)
+ m.d.comb += value.eq(Cat(*pbl))
+ m.d.comb += parts[i].eq(~(value).bool())
+ m.d.comb += delayed_parts[0][i].eq(parts[i])
+ m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
+ for j in range(len(delayed_parts)-1)]
+
+ not_a_term, neg_lsb_a_term, not_b_term, neg_lsb_b_term = \
+ self.not_a_term, self.neg_lsb_a_term, \
+ self.not_b_term, self.neg_lsb_b_term
+
+ byte_width = 8 // len(parts)
+ bit_width = 8 * byte_width
+ nat, nbt, nla, nlb = [], [], [], []
+ for i in range(len(parts)):
+ be = parts[i] & self.a[(i + 1) * bit_width - 1] \
+ & self._a_signed[i * byte_width]
+ ae = parts[i] & self.b[(i + 1) * bit_width - 1] \
+ & self._b_signed[i * byte_width]
+ a_enabled = Signal(name="a_en_%d" % i, reset_less=True)
+ b_enabled = Signal(name="b_en_%d" % i, reset_less=True)
+ m.d.comb += a_enabled.eq(ae)
+ m.d.comb += b_enabled.eq(be)
+
+ # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
+ # negation operation is split into a bitwise not and a +1.
+ # likewise for 16, 32, and 64-bit values.
+ nat.append(Mux(a_enabled,
+ Cat(Repl(0, bit_width),
+ ~self.a.bit_select(bit_width * i, bit_width)),
+ 0))
+
+ nla.append(Cat(Repl(0, bit_width), a_enabled,
+ Repl(0, bit_width-1)))
+
+ nbt.append(Mux(b_enabled,
+ Cat(Repl(0, bit_width),
+ ~self.b.bit_select(bit_width * i, bit_width)),
+ 0))
+
+ nlb.append(Cat(Repl(0, bit_width), b_enabled,
+ Repl(0, bit_width-1)))
+
+ m.d.comb += [not_a_term.eq(Cat(*nat)),
+ not_b_term.eq(Cat(*nbt)),
+ neg_lsb_a_term.eq(Cat(*nla)),
+ neg_lsb_b_term.eq(Cat(*nlb)),
+ ]
+
+ return m
+
class Mul8_16_32_64(Elaboratable):
"""Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
[Signal(2, name=f"_delayed_part_ops_{delay}_{i}")
for i in range(8)]
for delay in range(1 + len(self.register_levels))]
- self._part_8 = [Signal(name=f"_part_8_{i}") for i in range(8)]
- self._part_16 = [Signal(name=f"_part_16_{i}") for i in range(4)]
- self._part_32 = [Signal(name=f"_part_32_{i}") for i in range(2)]
- self._part_64 = [Signal(name=f"_part_64")]
- self._delayed_part_8 = [
- [Signal(name=f"_delayed_part_8_{delay}_{i}")
- for i in range(8)]
- for delay in range(1 + len(self.register_levels))]
- self._delayed_part_16 = [
- [Signal(name=f"_delayed_part_16_{delay}_{i}")
- for i in range(4)]
- for delay in range(1 + len(self.register_levels))]
- self._delayed_part_32 = [
- [Signal(name=f"_delayed_part_32_{delay}_{i}")
- for i in range(2)]
- for delay in range(1 + len(self.register_levels))]
- self._delayed_part_64 = [
- [Signal(name=f"_delayed_part_64_{delay}")]
- for delay in range(1 + len(self.register_levels))]
self._output_64 = Signal(64)
self._output_32 = Signal(64)
self._output_16 = Signal(64)
self._output_8 = Signal(64)
self._a_signed = [Signal(name=f"_a_signed_{i}") for i in range(8)]
self._b_signed = [Signal(name=f"_b_signed_{i}") for i in range(8)]
- self._not_a_term_8 = Signal(128)
- self._neg_lsb_a_term_8 = Signal(128)
- self._not_b_term_8 = Signal(128)
- self._neg_lsb_b_term_8 = Signal(128)
- self._not_a_term_16 = Signal(128)
- self._neg_lsb_a_term_16 = Signal(128)
- self._not_b_term_16 = Signal(128)
- self._neg_lsb_b_term_16 = Signal(128)
- self._not_a_term_32 = Signal(128)
- self._neg_lsb_a_term_32 = Signal(128)
- self._not_b_term_32 = Signal(128)
- self._neg_lsb_b_term_32 = Signal(128)
- self._not_a_term_64 = Signal(128)
- self._neg_lsb_a_term_64 = Signal(128)
- self._not_b_term_64 = Signal(128)
- self._neg_lsb_b_term_64 = Signal(128)
def _part_byte(self, index):
if index == -1 or index == 7:
.eq(self._delayed_part_ops[j][i])
for j in range(len(self.register_levels))]
- for parts, delayed_parts in [(self._part_64, self._delayed_part_64),
- (self._part_32, self._delayed_part_32),
- (self._part_16, self._delayed_part_16),
- (self._part_8, self._delayed_part_8)]:
- byte_count = 8 // len(parts)
- for i in range(len(parts)):
- pbl = []
- pbl.append(~pbs[i * byte_count - 1])
- for j in range(i * byte_count, (i + 1) * byte_count - 1):
- pbl.append(pbs[j])
- pbl.append(~pbs[(i + 1) * byte_count - 1])
- value = Signal(len(pbl), reset_less=True)
- m.d.comb += value.eq(Cat(*pbl))
- m.d.comb += parts[i].eq(~(value).bool())
- m.d.comb += delayed_parts[0][i].eq(parts[i])
- m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
- for j in range(len(self.register_levels))]
+ n_levels = len(self.register_levels)+1
+ m.submodules.part_8 = part_8 = Part(128, 8, n_levels, 8)
+ m.submodules.part_16 = part_16 = Part(128, 4, n_levels, 8)
+ m.submodules.part_32 = part_32 = Part(128, 2, n_levels, 8)
+ m.submodules.part_64 = part_64 = Part(128, 1, n_levels, 8)
+ nat_l, nbt_l, nla_l, nlb_l = [], [], [], []
+ for mod in [part_8, part_16, part_32, part_64]:
+ m.d.comb += mod.a.eq(self.a)
+ m.d.comb += mod.b.eq(self.b)
+ for i in range(len(self._a_signed)):
+ m.d.comb += mod._a_signed[i].eq(self._a_signed[i])
+ for i in range(len(self._b_signed)):
+ m.d.comb += mod._b_signed[i].eq(self._b_signed[i])
+ m.d.comb += mod.pbs.eq(pbs)
+ nat_l.append(mod.not_a_term)
+ nbt_l.append(mod.not_b_term)
+ nla_l.append(mod.neg_lsb_a_term)
+ nlb_l.append(mod.neg_lsb_b_term)
terms = []
# it's fine to bitwise-or these together since they are never enabled
# at the same time
+ nat_l = reduce(or_, nat_l)
+ nbt_l = reduce(or_, nbt_l)
+ nla_l = reduce(or_, nla_l)
+ nlb_l = reduce(or_, nlb_l)
m.submodules.nat = nat = Term(128, 128)
m.submodules.nla = nla = Term(128, 128)
m.submodules.nbt = nbt = Term(128, 128)
m.submodules.nlb = nlb = Term(128, 128)
- m.d.comb += nat.ti.eq(self._not_a_term_8 | self._not_a_term_16
- | self._not_a_term_32 | self._not_a_term_64)
- m.d.comb += nbt.ti.eq(self._not_b_term_8 | self._not_b_term_16
- | self._not_b_term_32 | self._not_b_term_64)
- m.d.comb += nla.ti.eq(self._neg_lsb_a_term_8 | self._neg_lsb_a_term_16
- | self._neg_lsb_a_term_32 | self._neg_lsb_a_term_64)
- m.d.comb += nlb.ti.eq(self._neg_lsb_b_term_8 | self._neg_lsb_b_term_16
- | self._neg_lsb_b_term_32 | self._neg_lsb_b_term_64)
+ m.d.comb += nat.ti.eq(nat_l)
+ m.d.comb += nbt.ti.eq(nbt_l)
+ m.d.comb += nla.ti.eq(nla_l)
+ m.d.comb += nlb.ti.eq(nlb_l)
terms.append(nat.term)
terms.append(nla.term)
terms.append(nbt.term)
terms.append(nlb.term)
- for not_a_term, \
- neg_lsb_a_term, \
- not_b_term, \
- neg_lsb_b_term, \
- parts in [
- (self._not_a_term_8,
- self._neg_lsb_a_term_8,
- self._not_b_term_8,
- self._neg_lsb_b_term_8,
- self._part_8),
- (self._not_a_term_16,
- self._neg_lsb_a_term_16,
- self._not_b_term_16,
- self._neg_lsb_b_term_16,
- self._part_16),
- (self._not_a_term_32,
- self._neg_lsb_a_term_32,
- self._not_b_term_32,
- self._neg_lsb_b_term_32,
- self._part_32),
- (self._not_a_term_64,
- self._neg_lsb_a_term_64,
- self._not_b_term_64,
- self._neg_lsb_b_term_64,
- self._part_64),
- ]:
- byte_width = 8 // len(parts)
- bit_width = 8 * byte_width
- nat, nbt, nla, nlb = [], [], [], []
- for i in range(len(parts)):
- be = parts[i] & self.a[(i + 1) * bit_width - 1] \
- & self._a_signed[i * byte_width]
- ae = parts[i] & self.b[(i + 1) * bit_width - 1] \
- & self._b_signed[i * byte_width]
- a_enabled = Signal(name="a_en_%d" % i, reset_less=True)
- b_enabled = Signal(name="b_en_%d" % i, reset_less=True)
- m.d.comb += a_enabled.eq(ae)
- m.d.comb += b_enabled.eq(be)
-
- # for 8-bit values: form a * 0xFF00 by using -a * 0x100, the
- # negation operation is split into a bitwise not and a +1.
- # likewise for 16, 32, and 64-bit values.
- nat.append(Mux(a_enabled,
- Cat(Repl(0, bit_width),
- ~self.a.bit_select(bit_width * i, bit_width)),
- 0))
-
- nla.append(Cat(Repl(0, bit_width), a_enabled,
- Repl(0, bit_width-1)))
-
- nbt.append(Mux(b_enabled,
- Cat(Repl(0, bit_width),
- ~self.b.bit_select(bit_width * i, bit_width)),
- 0))
-
- nlb.append(Cat(Repl(0, bit_width), b_enabled,
- Repl(0, bit_width-1)))
-
- m.d.comb += [not_a_term.eq(Cat(*nat)),
- not_b_term.eq(Cat(*nbt)),
- neg_lsb_a_term.eq(Cat(*nla)),
- neg_lsb_b_term.eq(Cat(*nlb)),
- ]
-
expanded_part_pts = PartitionPoints()
for i, v in self.part_pts.items():
signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
for i in range(8):
op = Signal(8, reset_less=True, name="op%d" % i)
m.d.comb += op.eq(
- Mux(self._delayed_part_8[-1][i]
- | self._delayed_part_16[-1][i // 2],
- Mux(self._delayed_part_8[-1][i],
+ Mux(part_8.delayed_parts[-1][i]
+ | part_16.delayed_parts[-1][i // 2],
+ Mux(part_8.delayed_parts[-1][i],
self._output_8.bit_select(i * 8, 8),
self._output_16.bit_select(i * 8, 8)),
- Mux(self._delayed_part_32[-1][i // 4],
+ Mux(part_32.delayed_parts[-1][i // 4],
self._output_32.bit_select(i * 8, 8),
self._output_64.bit_select(i * 8, 8))))
ol.append(op)