From 68ba5fe8e3ec9f4bd322110e1fc0e4f75906c6df Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Thu, 16 May 2024 22:02:45 -0700 Subject: [PATCH] hdl/clmul: split clmuladd hdl out into separate function for ease of use in formal proofs --- src/nmigen_gf/hdl/clmul.py | 44 +++++++++++++++++++++++++++----------- 1 file changed, 31 insertions(+), 13 deletions(-) diff --git a/src/nmigen_gf/hdl/clmul.py b/src/nmigen_gf/hdl/clmul.py index ad43173..b82d2c9 100644 --- a/src/nmigen_gf/hdl/clmul.py +++ b/src/nmigen_gf/hdl/clmul.py @@ -11,19 +11,40 @@ https://bugs.libre-soc.org/show_bug.cgi?id=784 import operator from nmigen.hdl.ir import Elaboratable -from nmigen.hdl.ast import Signal, Repl +from nmigen.hdl.ast import Signal, Repl, Value from nmigen.hdl.dsl import Module from nmutil.util import treereduce +def clmuladd(factor1, factor2, *terms, mktmp=lambda v, name: v): + factor1 = Value.cast(factor1) + factor1_width = factor1.shape().width + factor2 = Value.cast(factor2) + factor2_width = factor2.shape().width + all_terms = [*terms] + for shift in range(factor2_width): + # construct partial product term + mask = Repl(factor2[shift], factor1_width) + part_prod = mktmp((factor1 & mask) << shift, + name="part_prod_%s" % shift) + all_terms.append(part_prod) + + # merge all terms together + return treereduce(all_terms, operator.xor) + + +def clmul(factor1, factor2, mktmp=lambda v, name: v): + return clmuladd(factor1, factor2, mktmp=mktmp) + + class CLMulAdd(Elaboratable): """Carry-less multiply-add. (optional add) Computes: ``` - self.output = (clmul(self.factor1, self.factor2) ^ - self.terms[0] ^ - self.terms[1] ^ + self.output = (clmul(self.factor1, self.factor2) ^ + self.terms[0] ^ + self.terms[1] ^ self.terms[2] ...) ``` @@ -63,15 +84,12 @@ class CLMulAdd(Elaboratable): def elaborate(self, platform): m = Module() - all_terms = self.terms.copy() # create copy to avoid modifying self - for shift in range(self.factor_width): - part_prod = Signal(self.output.width, name=f"part_prod_{shift}") - # construct partial product term - mask = Repl(self.factor2[shift], self.factor_width) - m.d.comb += part_prod.eq((self.factor1 & mask) << shift) - all_terms.append(part_prod) + def mktmp(v, name): + s = Signal.like(v, name=name) + m.d.comb += s.eq(v) + return s + + output = clmuladd(self.factor1, self.factor2, *self.terms, mktmp=mktmp) - # merge all terms together - output = treereduce(all_terms, operator.xor) m.d.comb += self.output.eq(output) return m -- 2.30.2