import operator
+from enum import Enum, unique
from fractions import Fraction
from numbers import Rational
-from typing import Callable, Iterable
+from typing import Callable, Generic, Iterable, Iterator, Type, TypeVar
+from bigint_presentation_code.util import final
-class Matrix:
- __slots__ = "__height", "__width", "__data"
+_T = TypeVar("_T")
+
+
+@final
+@unique
+class SpecialMatrix(Enum):
+ Zero = 0
+ Identity = 1
+
+
+@final
+class Matrix(Generic[_T]):
+ __slots__ = "__height", "__width", "__data", "__element_type"
@property
def height(self):
+ # type: () -> int
return self.__height
@property
def width(self):
+ # type: () -> int
return self.__width
- def __init__(self, height, width, data=None):
- # type: (int, int, Iterable[Rational | int] | None) -> None
+ @property
+ def element_type(self):
+ # type: () -> Type[_T]
+ return self.__element_type
+
+ def __init__(self, height, width, data=SpecialMatrix.Zero,
+ element_type=Fraction):
+ # type: (int, int, Iterable[_T | int] | SpecialMatrix, Type[_T]) -> None
if width < 0 or height < 0:
raise ValueError("matrix size must be non-negative")
self.__height = height
self.__width = width
- self.__data = [Fraction()] * (height * width)
- if data is not None:
- data = list(data)
- if len(data) != len(self.__data):
+ self.__element_type = element_type
+ if isinstance(data, SpecialMatrix):
+ self.__data = [element_type(0) for _ in range(height * width)]
+ if data is SpecialMatrix.Identity:
+ for i in range(min(width, height)):
+ self[i, i] = element_type(1)
+ else:
+ assert data is SpecialMatrix.Zero
+ else:
+ self.__data = [element_type(v) for v in data]
+ if len(self.__data) != height * width:
raise ValueError("data has wrong length")
- self.__data[:] = map(Fraction, data)
-
- @staticmethod
- def identity(height, width=None):
- # type: (int, int | None) -> Matrix
- if width is None:
- width = height
- retval = Matrix(height, width)
- for i in range(min(height, width)):
- retval[i, i] = 1
- return retval
def __idx(self, row, col):
# type: (int, int) -> int
raise IndexError()
def __getitem__(self, row_col):
- # type: (tuple[int, int]) -> Fraction
+ # type: (tuple[int, int]) -> _T
row, col = row_col
return self.__data[self.__idx(row, col)]
def __setitem__(self, row_col, value):
- # type: (tuple[int, int], Rational | int) -> None
+ # type: (tuple[int, int], _T | int) -> None
row, col = row_col
- self.__data[self.__idx(row, col)] = Fraction(value)
+ self.__data[self.__idx(row, col)] = self.__element_type(value)
def copy(self):
- retval = Matrix(self.width, self.height)
- retval.__data[:] = self.__data
- return retval
+ # type: () -> Matrix[_T]
+ return Matrix(self.width, self.height, data=self.__data,
+ element_type=self.element_type)
def indexes(self):
+ # type: () -> Iterable[tuple[int, int]]
for row in range(self.height):
for col in range(self.width):
yield row, col
def __mul__(self, rhs):
- # type: (Rational | int) -> Matrix
- rhs = Fraction(rhs)
+ # type: (_T | int) -> Matrix[_T]
retval = self.copy()
for i in self.indexes():
- retval[i] *= rhs
+ retval[i] *= rhs # type: ignore
return retval
def __rmul__(self, lhs):
- # type: (Rational | int) -> Matrix
- return self.__mul__(lhs)
+ # type: (_T | int) -> Matrix[_T]
+ retval = self.copy()
+ for i in self.indexes():
+ retval[i] = lhs * retval[i] # type: ignore
+ return retval
def __truediv__(self, rhs):
# type: (Rational | int) -> Matrix
- rhs = 1 / Fraction(rhs)
retval = self.copy()
for i in self.indexes():
- retval[i] *= rhs
+ retval[i] /= rhs # type: ignore
return retval
def __matmul__(self, rhs):
- # type: (Matrix) -> Matrix
+ # type: (Matrix[_T]) -> Matrix[_T]
if self.width != rhs.height:
raise ValueError(
"lhs width must equal rhs height to multiply matrixes")
- retval = Matrix(self.height, rhs.width)
+ retval = Matrix(self.height, rhs.width, element_type=self.element_type)
for row in range(retval.height):
for col in range(retval.width):
- sum = Fraction()
+ sum = self.element_type()
for i in range(self.width):
- sum += self[row, i] * rhs[i, col]
+ sum += self[row, i] * rhs[i, col] # type: ignore
retval[row, col] = sum
return retval
def __rmatmul__(self, lhs):
- # type: (Matrix) -> Matrix
+ # type: (Matrix[_T]) -> Matrix[_T]
return lhs.__matmul__(self)
def __elementwise_bin_op(self, rhs, op):
- # type: (Matrix, Callable[[Fraction, Fraction], Fraction]) -> Matrix
+ # type: (Matrix, Callable[[_T | int, _T | int], _T | int]) -> Matrix[_T]
if self.height != rhs.height or self.width != rhs.width:
raise ValueError(
"matrix dimensions must match for element-wise operations")
return retval
def __add__(self, rhs):
- # type: (Matrix) -> Matrix
+ # type: (Matrix[_T]) -> Matrix[_T]
return self.__elementwise_bin_op(rhs, operator.add)
def __radd__(self, lhs):
- # type: (Matrix) -> Matrix
+ # type: (Matrix[_T]) -> Matrix[_T]
return lhs.__add__(self)
def __sub__(self, rhs):
- # type: (Matrix) -> Matrix
+ # type: (Matrix[_T]) -> Matrix[_T]
return self.__elementwise_bin_op(rhs, operator.sub)
def __rsub__(self, lhs):
- # type: (Matrix) -> Matrix
+ # type: (Matrix[_T]) -> Matrix[_T]
return lhs.__sub__(self)
def __iter__(self):
+ # type: () -> Iterator[_T]
return iter(self.__data)
def __reversed__(self):
+ # type: () -> Iterator[_T]
return reversed(self.__data)
def __neg__(self):
+ # type: () -> Matrix[_T]
retval = self.copy()
for i in retval.indexes():
- retval[i] = -retval[i]
+ retval[i] = -retval[i] # type: ignore
return retval
def __repr__(self):
+ # type: () -> str
if self.height == 0 or self.width == 0:
return f"Matrix(height={self.height}, width={self.width})"
lines = []
for row in range(self.height):
line.clear()
for col in range(self.width):
- if self[row, col].denominator == 1:
- line.append(str(self[row, col].numerator))
+ el = self[row, col]
+ if isinstance(el, Fraction) and el.denominator == 1:
+ line.append(str(el.numerator))
else:
- line.append(repr(self[row, col]))
+ line.append(repr(el))
lines.append(", ".join(line))
lines = ",\n ".join(lines)
- return (f"Matrix(height={self.height}, width={self.width}, data=[\n"
+ element_type = ""
+ if self.element_type is not Fraction:
+ element_type = f"element_type={self.element_type}, "
+ return (f"Matrix(height={self.height}, width={self.width}, "
+ f"{element_type}data=[\n"
f" {lines},\n])")
def __eq__(self, rhs):
+ # type: (object) -> bool
if not isinstance(rhs, Matrix):
return NotImplemented
return (self.height == rhs.height
and self.width == rhs.width
- and self.__data == rhs.__data)
+ and self.__data == rhs.__data
+ and self.element_type == rhs.element_type)
- def inverse(self):
+ def inverse(self # type: Matrix[Fraction]
+ ):
+ # type: () -> Matrix[Fraction]
size = self.height
if size != self.width:
raise ValueError("can't invert a non-square matrix")
+ if self.element_type is not Fraction:
+ raise TypeError("can't invert a matrix with element_type that "
+ "isn't Fraction")
inp = self.copy()
- retval = Matrix.identity(size)
+ retval = Matrix(size, size, data=SpecialMatrix.Identity)
# the algorithm is adapted from:
# https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
for k in range(size):