move clmul files into nmigen-gf.git
[nmutil.git] / src / nmutil / test / test_deduped.py
1 import unittest
2 from nmutil.deduped import deduped
3
4
5 class TestDeduped(unittest.TestCase):
6 def test_deduped1(self):
7 global_key = 1
8 call_count = 0
9
10 def call_counter():
11 nonlocal call_count
12 retval = call_count
13 call_count += 1
14 return retval
15
16 class C:
17 def __init__(self, name):
18 self.name = name
19
20 @deduped()
21 def method(self, a, *, b=1):
22 return self, a, b, call_counter()
23
24 @deduped(global_keys=[lambda: global_key])
25 def method_with_global(self, a, *, b=1):
26 return self, a, b, call_counter(), global_key
27
28 @staticmethod
29 @deduped()
30 def smethod(a, *, b=1):
31 return a, b, call_counter()
32
33 @classmethod
34 @deduped()
35 def cmethod(cls, a, *, b=1):
36 return cls, a, b, call_counter()
37
38 def __repr__(self):
39 return f"{self.__class__.__name__}({self.name})"
40
41 class D(C):
42 pass
43
44 c1 = C("c1")
45 c2 = C("c2")
46
47 # run everything twice to ensure caching works
48 for which_pass in ("first", "second"):
49 with self.subTest(which_pass=which_pass):
50 self.assertEqual(C.cmethod(1), (C, 1, 1, 0))
51 self.assertEqual(C.cmethod(2), (C, 2, 1, 1))
52 self.assertEqual(C.cmethod(1, b=5), (C, 1, 5, 2))
53 self.assertEqual(D.cmethod(1, b=5), (D, 1, 5, 3))
54 self.assertEqual(D.smethod(1, b=5), (1, 5, 4))
55 self.assertEqual(C.smethod(1, b=5), (1, 5, 4))
56 self.assertEqual(c1.method(None), (c1, None, 1, 5))
57 global_key = 2
58 self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2))
59 self.assertEqual(c1.smethod(1, b=5), (1, 5, 4))
60 self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6))
61 self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7))
62 self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 8, 2))
63 global_key = 1
64 self.assertEqual(c1.cmethod(1, b=5), (C, 1, 5, 2))
65 self.assertEqual(c1.smethod(1, b=5), (1, 5, 4))
66 self.assertEqual(c1.method(1, b=5), (c1, 1, 5, 6))
67 self.assertEqual(c2.method(1, b=5), (c2, 1, 5, 7))
68 self.assertEqual(c1.method_with_global(1), (c1, 1, 1, 9, 1))
69 self.assertEqual(call_count, 10)
70
71 def test_bad_methods(self):
72 with self.assertRaisesRegex(TypeError,
73 ".*@staticmethod.*applied.*@deduped.*"):
74 class C:
75 @deduped()
76 @staticmethod
77 def f():
78 pass
79
80 with self.assertRaisesRegex(TypeError,
81 ".*@classmethod.*applied.*@deduped.*"):
82 class C:
83 @deduped()
84 @classmethod
85 def f():
86 pass
87
88
89 if __name__ == '__main__':
90 unittest.main()