remove BitwiseXorReduce
authorJacob Lifshay <programmerjake@gmail.com>
Thu, 7 Apr 2022 02:33:10 +0000 (19:33 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Thu, 7 Apr 2022 02:33:10 +0000 (19:33 -0700)
src/nmigen_gf/hdl/clmul.py
src/nmigen_gf/hdl/test/test_clmul.py

index e53c9fa3fe1c846b5033d22c5fc731aaed7d76fc..8f79c5e3aee4b961570f8e9967a96ea8f529e5e2 100644 (file)
@@ -9,46 +9,13 @@
 https://bugs.libre-soc.org/show_bug.cgi?id=784
 """
 
-from functools import reduce
 import operator
 from nmigen.hdl.ir import Elaboratable
-from nmigen.hdl.ast import Signal, Cat, Repl, Value
+from nmigen.hdl.ast import Signal, Repl
 from nmigen.hdl.dsl import Module
 from nmutil.util import treereduce
 
 
-# XXX class to be removed https://bugs.libre-soc.org/show_bug.cgi?id=784
-# functionality covered already entirely by nmutil.util.tree_reduce
-class BitwiseXorReduce(Elaboratable):
-    """Bitwise Xor lots of stuff together by using tree-reduction on each bit.
-
-    Properties:
-    input_values: tuple[Value, ...]
-        input nmigen Values
-    output: Signal
-        output, set to `input_values[0] ^ input_values[1] ^ input_values[2]...`
-    """
-
-    def __init__(self, input_values):
-        self.input_values = tuple(map(Value.cast, input_values))
-        assert len(self.input_values) > 0, "can't xor-reduce nothing"
-        self.output = Signal(reduce(operator.xor, self.input_values).shape())
-
-    def elaborate(self, platform):
-        m = Module()
-        # collect inputs into full-width Signals
-        inputs = []
-        for i, inp_v in enumerate(self.input_values):
-            inp = self.output.like(self.output, name=f"input_{i}")
-            # sign/zero-extend inp_v to full-width
-            m.d.comb += inp.eq(inp_v)
-            inputs.append(inp)
-        for bit in range(self.output.width):
-            # construct a tree-reduction for bit index `bit` of all inputs
-            m.d.comb += self.output[bit].eq(Cat(i[bit] for i in inputs).xor())
-        return m
-
-
 class CLMulAdd(Elaboratable):
     """Carry-less multiply-add. (optional add)
 
index dc0e6fd149758416762ff668a25b3e9067562c97..42c29cf3c0be6c39fb7116204cdbe846675c1f50 100644 (file)
@@ -5,77 +5,17 @@
 # of Horizon 2020 EU Programme 957073.
 
 from functools import reduce
-from operator import xor
+import operator
 import unittest
-from nmigen.hdl.ast import (AnyConst, Assert, Signal, Const, unsigned, signed,
-                            Mux)
+from nmigen.hdl.ast import AnyConst, Assert, Signal, Const, unsigned, Mux
 from nmigen.hdl.dsl import Module
 from nmutil.formaltest import FHDLTestCase
 from nmigen_gf.reference.clmul import clmul
-from nmigen_gf.hdl.clmul import BitwiseXorReduce, CLMulAdd
+from nmigen_gf.hdl.clmul import CLMulAdd
 from nmigen.sim import Delay
 from nmutil.sim_util import do_sim, hash_256
 
 
-class TestBitwiseXorReduce(FHDLTestCase):
-    def tst(self, input_shapes):
-        dut = BitwiseXorReduce(Signal(w, name=f"input_{i}")
-                               for i, w in enumerate(input_shapes))
-        self.assertEqual(reduce(xor, dut.input_values).shape(),
-                         dut.output.shape())
-
-        def case(inputs):
-            expected = reduce(xor, inputs)
-            with self.subTest(inputs=list(map(hex, inputs)),
-                              expected=hex(expected)):
-                for i, inp in enumerate(inputs):
-                    yield dut.input_values[i].eq(inp)
-                yield Delay(1e-6)
-                output = yield dut.output
-                with self.subTest(output=hex(output)):
-                    self.assertEqual(expected, output)
-
-        def process():
-            for i in range(100):
-                inputs = []
-                for inp in dut.input_values:
-                    v = hash_256(f"bxorr input {i} {inp.name}")
-                    inputs.append(Const.normalize(v, inp.shape()))
-                yield from case(inputs)
-
-        with do_sim(self, dut, [*dut.input_values, dut.output]) as sim:
-            sim.add_process(process)
-            sim.run()
-
-    def tst_formal(self, input_shapes):
-        dut = BitwiseXorReduce(Signal(w, name=f"input_{i}")
-                               for i, w in enumerate(input_shapes))
-        m = Module()
-        m.submodules.dut = dut
-        for i in dut.input_values:
-            m.d.comb += i.eq(AnyConst(i.shape()))
-        m.d.comb += Assert(dut.output == reduce(xor, dut.input_values))
-        self.assertFormal(m)
-
-    def test_65_of_u64(self):
-        self.tst([64] * 65)
-
-    def test_formal_65_of_u64(self):
-        self.tst_formal([64] * 65)
-
-    def test_5_of_u6(self):
-        self.tst([6] * 5)
-
-    def test_formal_5_of_u6(self):
-        self.tst_formal([6] * 5)
-
-    def test_u5_i6_u3_i10(self):
-        self.tst([unsigned(5), signed(6), unsigned(3), signed(10)])
-
-    def test_formal_u5_i6_u3_i10(self):
-        self.tst_formal([unsigned(5), signed(6), unsigned(3), signed(10)])
-
-
 class TestCLMulAdd(FHDLTestCase):
     def tst(self, factor_width, terms_width):
         dut = CLMulAdd(factor_width, terms_width)
@@ -83,7 +23,7 @@ class TestCLMulAdd(FHDLTestCase):
                          max((factor_width * 2 - 1, *terms_width)))
 
         def case(factor1, factor2, terms):
-            expected = reduce(xor, terms, clmul(factor1, factor2))
+            expected = reduce(operator.xor, terms, clmul(factor1, factor2))
             with self.subTest(factor1=hex(factor1),
                               factor2=bin(factor2),
                               terms=list(map(hex, terms)),
@@ -145,8 +85,8 @@ class TestCLMulAdd(FHDLTestCase):
             sig = Signal(reduce_inputs[i].shape(), name=f"reduce_input_{i}")
             m.d.comb += sig.eq(reduce_inputs[i])
             reduce_inputs[i] = sig
-        expected = Signal(reduce(xor, reduce_inputs).shape())
-        m.d.comb += expected.eq(reduce(xor, reduce_inputs))
+        expected = Signal(reduce(operator.xor, reduce_inputs).shape())
+        m.d.comb += expected.eq(reduce(operator.xor, reduce_inputs))
         m.d.comb += Assert(dut.output == expected)
         self.assertFormal(m)