From 89a5b8b6299d71fa1315b9f9b0b3be7af6fe4c5c Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 6 Apr 2022 19:29:38 -0700 Subject: [PATCH] remove uses of BitwiseXorReduce --- src/nmigen_gf/hdl/clmul.py | 25 ++++++++++++------------- 1 file changed, 12 insertions(+), 13 deletions(-) diff --git a/src/nmigen_gf/hdl/clmul.py b/src/nmigen_gf/hdl/clmul.py index 6027d82..e53c9fa 100644 --- a/src/nmigen_gf/hdl/clmul.py +++ b/src/nmigen_gf/hdl/clmul.py @@ -10,10 +10,11 @@ https://bugs.libre-soc.org/show_bug.cgi?id=784 """ from functools import reduce -from operator import xor # XXX import operator then use operator.xor +import operator from nmigen.hdl.ir import Elaboratable from nmigen.hdl.ast import Signal, Cat, Repl, Value from nmigen.hdl.dsl import Module +from nmutil.util import treereduce # XXX class to be removed https://bugs.libre-soc.org/show_bug.cgi?id=784 @@ -31,7 +32,7 @@ class BitwiseXorReduce(Elaboratable): def __init__(self, input_values): self.input_values = tuple(map(Value.cast, input_values)) assert len(self.input_values) > 0, "can't xor-reduce nothing" - self.output = Signal(reduce(xor, self.input_values).shape()) + self.output = Signal(reduce(operator.xor, self.input_values).shape()) def elaborate(self, platform): m = Module() @@ -92,18 +93,16 @@ class CLMulAdd(Elaboratable): self.output = Signal(max((self.factor_width * 2 - 1, *self.term_widths))) - # XXX to create temporary Signals for mask-shifted expression. - # terms ok. - def __reduce_inputs(self): + def elaborate(self, platform): + m = Module() + + part_prods = [] for shift in range(self.factor_width): + part_prod = Signal(self.output.width, name=f"part_prod_{shift}") mask = Repl(self.factor2[shift], self.factor_width) - yield (self.factor1 & mask) << shift - yield from self.terms + m.d.comb += part_prod.eq((self.factor1 & mask) << shift) + part_prods.append(part_prod) - def elaborate(self, platform): - m = Module() - # XXX to be replaced with nmutil.util.tree_reduce() - xor_reduce = BitwiseXorReduce(self.__reduce_inputs()) - m.submodules.xor_reduce = xor_reduce - m.d.comb += self.output.eq(xor_reduce.output) + output = treereduce(part_prods + self.terms, operator.xor) + m.d.comb += self.output.eq(output) return m -- 2.30.2