--- /dev/null
+import operator
+from typing import Callable, Iterable
+from fractions import Fraction
+from numbers import Rational
+
+
+class Matrix:
+ __slots__ = "__height", "__width", "__data"
+
+ @property
+ def height(self):
+ return self.__height
+
+ @property
+ def width(self):
+ return self.__width
+
+ def __init__(self, height, width, data=None):
+ # type: (int, int, Iterable[Rational | int] | None) -> None
+ if width < 0 or height < 0:
+ raise ValueError("matrix size must be non-negative")
+ self.__height = height
+ self.__width = width
+ self.__data = [Fraction()] * (height * width)
+ if data is not None:
+ data = list(data)
+ if len(data) != len(self.__data):
+ raise ValueError("data has wrong length")
+ self.__data[:] = map(Fraction, data)
+
+ @staticmethod
+ def identity(height, width=None):
+ # type: (int, int | None) -> Matrix
+ if width is None:
+ width = height
+ retval = Matrix(height, width)
+ for i in range(min(height, width)):
+ retval[i, i] = 1
+ return retval
+
+ def __idx(self, row, col):
+ # type: (int, int) -> int
+ if 0 <= col < self.width and 0 <= row < self.height:
+ return row * self.width + col
+ raise IndexError()
+
+ def __getitem__(self, row_col):
+ # type: (tuple[int, int]) -> Fraction
+ row, col = row_col
+ return self.__data[self.__idx(row, col)]
+
+ def __setitem__(self, row_col, value):
+ # type: (tuple[int, int], Rational | int) -> None
+ row, col = row_col
+ self.__data[self.__idx(row, col)] = Fraction(value)
+
+ def copy(self):
+ retval = Matrix(self.width, self.height)
+ retval.__data[:] = self.__data
+ return retval
+
+ def indexes(self):
+ for row in range(self.height):
+ for col in range(self.width):
+ yield row, col
+
+ def __mul__(self, rhs):
+ # type: (Rational | int) -> Matrix
+ rhs = Fraction(rhs)
+ retval = self.copy()
+ for i in self.indexes():
+ retval[i] *= rhs
+ return retval
+
+ def __rmul__(self, lhs):
+ # type: (Rational | int) -> Matrix
+ return self.__mul__(lhs)
+
+ def __truediv__(self, rhs):
+ # type: (Rational | int) -> Matrix
+ rhs = 1 / Fraction(rhs)
+ retval = self.copy()
+ for i in self.indexes():
+ retval[i] *= rhs
+ return retval
+
+ def __matmul__(self, rhs):
+ # type: (Matrix) -> Matrix
+ if self.width != rhs.height:
+ raise ValueError(
+ "lhs width must equal rhs height to multiply matrixes")
+ retval = Matrix(self.height, rhs.width)
+ for row in range(retval.height):
+ for col in range(retval.width):
+ sum = Fraction()
+ for i in range(self.width):
+ sum += self[row, i] * rhs[i, col]
+ retval[row, col] = sum
+ return retval
+
+ def __rmatmul__(self, lhs):
+ # type: (Matrix) -> Matrix
+ return lhs.__matmul__(self)
+
+ def __elementwise_bin_op(self, rhs, op):
+ # type: (Matrix, Callable[[Fraction, Fraction], Fraction]) -> Matrix
+ if self.height != rhs.height or self.width != rhs.width:
+ raise ValueError(
+ "matrix dimensions must match for element-wise operations")
+ retval = self.copy()
+ for i in retval.indexes():
+ retval[i] = op(retval[i], rhs[i])
+ return retval
+
+ def __add__(self, rhs):
+ # type: (Matrix) -> Matrix
+ return self.__elementwise_bin_op(rhs, operator.add)
+
+ def __radd__(self, lhs):
+ # type: (Matrix) -> Matrix
+ return lhs.__add__(self)
+
+ def __sub__(self, rhs):
+ # type: (Matrix) -> Matrix
+ return self.__elementwise_bin_op(rhs, operator.sub)
+
+ def __rsub__(self, lhs):
+ # type: (Matrix) -> Matrix
+ return lhs.__sub__(self)
+
+ def __iter__(self):
+ return iter(self.__data)
+
+ def __reversed__(self):
+ return reversed(self.__data)
+
+ def __neg__(self):
+ retval = self.copy()
+ for i in retval.indexes():
+ retval[i] = -retval[i]
+ return retval
+
+ def __repr__(self):
+ if self.height == 0 or self.width == 0:
+ return f"Matrix(height={self.height}, width={self.width})"
+ lines = []
+ line = []
+ for row in range(self.height):
+ line.clear()
+ for col in range(self.width):
+ if self[row, col].denominator == 1:
+ line.append(str(self[row, col].numerator))
+ else:
+ line.append(repr(self[row, col]))
+ lines.append(", ".join(line))
+ lines = ",\n ".join(lines)
+ return (f"Matrix(height={self.height}, width={self.width}, data=[\n"
+ f" {lines},\n])")
+
+ def __eq__(self, rhs):
+ if not isinstance(rhs, Matrix):
+ return NotImplemented
+ return (self.height == rhs.height
+ and self.width == rhs.width
+ and self.__data == rhs.__data)
+
+ def inverse(self):
+ size = self.height
+ if size != self.width:
+ raise ValueError("can't invert a non-square matrix")
+ inp = self.copy()
+ retval = Matrix.identity(size)
+ # the algorithm is adapted from:
+ # https://rosettacode.org/wiki/Gauss-Jordan_matrix_inversion#C
+ for k in range(size):
+ f = abs(inp[k, k]) # Find pivot.
+ p = k
+ for i in range(k + 1, size):
+ g = abs(inp[k, i])
+ if g > f:
+ f = g
+ p = i
+ if f == 0:
+ raise ZeroDivisionError("Matrix is singular")
+ if p != k: # Swap rows.
+ for j in range(k, size):
+ f = inp[j, k]
+ inp[j, k] = inp[j, p]
+ inp[j, p] = f
+ for j in range(size):
+ f = retval[j, k]
+ retval[j, k] = retval[j, p]
+ retval[j, p] = f
+ f = 1 / inp[k, k] # Scale row so pivot is 1.
+ for j in range(k, size):
+ inp[j, k] *= f
+ for j in range(size):
+ retval[j, k] *= f
+ for i in range(size): # Subtract to get zeros.
+ if i == k:
+ continue
+ f = inp[k, i]
+ for j in range(k, size):
+ inp[j, i] -= inp[j, k] * f
+ for j in range(size):
+ retval[j, i] -= retval[j, k] * f
+ return retval
--- /dev/null
+import unittest
+from fractions import Fraction
+
+from bigint_presentation_code.matrix import Matrix
+
+
+class TestMatrix(unittest.TestCase):
+ def test_repr(self):
+ self.assertEqual(repr(Matrix(2, 3, [0, 1, 2,
+ 3, 4, 5])),
+ 'Matrix(height=2, width=3, data=[\n'
+ ' 0, 1, 2,\n'
+ ' 3, 4, 5,\n'
+ '])')
+ self.assertEqual(repr(Matrix(2, 3, [0, 1, Fraction(2) / 3,
+ 3, 4, 5])),
+ 'Matrix(height=2, width=3, data=[\n'
+ ' 0, 1, Fraction(2, 3),\n'
+ ' 3, 4, 5,\n'
+ '])')
+ self.assertEqual(repr(Matrix(0, 3)), 'Matrix(height=0, width=3)')
+ self.assertEqual(repr(Matrix(2, 0)), 'Matrix(height=2, width=0)')
+
+ def test_eq(self):
+ self.assertFalse(Matrix(1, 1) == 5)
+ self.assertFalse(5 == Matrix(1, 1))
+ self.assertFalse(Matrix(2, 1) == Matrix(1, 1))
+ self.assertFalse(Matrix(1, 2) == Matrix(1, 1))
+ self.assertTrue(Matrix(1, 1) == Matrix(1, 1))
+ self.assertTrue(Matrix(1, 1, [1]) == Matrix(1, 1, [1]))
+ self.assertFalse(Matrix(1, 1, [2]) == Matrix(1, 1, [1]))
+
+ def test_add(self):
+ self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+ + Matrix(2, 2, [40, 30, 20, 10]),
+ Matrix(2, 2, [41, 32, 23, 14]))
+
+ def test_identity(self):
+ self.assertEqual(Matrix.identity(2, 2),
+ Matrix(2, 2, [1, 0,
+ 0, 1]))
+ self.assertEqual(Matrix.identity(1, 3),
+ Matrix(1, 3, [1, 0, 0]))
+ self.assertEqual(Matrix.identity(2, 3),
+ Matrix(2, 3, [1, 0, 0,
+ 0, 1, 0]))
+ self.assertEqual(Matrix.identity(3),
+ Matrix(3, 3, [1, 0, 0,
+ 0, 1, 0,
+ 0, 0, 1]))
+
+ def test_sub(self):
+ self.assertEqual(Matrix(2, 2, [40, 30, 20, 10])
+ - Matrix(2, 2, [-1, -2, -3, -4]),
+ Matrix(2, 2, [41, 32, 23, 14]))
+
+ def test_neg(self):
+ self.assertEqual(-Matrix(2, 2, [40, 30, 20, 10]),
+ Matrix(2, 2, [-40, -30, -20, -10]))
+
+ def test_mul(self):
+ self.assertEqual(Matrix(2, 2, [1, 2, 3, 4]) * Fraction(3, 2),
+ Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+ self.assertEqual(Fraction(3, 2) * Matrix(2, 2, [1, 2, 3, 4]),
+ Matrix(2, 2, [Fraction(3, 2), 3, Fraction(9, 2), 6]))
+
+ def test_matmul(self):
+ self.assertEqual(Matrix(2, 2, [1, 2, 3, 4])
+ @ Matrix(2, 2, [4, 3, 2, 1]),
+ Matrix(2, 2, [8, 5, 20, 13]))
+ self.assertEqual(Matrix(3, 2, [6, 5, 4, 3, 2, 1])
+ @ Matrix(2, 1, [1, 2]),
+ Matrix(3, 1, [16, 10, 4]))
+
+ def test_inverse(self):
+ self.assertEqual(Matrix(0, 0).inverse(), Matrix(0, 0))
+ self.assertEqual(Matrix(1, 1, [2]).inverse(),
+ Matrix(1, 1, [Fraction(1, 2)]))
+ self.assertEqual(Matrix(1, 1, [1]).inverse(),
+ Matrix(1, 1, [1]))
+ self.assertEqual(Matrix(2, 2, [1, 0, 1, 1]).inverse(),
+ Matrix(2, 2, [1, 0, -1, 1]))
+ self.assertEqual(Matrix(3, 3, [0, 1, 0,
+ 1, 0, 0,
+ 0, 0, 1]).inverse(),
+ Matrix(3, 3, [0, 1, 0,
+ 1, 0, 0,
+ 0, 0, 1]))
+ _1_2 = Fraction(1, 2)
+ _1_3 = Fraction(1, 3)
+ _1_6 = Fraction(1, 6)
+ self.assertEqual(Matrix(5, 5, [1, 0, 0, 0, 0,
+ 1, 1, 1, 1, 1,
+ 1, -1, 1, -1, 1,
+ 1, -2, 4, -8, 16,
+ 0, 0, 0, 0, 1]).inverse(),
+ Matrix(5, 5, [1, 0, 0, 0, 0,
+ _1_2, _1_3, -1, _1_6, -2,
+ -1, _1_2, _1_2, 0, -1,
+ -_1_2, _1_6, _1_2, -_1_6, 2,
+ 0, 0, 0, 0, 1]))
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(1, 1, [0]).inverse()
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(2, 2, [0, 0, 1, 1]).inverse()
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(2, 2, [1, 0, 1, 0]).inverse()
+ with self.assertRaisesRegex(ZeroDivisionError, "Matrix is singular"):
+ Matrix(2, 2, [1, 1, 1, 1]).inverse()
+
+
+if __name__ == "__main__":
+ unittest.main()