From 56a194ff0fe3d085124e0615143112a942305ead Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Fri, 28 Jun 2019 20:22:10 -0700 Subject: [PATCH] add Fract class --- src/ieee754/div_rem_sqrt_rsqrt/algorithm.py | 265 ++++++++++++++++++ .../div_rem_sqrt_rsqrt/test_algorithm.py | 5 +- 2 files changed, 269 insertions(+), 1 deletion(-) diff --git a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py index 69de47db..dee6dcf4 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/algorithm.py @@ -146,3 +146,268 @@ class DivRem: self.quotient = Const.normalize(quotient, (bit_width, self.signed)) self.remainder = Const.normalize(remainder, (bit_width, self.signed)) return True + + +class Fixed: + """ Fixed-point number. + + the value is bits * 2 ** -fract_width + + :attribute bits: the bits of the fixed-point number + :attribute fract_width: the number of bits in the fractional portion + :attribute bit_width: the total number of bits + :attribute signed: if the type is signed + """ + + @staticmethod + def from_bits(bits, fract_width, bit_width, signed): + """ Create a new Fixed. + + :param bits: the bits of the fixed-point number + :param fract_width: the number of bits in the fractional portion + :param bit_width: the total number of bits + :param signed: if the type is signed + """ + retval = Fixed(0, fract_width, bit_width, signed) + retval.bits = Const.normalize(bits, (bit_width, signed)) + return retval + + def __init__(self, value, fract_width, bit_width, signed): + """ Create a new Fixed. + + :param value: the value of the fixed-point number + :param fract_width: the number of bits in the fractional portion + :param bit_width: the total number of bits + :param signed: if the type is signed + """ + assert fract_width >= 0 + assert bit_width > 0 + if isinstance(value, Fixed): + if fract_width < value.fract_width: + bits = value.bits >> (value.fract_width - fract_width) + else: + bits = value.bits << (fract_width - value.fract_width) + elif isinstance(value, int): + bits = value << fract_width + else: + bits = floor(value * 2 ** fract_width) + self.bits = Const.normalize(bits, (bit_width, signed)) + self.fract_width = fract_width + self.bit_width = bit_width + self.signed = signed + + def __repr__(self): + """ Get representation.""" + return f"Fixed({self.bits}, {self.fract_width}, {self.bit_width})" + + def __trunc__(self): + """ Truncate to integer.""" + if self.bits < 0: + return self.__ceil__() + return self.__floor__() + + def __int__(self): + """ Truncate to integer.""" + return self.__trunc__() + + def __float__(self): + """ Convert to float.""" + return self.bits * 2 ** -self.fract_width + + def __floor__(self): + """ Floor to integer.""" + return self.bits >> self.fract_width + + def __ceil__(self): + """ Ceil to integer.""" + return -((-self.bits) >> self.fract_width) + + def __neg__(self): + """ Negate.""" + return self.from_bits(-self.bits, self.fract_width, + self.bit_width, self.signed) + + def __pos__(self): + """ Unary Positive.""" + return self + + def __abs__(self): + """ Absolute Value.""" + return self.from_bits(abs(self.bits), self.fract_width, + self.bit_width, self.signed) + + def __invert__(self): + """ Inverse.""" + return self.from_bits(~self.bits, self.fract_width, + self.bit_width, self.signed) + + def _binary_op(self, rhs, operation, full=False): + """ Handle binary arithmetic operators. """ + if isinstance(rhs, int): + rhs_fract_width = 0 + rhs_bits = rhs + int_width = self.bit_width - self.fract_width + elif isinstance(rhs, Fixed): + if self.signed != rhs.signed: + return TypeError("signedness must match") + rhs_fract_width = rhs.fract_width + rhs_bits = rhs.bits + int_width = max(self.bit_width - self.fract_width, + rhs.bit_width - rhs.fract_width) + else: + return NotImplemented + fract_width = max(self.fract_width, rhs_fract_width) + rhs_bits <<= fract_width - rhs_fract_width + lhs_bits = self.bits << fract_width - self.fract_width + bit_width = int_width + fract_width + if full: + return operation(lhs_bits, rhs_bits, + fract_width, bit_width, self.signed) + bits = operation(lhs_bits, rhs_bits, + fract_width) + return self.from_bits(bits, fract_width, bit_width, self.signed) + + def __add__(self, rhs): + """ Addition.""" + return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs + rhs) + + def __radd__(self, lhs): + """ Reverse Addition.""" + return self.__add__(lhs) + + def __sub__(self, rhs): + """ Subtraction.""" + return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs - rhs) + + def __rsub__(self, lhs): + """ Reverse Subtraction.""" + # note swapped argument and parameter order + return self._binary_op(lhs, lambda rhs, lhs, fract_width: lhs - rhs) + + def __and__(self, rhs): + """ Bitwise And.""" + return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs & rhs) + + def __rand__(self, lhs): + """ Reverse Bitwise And.""" + return self.__and__(lhs) + + def __or__(self, rhs): + """ Bitwise Or.""" + return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs | rhs) + + def __ror__(self, lhs): + """ Reverse Bitwise Or.""" + return self.__or__(lhs) + + def __xor__(self, rhs): + """ Bitwise Xor.""" + return self._binary_op(rhs, lambda lhs, rhs, fract_width: lhs ^ rhs) + + def __rxor__(self, lhs): + """ Reverse Bitwise Xor.""" + return self.__xor__(lhs) + + def __mul__(self, rhs): + """ Multiplication. """ + if isinstance(rhs, int): + rhs_fract_width = 0 + rhs_bits = rhs + int_width = self.bit_width - self.fract_width + elif isinstance(rhs, Fixed): + if self.signed != rhs.signed: + return TypeError("signedness must match") + rhs_fract_width = rhs.fract_width + rhs_bits = rhs.bits + int_width = (self.bit_width - self.fract_width + + rhs.bit_width - rhs.fract_width) + else: + return NotImplemented + fract_width = self.fract_width + rhs_fract_width + bit_width = int_width + fract_width + bits = self.bits * rhs_bits + return self.from_bits(bits, fract_width, bit_width, self.signed) + + @staticmethod + def _cmp_impl(lhs, rhs, fract_width, bit_width, signed): + if lhs < rhs: + return -1 + elif lhs == rhs: + return 0 + return 1 + + def cmp(self, rhs): + """ Compare self with rhs. + + :returns int: returns -1 if self is less than rhs, 0 if they're equal, + and 1 for greater than. + Returns NotImplemented for unimplemented cases + """ + return self._binary_op(rhs, self._cmp_impl, full=True) + + def __lt__(self, rhs): + """ Less Than.""" + return self.cmp(rhs) < 0 + + def __le__(self, rhs): + """ Less Than or Equal.""" + return self.cmp(rhs) <= 0 + + def __eq__(self, rhs): + """ Equal.""" + return self.cmp(rhs) == 0 + + def __ne__(self, rhs): + """ Not Equal.""" + return self.cmp(rhs) != 0 + + def __gt__(self, rhs): + """ Greater Than.""" + return self.cmp(rhs) > 0 + + def __ge__(self, rhs): + """ Greater Than or Equal.""" + return self.cmp(rhs) >= 0 + + def __bool__(self, rhs): + """ Convert to bool.""" + return bool(self.bits) + + def __str__(self): + """ Get text representation.""" + # don't just use self.__float__() in order to work with numbers more + # than 53 bits wide + retval = "fixed:" + bits = self.bits + if bits < 0: + retval += "-" + bits = -bits + int_part = bits >> self.fract_width + fract_part = bits & ~(-1 << self.fract_width) + # round up fract_width to nearest multiple of 4 + fract_width = (self.fract_width + 3) & ~3 + fract_part <<= (fract_width - self.fract_width) + fract_width_in_hex_digits = fract_width / 4 + retval += f"{int_part:x}." + retval += f"{fract_part:x}".zfill(fract_width_in_hex_digits) + return retval + + +def fract_sqrt(): + # FIXME: finish + raise NotImplementedError() + + +class FractSqrt: + # FIXME: finish + pass + + +def fract_rsqrt(): + # FIXME: finish + raise NotImplementedError() + + +class FractRSqrt: + # FIXME: finish + pass diff --git a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py index ea14c964..23dcbab9 100644 --- a/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py +++ b/src/ieee754/div_rem_sqrt_rsqrt/test_algorithm.py @@ -2,7 +2,8 @@ # See Notices.txt for copyright information from nmigen.hdl.ast import Const -from .algorithm import div_rem, UnsignedDivRem, DivRem +from .algorithm import (div_rem, UnsignedDivRem, DivRem, + Fract, fract_sqrt, FractSqrt, fract_rsqrt, FractRSqrt) import unittest @@ -346,3 +347,5 @@ class TestDivRem(unittest.TestCase): def test_radix_16(self): self.helper(4) + +# FIXME: add tests for Fract, fract_sqrt, FractSqrt, fract_rsqrt, and FractRSqrt -- 2.30.2