add Fract class
authorJacob Lifshay <programmerjake@gmail.com>
Sat, 29 Jun 2019 03:22:10 +0000 (20:22 -0700)
committerJacob Lifshay <programmerjake@gmail.com>
Sat, 29 Jun 2019 03:22:10 +0000 (20:22 -0700)
src/ieee754/div_rem_sqrt_rsqrt/algorithm.py
src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py

index 69de47db4f4c4cfcc11ff483e5dcfb5ebcc655d9..dee6dcf4fa7c8bb631c1e0b9cd71c3eeaa366882 100644 (file)
@@ -146,3 +146,268 @@ class DivRem:
         self.quotient = Const.normalize(quotient, (bit_width, self.signed))
         self.remainder = Const.normalize(remainder, (bit_width, self.signed))
         return True
+
+
+class Fixed:
+    """ Fixed-point number.
+
+    the value is bits * 2 ** -fract_width
+
+    :attribute bits: the bits of the fixed-point number
+    :attribute fract_width: the number of bits in the fractional portion
+    :attribute bit_width: the total number of bits
+    :attribute signed: if the type is signed
+    """
+
+    @staticmethod
+    def from_bits(bits, fract_width, bit_width, signed):
+        """ Create a new Fixed.
+
+        :param bits: the bits of the fixed-point number
+        :param fract_width: the number of bits in the fractional portion
+        :param bit_width: the total number of bits
+        :param signed: if the type is signed
+        """
+        retval = Fixed(0, fract_width, bit_width, signed)
+        retval.bits = Const.normalize(bits, (bit_width, signed))
+        return retval
+
+    def __init__(self, value, fract_width, bit_width, signed):
+        """ Create a new Fixed.
+
+        :param value: the value of the fixed-point number
+        :param fract_width: the number of bits in the fractional portion
+        :param bit_width: the total number of bits
+        :param signed: if the type is signed
+        """
+        assert fract_width >= 0
+        assert bit_width > 0
+        if isinstance(value, Fixed):
+            if fract_width < value.fract_width:
+                bits = value.bits >> (value.fract_width - fract_width)
+            else:
+                bits = value.bits << (fract_width - value.fract_width)
+        elif isinstance(value, int):
+            bits = value << fract_width
+        else:
+            bits = floor(value * 2 ** fract_width)
+        self.bits = Const.normalize(bits, (bit_width, signed))
+        self.fract_width = fract_width
+        self.bit_width = bit_width
+        self.signed = signed
+
+    def __repr__(self):
+        """ Get representation."""
+        return f"Fixed({self.bits}, {self.fract_width}, {self.bit_width})"
+
+    def __trunc__(self):
+        """ Truncate to integer."""
+        if self.bits < 0:
+            return self.__ceil__()
+        return self.__floor__()
+
+    def __int__(self):
+        """ Truncate to integer."""
+        return self.__trunc__()
+
+    def __float__(self):
+        """ Convert to float."""
+        return self.bits * 2 ** -self.fract_width
+
+    def __floor__(self):
+        """ Floor to integer."""
+        return self.bits >> self.fract_width
+
+    def __ceil__(self):
+        """ Ceil to integer."""
+        return -((-self.bits) >> self.fract_width)
+
+    def __neg__(self):
+        """ Negate."""
+        return self.from_bits(-self.bits, self.fract_width,
+                              self.bit_width, self.signed)
+
+    def __pos__(self):
+        """ Unary Positive."""
+        return self
+
+    def __abs__(self):
+        """ Absolute Value."""
+        return self.from_bits(abs(self.bits), self.fract_width,
+                              self.bit_width, self.signed)
+
+    def __invert__(self):
+        """ Inverse."""
+        return self.from_bits(~self.bits, self.fract_width,
+                              self.bit_width, self.signed)
+
+    def _binary_op(self, rhs, operation, full=False):
+        """ Handle binary arithmetic operators. """
+        if isinstance(rhs, int):
+            rhs_fract_width = 0
+            rhs_bits = rhs
+            int_width = self.bit_width - self.fract_width
+        elif isinstance(rhs, Fixed):
+            if self.signed != rhs.signed:
+                return TypeError("signedness must match")
+            rhs_fract_width = rhs.fract_width
+            rhs_bits = rhs.bits
+            int_width = max(self.bit_width - self.fract_width,
+                            rhs.bit_width - rhs.fract_width)
+        else:
+            return NotImplemented
+        fract_width = max(self.fract_width, rhs_fract_width)
+        rhs_bits <<= fract_width - rhs_fract_width
+        lhs_bits = self.bits << fract_width - self.fract_width
+        bit_width = int_width + fract_width
+        if full:
+            return operation(lhs_bits, rhs_bits,
+                             fract_width, bit_width, self.signed)
+        bits = operation(lhs_bits, rhs_bits,
+                         fract_width)
+        return self.from_bits(bits, fract_width, bit_width, self.signed)
+
+    def __add__(self, rhs):
+        """ Addition."""
+        return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs + rhs)
+
+    def __radd__(self, lhs):
+        """ Reverse Addition."""
+        return self.__add__(lhs)
+
+    def __sub__(self, rhs):
+        """ Subtraction."""
+        return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs - rhs)
+
+    def __rsub__(self, lhs):
+        """ Reverse Subtraction."""
+        # note swapped argument and parameter order
+        return self._binary_op(lhs, lambda rhs, lhs, fract_width: lhs - rhs)
+
+    def __and__(self, rhs):
+        """ Bitwise And."""
+        return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs & rhs)
+
+    def __rand__(self, lhs):
+        """ Reverse Bitwise And."""
+        return self.__and__(lhs)
+
+    def __or__(self, rhs):
+        """ Bitwise Or."""
+        return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs | rhs)
+
+    def __ror__(self, lhs):
+        """ Reverse Bitwise Or."""
+        return self.__or__(lhs)
+
+    def __xor__(self, rhs):
+        """ Bitwise Xor."""
+        return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs ^ rhs)
+
+    def __rxor__(self, lhs):
+        """ Reverse Bitwise Xor."""
+        return self.__xor__(lhs)
+
+    def __mul__(self, rhs):
+        """ Multiplication. """
+        if isinstance(rhs, int):
+            rhs_fract_width = 0
+            rhs_bits = rhs
+            int_width = self.bit_width - self.fract_width
+        elif isinstance(rhs, Fixed):
+            if self.signed != rhs.signed:
+                return TypeError("signedness must match")
+            rhs_fract_width = rhs.fract_width
+            rhs_bits = rhs.bits
+            int_width = (self.bit_width - self.fract_width
+                         + rhs.bit_width - rhs.fract_width)
+        else:
+            return NotImplemented
+        fract_width = self.fract_width + rhs_fract_width
+        bit_width = int_width + fract_width
+        bits = self.bits * rhs_bits
+        return self.from_bits(bits, fract_width, bit_width, self.signed)
+
+    @staticmethod
+    def _cmp_impl(lhs, rhs, fract_width, bit_width, signed):
+        if lhs < rhs:
+            return -1
+        elif lhs == rhs:
+            return 0
+        return 1
+
+    def cmp(self, rhs):
+        """ Compare self with rhs.
+
+        :returns int: returns -1 if self is less than rhs, 0 if they're equal,
+            and 1 for greater than.
+            Returns NotImplemented for unimplemented cases
+        """
+        return self._binary_op(rhs, self._cmp_impl, full=True)
+
+    def __lt__(self, rhs):
+        """ Less Than."""
+        return self.cmp(rhs) < 0
+
+    def __le__(self, rhs):
+        """ Less Than or Equal."""
+        return self.cmp(rhs) <= 0
+
+    def __eq__(self, rhs):
+        """ Equal."""
+        return self.cmp(rhs) == 0
+
+    def __ne__(self, rhs):
+        """ Not Equal."""
+        return self.cmp(rhs) != 0
+
+    def __gt__(self, rhs):
+        """ Greater Than."""
+        return self.cmp(rhs) > 0
+
+    def __ge__(self, rhs):
+        """ Greater Than or Equal."""
+        return self.cmp(rhs) >= 0
+
+    def __bool__(self, rhs):
+        """ Convert to bool."""
+        return bool(self.bits)
+
+    def __str__(self):
+        """ Get text representation."""
+        # don't just use self.__float__() in order to work with numbers more
+        # than 53 bits wide
+        retval = "fixed:"
+        bits = self.bits
+        if bits < 0:
+            retval += "-"
+            bits = -bits
+        int_part = bits >> self.fract_width
+        fract_part = bits & ~(-1 << self.fract_width)
+        # round up fract_width to nearest multiple of 4
+        fract_width = (self.fract_width + 3) & ~3
+        fract_part <<= (fract_width - self.fract_width)
+        fract_width_in_hex_digits = fract_width / 4
+        retval += f"{int_part:x}."
+        retval += f"{fract_part:x}".zfill(fract_width_in_hex_digits)
+        return retval
+
+
+def fract_sqrt():
+    # FIXME: finish
+    raise NotImplementedError()
+
+
+class FractSqrt:
+    # FIXME: finish
+    pass
+
+
+def fract_rsqrt():
+    # FIXME: finish
+    raise NotImplementedError()
+
+
+class FractRSqrt:
+    # FIXME: finish
+    pass
index ea14c964792cd37d6fd9e4f21e1467d0b28819ff..23dcbab9e81d7f10d93ef1e0094059f210f11a40 100644 (file)
@@ -2,7 +2,8 @@
 # See Notices.txt for copyright information
 
 from nmigen.hdl.ast import Const
-from .algorithm import div_rem, UnsignedDivRem, DivRem
+from .algorithm import (div_rem, UnsignedDivRem, DivRem,
+                        Fract, fract_sqrt, FractSqrt, fract_rsqrt,  FractRSqrt)
 import unittest
 
 
@@ -346,3 +347,5 @@ class TestDivRem(unittest.TestCase):
 
     def test_radix_16(self):
         self.helper(4)
+
+# FIXME: add tests for Fract, fract_sqrt, FractSqrt, fract_rsqrt, and FractRSqrt