# combine above using Cat
m.d.comb += Cat(*ea).eq(Cat(*al))
m.d.comb += Cat(*eb).eq(Cat(*bl))
- m.d.comb += Cat(*eo).eq(Cat(*ol))
+ m.d.comb += Cat(*ol).eq(Cat(*eo))
# use only one addition to take advantage of look-ahead carry and
# special hardware on FPGAs
m.d.comb += self._expanded_output.eq(
OP_MUL_UNSIGNED_HIGH = 3
+def get_term(value, shift=0, enabled=None):
+ if enabled is not None:
+ value = Mux(enabled, value, 0)
+ if shift > 0:
+ value = Cat(Repl(C(0, 1), shift), value)
+ else:
+ assert shift == 0
+ return value
+
+
+def g_add_term(m, terms, value, shift=0, enabled=None):
+ term = Signal(128, reset_less=True)
+ terms.append(term)
+ m.d.comb += term.eq(get_term(value, shift, enabled))
+
+
+class Term(Elaboratable):
+ def __init__(self, width, twidth, pbwid, a_index, b_index):
+ self.a_index = a_index
+ self.b_index = b_index
+ self.width = width
+ self.a = Signal(width, reset_less=True)
+ self.b = Signal(width, reset_less=True)
+ self.term = Signal(twidth, reset_less=True)
+ self.pb_en = Signal(pbwid, reset_less=True)
+
+ def elaborate(self, platform):
+
+ m = Module()
+ product = Signal(self.width*2, reset_less=True)
+ m.d.comb += product.eq(self.a * self.b)
+
+ tl = []
+ min_index = min(self.a_index, self.b_index)
+ max_index = max(self.a_index, self.b_index)
+ for i in range(min_index, max_index):
+ tl.append(self.pb_en[i])
+ name = "te_%d_%d" % (self.a_index, self.b_index)
+ if len(tl) > 0:
+ term_enabled = Signal(name=name, reset_less=True)
+ m.d.comb += term_enabled.eq(~(Cat(*tl).bool()))
+ else:
+ term_enabled = None
+ shift = 8 * (self.a_index + self.b_index)
+ m.d.comb += self.term.eq(get_term(product, shift, term_enabled))
+
+ return m
+
+
class Mul8_16_32_64(Elaboratable):
"""Signed/Unsigned 8/16/32/64-bit partitioned integer multiplier.
def elaborate(self, platform):
m = Module()
+ # collect part-bytes
+ pbs = Signal(8, reset_less=True)
+ tl = []
+ for i in range(8):
+ pb = Signal(name="pb%d" % i, reset_less=True)
+ m.d.comb += pb.eq(self._part_byte(i))
+ tl.append(pb)
+ m.d.comb += pbs.eq(Cat(*tl))
+
for i in range(len(self.part_ops)):
m.d.comb += self._delayed_part_ops[0][i].eq(self.part_ops[i])
m.d.sync += [self._delayed_part_ops[j + 1][i]
(self._part_8, self._delayed_part_8)]:
byte_count = 8 // len(parts)
for i in range(len(parts)):
- pb = self._part_byte(i * byte_count - 1)
+ pb = pbs[i * byte_count - 1]
value = add_intermediate_value(pb)
for j in range(i * byte_count, (i + 1) * byte_count - 1):
- pb = add_intermediate_value(~self._part_byte(j))
+ pb = add_intermediate_value(~pbs[j])
value = add_intermediate_value(value & pb)
- pb = self._part_byte((i + 1) * byte_count - 1)
+ pb = pbs[(i + 1) * byte_count - 1]
value = add_intermediate_value(value & pb)
m.d.comb += parts[i].eq(value)
m.d.comb += delayed_parts[0][i].eq(parts[i])
m.d.sync += [delayed_parts[j + 1][i].eq(delayed_parts[j][i])
for j in range(len(self.register_levels))]
- products = [[
- Signal(16, name=f"products_{i}_{j}")
- for j in range(8)]
- for i in range(8)]
+ terms = []
for a_index in range(8):
for b_index in range(8):
- a = self.a.part(a_index * 8, 8)
- b = self.b.part(b_index * 8, 8)
- m.d.comb += products[a_index][b_index].eq(a * b)
+ t = Term(8, 128, 8, a_index, b_index)
+ setattr(m.submodules, "term_%d_%d" % (a_index, b_index), t)
- terms = []
+ m.d.comb += t.a.eq(self.a.bit_select(a_index * 8, 8))
+ m.d.comb += t.b.eq(self.b.bit_select(b_index * 8, 8))
+ m.d.comb += t.pb_en.eq(pbs)
- def add_term(value, shift=0, enabled=None):
- term = Signal(128)
- terms.append(term)
- if enabled is not None:
- value = Mux(enabled, value, 0)
- if shift > 0:
- value = Cat(Repl(C(0, 1), shift), value)
- else:
- assert shift == 0
- m.d.comb += term.eq(value)
+ terms.append(t.term)
- for a_index in range(8):
- for b_index in range(8):
- term_enabled: Value = C(True, 1)
- min_index = min(a_index, b_index)
- max_index = max(a_index, b_index)
- for i in range(min_index, max_index):
- term_enabled &= ~self._part_byte(i)
- add_term(products[a_index][b_index],
- 8 * (a_index + b_index),
- term_enabled)
+ def add_term(value, shift=0, enabled=None):
+ g_add_term(m, terms, value, shift, enabled)
for i in range(8):
a_signed = self.part_ops[i] != OP_MUL_UNSIGNED_HIGH
b_signed = (self.part_ops[i] == OP_MUL_LOW) \
- | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
+ | (self.part_ops[i] == OP_MUL_SIGNED_HIGH)
m.d.comb += self._a_signed[i].eq(a_signed)
m.d.comb += self._b_signed[i].eq(b_signed)
byte_width = 8 // len(parts)
bit_width = 8 * byte_width
for i in range(len(parts)):
- ae = parts[i] & self.a[(i + 1) * bit_width - 1] \
+ be = parts[i] & self.a[(i + 1) * bit_width - 1] \
& self._a_signed[i * byte_width]
- be = parts[i] & self.b[(i + 1) * bit_width - 1] \
+ ae = parts[i] & self.b[(i + 1) * bit_width - 1] \
& self._b_signed[i * byte_width]
- a_enabled = Signal(name="a_enabled_%d" % i, reset_less=True)
- b_enabled = Signal(name="b_enabled_%d" % i, reset_less=True)
+ a_enabled = Signal(name="a_en_%d" % i, reset_less=True)
+ b_enabled = Signal(name="b_en_%d" % i, reset_less=True)
m.d.comb += a_enabled.eq(ae)
m.d.comb += b_enabled.eq(be)
# negation operation is split into a bitwise not and a +1.
# likewise for 16, 32, and 64-bit values.
m.d.comb += [
- not_a_term.part(bit_width * 2 * i, bit_width * 2)
+ not_a_term.bit_select(bit_width * 2 * i, bit_width * 2)
.eq(Mux(a_enabled,
Cat(Repl(0, bit_width),
- ~self.a.part(bit_width * i, bit_width)),
+ ~self.a.bit_select(bit_width * i, bit_width)),
0)),
- neg_lsb_a_term.part(bit_width * 2 * i, bit_width * 2)
+ neg_lsb_a_term.bit_select(bit_width * 2 * i, bit_width * 2)
.eq(Cat(Repl(0, bit_width), a_enabled)),
- not_b_term.part(bit_width * 2 * i, bit_width * 2)
+ not_b_term.bit_select(bit_width * 2 * i, bit_width * 2)
.eq(Mux(b_enabled,
Cat(Repl(0, bit_width),
- ~self.b.part(bit_width * i, bit_width)),
+ ~self.b.bit_select(bit_width * i, bit_width)),
0)),
- neg_lsb_b_term.part(bit_width * 2 * i, bit_width * 2)
+ neg_lsb_b_term.bit_select(bit_width * 2 * i, bit_width * 2)
.eq(Cat(Repl(0, bit_width), b_enabled))]
expanded_part_pts = PartitionPoints()
for i, v in self.part_pts.items():
- signal = Signal(name=f"expanded_part_pts_{i*2}")
+ signal = Signal(name=f"expanded_part_pts_{i*2}", reset_less=True)
expanded_part_pts[i * 2] = signal
m.d.comb += signal.eq(v)
m.d.comb += self._intermediate_output.eq(add_reduce.output)
m.d.comb += self._output_64.eq(
Mux(self._delayed_part_ops[-1][0] == OP_MUL_LOW,
- self._intermediate_output.part(0, 64),
- self._intermediate_output.part(64, 64)))
+ self._intermediate_output.bit_select(0, 64),
+ self._intermediate_output.bit_select(64, 64)))
+
+ # create _output_32
+ ol = []
for i in range(2):
- m.d.comb += self._output_32.part(i * 32, 32).eq(
+ op = Signal(32, reset_less=True, name="op32_%d" % i)
+ m.d.comb += op.eq(
Mux(self._delayed_part_ops[-1][4 * i] == OP_MUL_LOW,
- self._intermediate_output.part(i * 64, 32),
- self._intermediate_output.part(i * 64 + 32, 32)))
+ self._intermediate_output.bit_select(i * 64, 32),
+ self._intermediate_output.bit_select(i * 64 + 32, 32)))
+ ol.append(op)
+ m.d.comb += self._output_32.eq(Cat(*ol))
+
+ # create _output_16
+ ol = []
for i in range(4):
- m.d.comb += self._output_16.part(i * 16, 16).eq(
+ op = Signal(16, reset_less=True, name="op16_%d" % i)
+ m.d.comb += op.eq(
Mux(self._delayed_part_ops[-1][2 * i] == OP_MUL_LOW,
- self._intermediate_output.part(i * 32, 16),
- self._intermediate_output.part(i * 32 + 16, 16)))
+ self._intermediate_output.bit_select(i * 32, 16),
+ self._intermediate_output.bit_select(i * 32 + 16, 16)))
+ ol.append(op)
+ m.d.comb += self._output_16.eq(Cat(*ol))
+
+ # create _output_8
+ ol = []
for i in range(8):
- m.d.comb += self._output_8.part(i * 8, 8).eq(
+ op = Signal(8, reset_less=True, name="op8_%d" % i)
+ m.d.comb += op.eq(
Mux(self._delayed_part_ops[-1][i] == OP_MUL_LOW,
- self._intermediate_output.part(i * 16, 8),
- self._intermediate_output.part(i * 16 + 8, 8)))
+ self._intermediate_output.bit_select(i * 16, 8),
+ self._intermediate_output.bit_select(i * 16 + 8, 8)))
+ ol.append(op)
+ m.d.comb += self._output_8.eq(Cat(*ol))
+
+ # final output
+ ol = []
for i in range(8):
- m.d.comb += self.output.part(i * 8, 8).eq(
+ op = Signal(8, reset_less=True, name="op%d" % i)
+ m.d.comb += op.eq(
Mux(self._delayed_part_8[-1][i]
| self._delayed_part_16[-1][i // 2],
Mux(self._delayed_part_8[-1][i],
- self._output_8.part(i * 8, 8),
- self._output_16.part(i * 8, 8)),
+ self._output_8.bit_select(i * 8, 8),
+ self._output_16.bit_select(i * 8, 8)),
Mux(self._delayed_part_32[-1][i // 4],
- self._output_32.part(i * 8, 8),
- self._output_64.part(i * 8, 8))))
+ self._output_32.bit_select(i * 8, 8),
+ self._output_64.bit_select(i * 8, 8))))
+ ol.append(op)
+ m.d.comb += self.output.eq(Cat(*ol))
return m