From d5eca4df4b6c1d65b2fe5cb815fde7b1bc53fc02 Mon Sep 17 00:00:00 2001 From: Jacob Lifshay Date: Wed, 19 Oct 2022 20:23:14 -0700 Subject: [PATCH] make Matrix support element types other than Fraction --- src/bigint_presentation_code/matrix.py | 137 ++++++++++++-------- src/bigint_presentation_code/test_matrix.py | 10 +- 2 files changed, 91 insertions(+), 56 deletions(-) diff --git a/src/bigint_presentation_code/matrix.py b/src/bigint_presentation_code/matrix.py index 3e1e154..49acddf 100644 --- a/src/bigint_presentation_code/matrix.py +++ b/src/bigint_presentation_code/matrix.py @@ -1,42 +1,59 @@ import operator +from enum import Enum, unique from fractions import Fraction from numbers import Rational -from typing import Callable, Iterable +from typing import Callable, Generic, Iterable, Iterator, Type, TypeVar +from bigint_presentation_code.util import final -class Matrix: - __slots__ = "__height", "__width", "__data" +_T = TypeVar("_T") + + +@final +@unique +class SpecialMatrix(Enum): + Zero = 0 + Identity = 1 + + +@final +class Matrix(Generic[_T]): + __slots__ = "__height", "__width", "__data", "__element_type" @property def height(self): + # type: () -> int return self.__height @property def width(self): + # type: () -> int return self.__width - def __init__(self, height, width, data=None): - # type: (int, int, Iterable[Rational | int] | None) -> None + @property + def element_type(self): + # type: () -> Type[_T] + return self.__element_type + + def __init__(self, height, width, data=SpecialMatrix.Zero, + element_type=Fraction): + # type: (int, int, Iterable[_T | int] | SpecialMatrix, Type[_T]) -> None if width < 0 or height < 0: raise ValueError("matrix size must be non-negative") self.__height = height self.__width = width - self.__data = [Fraction()] * (height * width) - if data is not None: - data = list(data) - if len(data) != len(self.__data): + self.__element_type = element_type + if isinstance(data, SpecialMatrix): + self.__data = [element_type(0) for _ in range(height * width)] + if data is SpecialMatrix.Identity: + for i in range(min(width, height)): + self[i, i] = element_type(1) + else: + assert data is SpecialMatrix.Zero + else: + self.__data = [element_type(v) for v in data] + if len(self.__data) != height * width: raise ValueError("data has wrong length") - self.__data[:] = map(Fraction, data) - - @staticmethod - def identity(height, width=None): - # type: (int, int | None) -> Matrix - if width is None: - width = height - retval = Matrix(height, width) - for i in range(min(height, width)): - retval[i, i] = 1 - return retval def __idx(self, row, col): # type: (int, int) -> int @@ -45,65 +62,67 @@ class Matrix: raise IndexError() def __getitem__(self, row_col): - # type: (tuple[int, int]) -> Fraction + # type: (tuple[int, int]) -> _T row, col = row_col return self.__data[self.__idx(row, col)] def __setitem__(self, row_col, value): - # type: (tuple[int, int], Rational | int) -> None + # type: (tuple[int, int], _T | int) -> None row, col = row_col - self.__data[self.__idx(row, col)] = Fraction(value) + self.__data[self.__idx(row, col)] = self.__element_type(value) def copy(self): - retval = Matrix(self.width, self.height) - retval.__data[:] = self.__data - return retval + # type: () -> Matrix[_T] + return Matrix(self.width, self.height, data=self.__data, + element_type=self.element_type) def indexes(self): + # type: () -> Iterable[tuple[int, int]] for row in range(self.height): for col in range(self.width): yield row, col def __mul__(self, rhs): - # type: (Rational | int) -> Matrix - rhs = Fraction(rhs) + # type: (_T | int) -> Matrix[_T] retval = self.copy() for i in self.indexes(): - retval[i] *= rhs + retval[i] *= rhs # type: ignore return retval def __rmul__(self, lhs): - # type: (Rational | int) -> Matrix - return self.__mul__(lhs) + # type: (_T | int) -> Matrix[_T] + retval = self.copy() + for i in self.indexes(): + retval[i] = lhs * retval[i] # type: ignore + return retval def __truediv__(self, rhs): # type: (Rational | int) -> Matrix - rhs = 1 / Fraction(rhs) retval = self.copy() for i in self.indexes(): - retval[i] *= rhs + retval[i] /= rhs # type: ignore return retval def __matmul__(self, rhs): - # type: (Matrix) -> Matrix + # type: (Matrix[_T]) -> Matrix[_T] if self.width != rhs.height: raise ValueError( "lhs width must equal rhs height to multiply matrixes") - retval = Matrix(self.height, rhs.width) + retval = Matrix(self.height, rhs.width, element_type=self.element_type) for row in range(retval.height): for col in range(retval.width): - sum = Fraction() + sum = self.element_type() for i in range(self.width): - sum += self[row, i] * rhs[i, col] + sum += self[row, i] * rhs[i, col] # type: ignore retval[row, col] = sum return retval def __rmatmul__(self, lhs): - # type: (Matrix) -> Matrix + # type: (Matrix[_T]) -> Matrix[_T] return lhs.__matmul__(self) def __elementwise_bin_op(self, rhs, op): - # type: (Matrix, Callable[[Fraction, Fraction], Fraction]) -> Matrix + # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T] if self.height != rhs.height or self.width != rhs.width: raise ValueError( "matrix dimensions must match for element-wise operations") @@ -113,34 +132,38 @@ class Matrix: return retval def __add__(self, rhs): - # type: (Matrix) -> Matrix + # type: (Matrix[_T]) -> Matrix[_T] return self.__elementwise_bin_op(rhs, operator.add) def __radd__(self, lhs): - # type: (Matrix) -> Matrix + # type: (Matrix[_T]) -> Matrix[_T] return lhs.__add__(self) def __sub__(self, rhs): - # type: (Matrix) -> Matrix + # type: (Matrix[_T]) -> Matrix[_T] return self.__elementwise_bin_op(rhs, operator.sub) def __rsub__(self, lhs): - # type: (Matrix) -> Matrix + # type: (Matrix[_T]) -> Matrix[_T] return lhs.__sub__(self) def __iter__(self): + # type: () -> Iterator[_T] return iter(self.__data) def __reversed__(self): + # type: () -> Iterator[_T] return reversed(self.__data) def __neg__(self): + # type: () -> Matrix[_T] retval = self.copy() for i in retval.indexes(): - retval[i] = -retval[i] + retval[i] = -retval[i] # type: ignore return retval def __repr__(self): + # type: () -> str if self.height == 0 or self.width == 0: return f"Matrix(height={self.height}, width={self.width})" lines = [] @@ -148,28 +171,40 @@ class Matrix: for row in range(self.height): line.clear() for col in range(self.width): - if self[row, col].denominator == 1: - line.append(str(self[row, col].numerator)) + el = self[row, col] + if isinstance(el, Fraction) and el.denominator == 1: + line.append(str(el.numerator)) else: - line.append(repr(self[row, col])) + line.append(repr(el)) lines.append(", ".join(line)) lines = ",\n ".join(lines) - return (f"Matrix(height={self.height}, width={self.width}, data=[\n" + element_type = "" + if self.element_type is not Fraction: + element_type = f"element_type={self.element_type}, " + return (f"Matrix(height={self.height}, width={self.width}, " + f"{element_type}data=[\n" f" {lines},\n])") def __eq__(self, rhs): + # type: (object) -> bool if not isinstance(rhs, Matrix): return NotImplemented return (self.height == rhs.height and self.width == rhs.width - and self.__data == rhs.__data) + and self.__data == rhs.__data + and self.element_type == rhs.element_type) - def inverse(self): + def inverse(self # type: Matrix[Fraction] + ): + # type: () -> Matrix[Fraction] size = self.height if size != self.width: raise ValueError("can't invert a non-square matrix") + if self.element_type is not Fraction: + raise TypeError("can't invert a matrix with element_type that " + "isn't Fraction") inp = self.copy() - retval = Matrix.identity(size) + retval = Matrix(size, size, data=SpecialMatrix.Identity) # the algorithm is adapted from: # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C for k in range(size): diff --git a/src/bigint_presentation_code/test_matrix.py b/src/bigint_presentation_code/test_matrix.py index ef39742..1a56df0 100644 --- a/src/bigint_presentation_code/test_matrix.py +++ b/src/bigint_presentation_code/test_matrix.py @@ -1,7 +1,7 @@ import unittest from fractions import Fraction -from bigint_presentation_code.matrix import Matrix +from bigint_presentation_code.matrix import Matrix, SpecialMatrix class TestMatrix(unittest.TestCase): @@ -36,15 +36,15 @@ class TestMatrix(unittest.TestCase): Matrix(2, 2, [41, 32, 23, 14])) def test_identity(self): - self.assertEqual(Matrix.identity(2, 2), + self.assertEqual(Matrix(2, 2, data=SpecialMatrix.Identity), Matrix(2, 2, [1, 0, 0, 1])) - self.assertEqual(Matrix.identity(1, 3), + self.assertEqual(Matrix(1, 3, data=SpecialMatrix.Identity), Matrix(1, 3, [1, 0, 0])) - self.assertEqual(Matrix.identity(2, 3), + self.assertEqual(Matrix(2, 3, data=SpecialMatrix.Identity), Matrix(2, 3, [1, 0, 0, 0, 1, 0])) - self.assertEqual(Matrix.identity(3), + self.assertEqual(Matrix(3, 3, data=SpecialMatrix.Identity), Matrix(3, 3, [1, 0, 0, 0, 1, 0, 0, 0, 1])) -- 2.30.2