remove uses of BitwiseXorReduce
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 7 Apr 2022 02:29:38 +0000 (19:29 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 7 Apr 2022 02:29:38 +0000 (19:29 -0700)
src/nmigen_gf/hdl/clmul.py

index 6027d82e4de16e17fb11841e61cdd6a2e40f62e8..e53c9fa3fe1c846b5033d22c5fc731aaed7d76fc 100644 (file)
@@ -10,10 +10,11 @@ https://bugs.libre-soc.org/show_bug.cgi?id=784
 """
 
 from functools import reduce
-from operator import xor  # XXX import operator then use operator.xor
+import operator
 from nmigen.hdl.ir import Elaboratable
 from nmigen.hdl.ast import Signal, Cat, Repl, Value
 from nmigen.hdl.dsl import Module
+from nmutil.util import treereduce
 
 
 # XXX class to be removed https://bugs.libre-soc.org/show_bug.cgi?id=784
@@ -31,7 +32,7 @@ class BitwiseXorReduce(Elaboratable):
     def __init__(self, input_values):
         self.input_values = tuple(map(Value.cast, input_values))
         assert len(self.input_values) > 0, "can't xor-reduce nothing"
-        self.output = Signal(reduce(xor, self.input_values).shape())
+        self.output = Signal(reduce(operator.xor, self.input_values).shape())
 
     def elaborate(self, platform):
         m = Module()
@@ -92,18 +93,16 @@ class CLMulAdd(Elaboratable):
         self.output = Signal(max((self.factor_width * 2 - 1,
                                   *self.term_widths)))
 
-    # XXX to create temporary Signals for mask-shifted expression.
-    # terms ok.
-    def __reduce_inputs(self):
+    def elaborate(self, platform):
+        m = Module()
+
+        part_prods = []
         for shift in range(self.factor_width):
+            part_prod = Signal(self.output.width, name=f"part_prod_{shift}")
             mask = Repl(self.factor2[shift], self.factor_width)
-            yield (self.factor1 & mask) << shift
-        yield from self.terms
+            m.d.comb += part_prod.eq((self.factor1 & mask) << shift)
+            part_prods.append(part_prod)
 
-    def elaborate(self, platform):
-        m = Module()
-        # XXX to be replaced with nmutil.util.tree_reduce()
-        xor_reduce = BitwiseXorReduce(self.__reduce_inputs())
-        m.submodules.xor_reduce = xor_reduce
-        m.d.comb += self.output.eq(xor_reduce.output)
+        output = treereduce(part_prods + self.terms, operator.xor)
+        m.d.comb += self.output.eq(output)
         return m