change goldschmidt_div_sqrt to use nmutil.plain_data rather than dataclasses
authorJacob Lifshay <programmerjake@gmail.com>
Tue, 16 Aug 2022 06:43:13 +0000 (23:43 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Tue, 16 Aug 2022 06:43:13 +0000 (23:43 -0700)
src/soc/fu/div/experiment/goldschmidt_div_sqrt.py
src/soc/fu/div/experiment/test/test_goldschmidt_div_sqrt.py

index 6f739c33db8d6a97f55015b795709c5b4ee34b2e..3f7c2480742d6913859461da120099385f99d18a 100644 (file)
@@ -5,17 +5,17 @@
 # of Horizon 2020 EU Programme 957073.
 
 from collections import defaultdict
-from dataclasses import dataclass, field, fields, replace
 import logging
 import math
 import enum
 from fractions import Fraction
 from types import FunctionType
 from functools import lru_cache
-from nmigen.hdl.ast import Signal, unsigned, signed, Const, Cat
+from nmigen.hdl.ast import Signal, unsigned, signed, Const
 from nmigen.hdl.dsl import Module, Elaboratable
 from nmigen.hdl.mem import Memory
 from nmutil.clz import CLZ
+from nmutil.plain_data import plain_data, fields, replace
 
 try:
     from functools import cached_property
@@ -65,13 +65,13 @@ class RoundDir(enum.Enum):
     ERROR_IF_INEXACT = enum.auto()
 
 
-@dataclass(frozen=True)
+@plain_data(frozen=True, eq=False, repr=False)
 class FixedPoint:
-    bits: int
-    frac_wid: int
+    __slots__ = "bits", "frac_wid"
 
-    def __post_init__(self):
-        # called by the autogenerated __init__
+    def __init__(self, bits, frac_wid):
+        self.bits = bits
+        self.frac_wid = frac_wid
         assert isinstance(self.bits, int)
         assert isinstance(self.frac_wid, int) and self.frac_wid >= 0
 
@@ -332,43 +332,47 @@ def _assert_accuracy(condition, msg="not accurate enough"):
     raise ParamsNotAccurateEnough(msg)
 
 
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
 class GoldschmidtDivParamsBase:
     """parameters for a Goldschmidt division algorithm, excluding derived
     parameters.
     """
 
-    io_width: int
-    """bit-width of the input divisor and the result.
-    the input numerator is `2 * io_width`-bits wide.
-    """
+    __slots__ = ("io_width", "extra_precision", "table_addr_bits",
+                 "table_data_bits", "iter_count")
+
+    def __init__(self, io_width, extra_precision, table_addr_bits,
+                 table_data_bits, iter_count):
+        assert isinstance(io_width, int)
+        assert isinstance(extra_precision, int)
+        assert isinstance(table_addr_bits, int)
+        assert isinstance(table_data_bits, int)
+        assert isinstance(iter_count, int)
+        self.io_width = io_width
+        """bit-width of the input divisor and the result.
+        the input numerator is `2 * io_width`-bits wide.
+        """
 
-    extra_precision: int
-    """number of bits of additional precision used inside the algorithm."""
+        self.extra_precision = extra_precision
+        """number of bits of additional precision used inside the algorithm."""
 
-    table_addr_bits: int
-    """the number of address bits used in the lookup-table."""
+        self.table_addr_bits = table_addr_bits
+        """the number of address bits used in the lookup-table."""
 
-    table_data_bits: int
-    """the number of data bits used in the lookup-table."""
+        self.table_data_bits = table_data_bits
+        """the number of data bits used in the lookup-table."""
 
-    iter_count: int
-    """the total number of iterations of the division algorithm's loop"""
+        self.iter_count = iter_count
+        """the total number of iterations of the division algorithm's loop"""
 
 
-@dataclass(frozen=True, unsafe_hash=True)
+@plain_data(frozen=True, unsafe_hash=True)
 class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     """parameters for a Goldschmidt division algorithm.
     Use `GoldschmidtDivParams.get` to find a efficient set of parameters.
     """
 
-    # tuple to be immutable, repr=False so repr() works for debugging even when
-    # __post_init__ hasn't finished running yet
-    table: "tuple[FixedPoint, ...]" = field(init=False, repr=False)
-    """the lookup-table"""
-
-    ops: "tuple[GoldschmidtDivOp, ...]" = field(init=False, repr=False)
-    """the operations needed to perform the goldschmidt division algorithm."""
+    __slots__ = "table", "ops"
 
     def _shrink_bound(self, bound, round_dir):
         """prevent fractions from having huge numerators/denominators by
@@ -445,8 +449,13 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
         # we round down
         return min_value
 
-    def __post_init__(self):
-        # called by the autogenerated __init__
+    def __init__(self, io_width, extra_precision, table_addr_bits,
+                 table_data_bits, iter_count):
+        super().__init__(io_width=io_width,
+                         extra_precision=extra_precision,
+                         table_addr_bits=table_addr_bits,
+                         table_data_bits=table_data_bits,
+                         iter_count=iter_count)
         _assert_accuracy(self.io_width >= 1, "io_width out of range")
         _assert_accuracy(self.extra_precision >= 0,
                          "extra_precision out of range")
@@ -460,9 +469,14 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
             table.append(FixedPoint.with_frac_wid(self.table_exact_value(addr),
                                                   self.table_data_bits,
                                                   RoundDir.DOWN))
-        # we have to use object.__setattr__ since frozen=True
-        object.__setattr__(self, "table", tuple(table))
-        object.__setattr__(self, "ops", tuple(self.__make_ops()))
+
+        self.table = tuple(table)
+        """ the lookup-table.
+        type: tuple[FixedPoint, ...]
+        """
+
+        self.ops = tuple(self.__make_ops())
+        "the operations needed to perform the goldschmidt division algorithm."
 
     @property
     def expanded_width(self):
@@ -800,11 +814,9 @@ class GoldschmidtDivParams(GoldschmidtDivParamsBase):
     @lru_cache(maxsize=1 << 16)
     def __cached_new(base_params):
         assert isinstance(base_params, GoldschmidtDivParamsBase)
-        # can't use dataclasses.asdict, since it's recursive and will also give
-        # child class fields too, which we don't want.
         kwargs = {}
         for field in fields(GoldschmidtDivParamsBase):
-            kwargs[field.name] = getattr(base_params, field.name)
+            kwargs[field] = getattr(base_params, field)
         try:
             return GoldschmidtDivParams(**kwargs), None
         except ParamsNotAccurateEnough as e:
@@ -1139,44 +1151,57 @@ class GoldschmidtDivOp(enum.Enum):
             assert False, f"unimplemented GoldschmidtDivOp: {self}"
 
 
-@dataclass
+@plain_data(repr=False)
 class GoldschmidtDivState:
-    orig_n: int
-    """original numerator"""
-
-    orig_d: int
-    """original denominator"""
-
-    n: FixedPoint
-    """numerator -- N_prime[i] in the paper's algorithm 2"""
-
-    d: FixedPoint
-    """denominator -- D_prime[i] in the paper's algorithm 2"""
-
-    f: "FixedPoint | None" = None
-    """current factor -- F_prime[i] in the paper's algorithm 2"""
-
-    quotient: "int | None" = None
-    """final quotient"""
-
-    remainder: "int | None" = None
-    """final remainder"""
-
-    n_shift: "int | None" = None
-    """amount the numerator needs to be left-shifted at the end of the
-    algorithm.
-    """
+    __slots__ = ("orig_n", "orig_d", "n", "d",
+                 "f", "quotient", "remainder", "n_shift")
+
+    def __init__(self, orig_n, orig_d, n, d,
+                 f=None, quotient=None, remainder=None, n_shift=None):
+        assert isinstance(orig_n, int)
+        assert isinstance(orig_d, int)
+        assert isinstance(n, FixedPoint)
+        assert isinstance(d, FixedPoint)
+        assert f is None or isinstance(f, FixedPoint)
+        assert quotient is None or isinstance(quotient, int)
+        assert remainder is None or isinstance(remainder, int)
+        assert n_shift is None or isinstance(n_shift, int)
+        self.orig_n = orig_n
+        """original numerator"""
+
+        self.orig_d = orig_d
+        """original denominator"""
+
+        self.n = n
+        """numerator -- N_prime[i] in the paper's algorithm 2"""
+
+        self.d = d
+        """denominator -- D_prime[i] in the paper's algorithm 2"""
+
+        self.f = f
+        """current factor -- F_prime[i] in the paper's algorithm 2"""
+
+        self.quotient = quotient
+        """final quotient"""
+
+        self.remainder = remainder
+        """final remainder"""
+
+        self.n_shift = n_shift
+        """amount the numerator needs to be left-shifted at the end of the
+        algorithm.
+        """
 
     def __repr__(self):
         fields_str = []
         for field in fields(GoldschmidtDivState):
-            value = getattr(self, field.name)
+            value = getattr(self, field)
             if value is None:
                 continue
-            if isinstance(value, int) and field.name != "n_shift":
-                fields_str.append(f"{field.name}={hex(value)}")
+            if isinstance(value, int) and field != "n_shift":
+                fields_str.append(f"{field}={hex(value)}")
             else:
-                fields_str.append(f"{field.name}={value!r}")
+                fields_str.append(f"{field}={value!r}")
         return f"GoldschmidtDivState({', '.join(fields_str)})"
 
 
@@ -1230,45 +1255,55 @@ def goldschmidt_div(n, d, params, trace=lambda state: None):
     return state.quotient, state.remainder
 
 
-@dataclass(eq=False)
+@plain_data(eq=False)
 class GoldschmidtDivHDLState:
-    m: Module
-    """The HDL Module"""
+    __slots__ = ("m", "orig_n", "orig_d", "n", "d",
+                 "f", "quotient", "remainder", "n_shift")
 
-    orig_n: Signal
-    """original numerator"""
+    __signal_name_prefix = "state_"
 
-    orig_d: Signal
-    """original denominator"""
+    def __init__(self, m, orig_n, orig_d, n, d,
+                 f=None, quotient=None, remainder=None, n_shift=None):
+        assert isinstance(m, Module)
+        assert isinstance(orig_n, Signal)
+        assert isinstance(orig_d, Signal)
+        assert isinstance(n, Signal)
+        assert isinstance(d, Signal)
+        assert f is None or isinstance(f, Signal)
+        assert quotient is None or isinstance(quotient, Signal)
+        assert remainder is None or isinstance(remainder, Signal)
+        assert n_shift is None or isinstance(n_shift, Signal)
 
-    n: Signal
-    """numerator -- N_prime[i] in the paper's algorithm 2"""
+        self.m = m
+        """The HDL Module"""
 
-    d: Signal
-    """denominator -- D_prime[i] in the paper's algorithm 2"""
+        self.orig_n = orig_n
+        """original numerator"""
 
-    f: "Signal | None" = None
-    """current factor -- F_prime[i] in the paper's algorithm 2"""
+        self.orig_d = orig_d
+        """original denominator"""
 
-    quotient: "Signal | None" = None
-    """final quotient"""
+        self.n = n
+        """numerator -- N_prime[i] in the paper's algorithm 2"""
 
-    remainder: "Signal | None" = None
-    """final remainder"""
+        self.d = d
+        """denominator -- D_prime[i] in the paper's algorithm 2"""
 
-    n_shift: "Signal | None" = None
-    """amount the numerator needs to be left-shifted at the end of the
-    algorithm.
-    """
+        self.f = f
+        """current factor -- F_prime[i] in the paper's algorithm 2"""
 
-    old_signals: "defaultdict[str, list[Signal]]" = field(repr=False,
-                                                          init=False)
+        self.quotient = quotient
+        """final quotient"""
 
-    __signal_name_prefix: "str" = field(default="state_", repr=False,
-                                        init=False)
+        self.remainder = remainder
+        """final remainder"""
+
+        self.n_shift = n_shift
+        """amount the numerator needs to be left-shifted at the end of the
+        algorithm.
+        """
 
-    def __post_init__(self):
-        # called by the autogenerated __init__
+        # old_signals must be set last
         self.old_signals = defaultdict(list)
 
     def __setattr__(self, name, value):
@@ -1293,14 +1328,14 @@ class GoldschmidtDivHDLState:
         old_prefix = self.__signal_name_prefix
         try:
             for field in fields(GoldschmidtDivHDLState):
-                if field.name.startswith("_") or field.name == "m":
+                if field.startswith("_") or field == "m":
                     continue
-                old_sig = getattr(self, field.name, None)
+                old_sig = getattr(self, field, None)
                 if old_sig is None:
                     continue
                 assert isinstance(old_sig, Signal)
                 new_sig = Signal.like(old_sig)
-                setattr(self, field.name, new_sig)
+                setattr(self, field, new_sig)
                 self.m.d.sync += new_sig.eq(old_sig)
         finally:
             self.__signal_name_prefix = old_prefix
index bf999bd850237387120691f16bcda7eb308e21df..28e795f4e8bdd54b93a3ec071ebb93599690a1b3 100644 (file)
@@ -4,7 +4,7 @@
 # Funded by NLnet Assure Programme 2021-02-052, https://nlnet.nl/assure part
 # of Horizon 2020 EU Programme 957073.
 
-from dataclasses import fields, replace
+from nmutil.plain_data import fields, replace
 import math
 import unittest
 from nmutil.formaltest import FHDLTestCase
@@ -13,7 +13,7 @@ from nmigen.sim import Tick, Delay
 from nmigen.hdl.ast import Signal
 from nmigen.hdl.dsl import Module
 from soc.fu.div.experiment.goldschmidt_div_sqrt import (
-    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivOp, GoldschmidtDivParams,
+    GoldschmidtDivHDL, GoldschmidtDivHDLState, GoldschmidtDivParams,
     GoldschmidtDivState, ParamsNotAccurateEnough, goldschmidt_div,
     FixedPoint, RoundDir, goldschmidt_sqrt_rsqrt)
 
@@ -156,15 +156,15 @@ class TestGoldschmidtDiv(FHDLTestCase):
                                   ref_state=repr(ref_state),
                                   last_op=str(last_op)):
                     for field in fields(GoldschmidtDivHDLState):
-                        sig = getattr(state, field.name)
+                        sig = getattr(state, field)
                         if not isinstance(sig, Signal):
                             continue
-                        ref_value = getattr(ref_state, field.name)
+                        ref_value = getattr(ref_state, field)
                         ref_value_str = repr(ref_value)
                         if isinstance(ref_value, int):
                             ref_value_str = hex(ref_value)
                         value = yield sig
-                        with self.subTest(field_name=field.name,
+                        with self.subTest(field_name=field,
                                           sig=repr(sig),
                                           sig_shape=repr(sig.shape()),
                                           value=hex(value),