hdl/clmul: split clmuladd hdl out into separate function for ease of use in formal...
authorJacob Lifshay <programmerjake@gmail.com>
Fri, 17 May 2024 05:02:45 +0000 (22:02 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Fri, 17 May 2024 05:02:45 +0000 (22:02 -0700)
src/nmigen_gf/hdl/clmul.py

index ad4317316b3a5a3e5655e908e7f9e13564918373..b82d2c91e1850dcac4afdc239c796e811a025a66 100644 (file)
@@ -11,19 +11,40 @@ https://bugs.libre-soc.org/show_bug.cgi?id=784
 
 import operator
 from nmigen.hdl.ir import Elaboratable
-from nmigen.hdl.ast import Signal, Repl
+from nmigen.hdl.ast import Signal, Repl, Value
 from nmigen.hdl.dsl import Module
 from nmutil.util import treereduce
 
 
+def clmuladd(factor1, factor2, *terms, mktmp=lambda v, name: v):
+    factor1 = Value.cast(factor1)
+    factor1_width = factor1.shape().width
+    factor2 = Value.cast(factor2)
+    factor2_width = factor2.shape().width
+    all_terms = [*terms]
+    for shift in range(factor2_width):
+        # construct partial product term
+        mask = Repl(factor2[shift], factor1_width)
+        part_prod = mktmp((factor1 & mask) << shift,
+                          name="part_prod_%s" % shift)
+        all_terms.append(part_prod)
+
+    # merge all terms together
+    return treereduce(all_terms, operator.xor)
+
+
+def clmul(factor1, factor2, mktmp=lambda v, name: v):
+    return clmuladd(factor1, factor2, mktmp=mktmp)
+
+
 class CLMulAdd(Elaboratable):
     """Carry-less multiply-add. (optional add)
 
         Computes:
         ```
-        self.output = (clmul(self.factor1, self.factor2) ^ 
-                       self.terms[0] ^ 
-                       self.terms[1] ^ 
+        self.output = (clmul(self.factor1, self.factor2) ^
+                       self.terms[0] ^
+                       self.terms[1] ^
                        self.terms[2] ...)
         ```
 
@@ -63,15 +84,12 @@ class CLMulAdd(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        all_terms = self.terms.copy()  # create copy to avoid modifying self
-        for shift in range(self.factor_width):
-            part_prod = Signal(self.output.width, name=f"part_prod_{shift}")
-            # construct partial product term
-            mask = Repl(self.factor2[shift], self.factor_width)
-            m.d.comb += part_prod.eq((self.factor1 & mask) << shift)
-            all_terms.append(part_prod)
+        def mktmp(v, name):
+            s = Signal.like(v, name=name)
+            m.d.comb += s.eq(v)
+            return s
+
+        output = clmuladd(self.factor1, self.factor2, *self.terms, mktmp=mktmp)
 
-        # merge all terms together
-        output = treereduce(all_terms, operator.xor)
         m.d.comb += self.output.eq(output)
         return m