fhdl.xfrm: implement DomainRenamer.
authorwhitequark <whitequark@whitequark.org>
Thu, 13 Dec 2018 08:57:14 +0000 (08:57 +0000)
committerwhitequark <whitequark@whitequark.org>
Thu, 13 Dec 2018 08:57:14 +0000 (08:57 +0000)
nmigen/fhdl/ast.py
nmigen/fhdl/xfrm.py
nmigen/test/test_fhdl_dsl.py
nmigen/test/test_fhdl_value.py [new file with mode: 0644]
nmigen/test/test_fhdl_values.py [deleted file]
nmigen/test/test_fhdl_xfrm.py
nmigen/test/tools.py [new file with mode: 0644]

index 604bb9e29d4f3654ae0cd2a94fc36749035ef294..9d3af7deafc970dd611cd8fb1d3ee57b5dd8876a 100644 (file)
@@ -573,7 +573,7 @@ class ClockSignal(Value):
     Parameters
     ----------
     domain : str
-        Clock domain to obtain a clock signal for. Defaults to `"sync"`.
+        Clock domain to obtain a clock signal for. Defaults to ``"sync"``.
     """
     def __init__(self, domain="sync"):
         super().__init__()
@@ -588,13 +588,13 @@ class ClockSignal(Value):
 class ResetSignal(Value):
     """Reset signal for a given clock domain
 
-    `ResetSignal` s for a given clock domain can be retrieved multiple
+    ``ResetSignal`` s for a given clock domain can be retrieved multiple
     times. They all ultimately refer to the same signal.
 
     Parameters
     ----------
     domain : str
-        Clock domain to obtain a reset signal for. Defaults to `"sync"`.
+        Clock domain to obtain a reset signal for. Defaults to ``"sync"``.
     """
     def __init__(self, domain="sync"):
         super().__init__()
@@ -603,7 +603,7 @@ class ResetSignal(Value):
         self.domain = domain
 
     def __repr__(self):
-        return "(reset {})".format(self.domain)
+        return "(rst {})".format(self.domain)
 
 
 class _StatementList(list):
index 8a20aab12ac457e4d7ad19048b159440d28f79d4..4277af63628abec977fba7890cad7f0c49f45a39 100644 (file)
@@ -4,7 +4,8 @@ from .ast import *
 from .ir import *
 
 
-__all__ = ["ValueTransformer", "StatementTransformer", "ResetInserter", "CEInserter"]
+__all__ = ["ValueTransformer", "StatementTransformer", "FragmentTransformer",
+           "DomainRenamer", "ResetInserter", "CEInserter"]
 
 
 class ValueTransformer:
@@ -116,6 +117,30 @@ class FragmentTransformer:
         return self.on_fragment(value)
 
 
+class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer):
+    def __init__(self, domains):
+        if isinstance(domains, str):
+            domains = {"sync": domains}
+        self.domains = OrderedDict(domains)
+
+    def on_ClockSignal(self, value):
+        if value.domain in self.domains:
+            return ClockSignal(self.domains[value.domain])
+        return value
+
+    def on_ResetSignal(self, value):
+        if value.domain in self.domains:
+            return ResetSignal(self.domains[value.domain])
+        return value
+
+    def map_drivers(self, fragment, new_fragment):
+        for cd_name, signals in fragment.iter_domains():
+            if cd_name in self.domains:
+                cd_name = self.domains[cd_name]
+            for signal in signals:
+                new_fragment.drive(signal, cd_name)
+
+
 class _ControlInserter(FragmentTransformer):
     def __init__(self, controls):
         if isinstance(controls, Value):
index 22019b1a028c003a109b60ad458d3d48cdf0bc68..7da25cf348765f4f1645c97f239061546edb60d0 100644 (file)
@@ -1,12 +1,12 @@
-import re
 import unittest
 from contextlib import contextmanager
 
-from nmigen.fhdl.ast import *
-from nmigen.fhdl.dsl import *
+from ..fhdl.ast import *
+from ..fhdl.dsl import *
+from .tools import *
 
 
-class DSLTestCase(unittest.TestCase):
+class DSLTestCase(FHDLTestCase):
     def setUp(self):
         self.s1 = Signal()
         self.s2 = Signal()
@@ -24,13 +24,6 @@ class DSLTestCase(unittest.TestCase):
             # WTF? unittest.assertRaises is completely broken.
             self.assertEqual(str(cm.exception), msg)
 
-    def assertRepr(self, obj, repr_str):
-        obj = Statement.wrap(obj)
-        repr_str = re.sub(r"\s+",   " ",  repr_str)
-        repr_str = re.sub(r"\( (?=\()", "(", repr_str)
-        repr_str = re.sub(r"\) (?=\))", ")", repr_str)
-        self.assertEqual(repr(obj), repr_str.strip())
-
     def test_d_comb(self):
         m = Module()
         m.d.comb += self.c1.eq(1)
diff --git a/nmigen/test/test_fhdl_value.py b/nmigen/test/test_fhdl_value.py
new file mode 100644 (file)
index 0000000..9c7dbde
--- /dev/null
@@ -0,0 +1,358 @@
+import unittest
+
+from nmigen.fhdl.ast import *
+
+
+class ValueTestCase(unittest.TestCase):
+    def test_wrap(self):
+        self.assertIsInstance(Value.wrap(0), Const)
+        self.assertIsInstance(Value.wrap(True), Const)
+        c = Const(0)
+        self.assertIs(Value.wrap(c), c)
+        with self.assertRaises(TypeError):
+            Value.wrap("str")
+
+    def test_bool(self):
+        with self.assertRaises(TypeError):
+            if Const(0):
+                pass
+
+    def test_len(self):
+        self.assertEqual(len(Const(10)), 4)
+
+    def test_getitem_int(self):
+        s1 = Const(10)[0]
+        self.assertIsInstance(s1, Slice)
+        self.assertEqual(s1.start, 0)
+        self.assertEqual(s1.end, 1)
+        s2 = Const(10)[-1]
+        self.assertIsInstance(s2, Slice)
+        self.assertEqual(s2.start, 3)
+        self.assertEqual(s2.end, 4)
+        with self.assertRaises(IndexError):
+            Const(10)[5]
+
+    def test_getitem_slice(self):
+        s1 = Const(10)[1:3]
+        self.assertIsInstance(s1, Slice)
+        self.assertEqual(s1.start, 1)
+        self.assertEqual(s1.end, 3)
+        s2 = Const(10)[1:-2]
+        self.assertIsInstance(s2, Slice)
+        self.assertEqual(s2.start, 1)
+        self.assertEqual(s2.end, 2)
+        s3 = Const(31)[::2]
+        self.assertIsInstance(s3, Cat)
+        self.assertIsInstance(s3.operands[0], Slice)
+        self.assertEqual(s3.operands[0].start, 0)
+        self.assertEqual(s3.operands[0].end, 1)
+        self.assertIsInstance(s3.operands[1], Slice)
+        self.assertEqual(s3.operands[1].start, 2)
+        self.assertEqual(s3.operands[1].end, 3)
+        self.assertIsInstance(s3.operands[2], Slice)
+        self.assertEqual(s3.operands[2].start, 4)
+        self.assertEqual(s3.operands[2].end, 5)
+
+    def test_getitem_wrong(self):
+        with self.assertRaises(TypeError):
+            Const(31)["str"]
+
+
+class ConstTestCase(unittest.TestCase):
+    def test_shape(self):
+        self.assertEqual(Const(0).shape(),   (0, False))
+        self.assertEqual(Const(1).shape(),   (1, False))
+        self.assertEqual(Const(10).shape(),  (4, False))
+        self.assertEqual(Const(-10).shape(), (4, True))
+
+        self.assertEqual(Const(1, 4).shape(),         (4, False))
+        self.assertEqual(Const(1, (4, True)).shape(), (4, True))
+
+        with self.assertRaises(TypeError):
+            Const(1, -1)
+
+    def test_value(self):
+        self.assertEqual(Const(10).value, 10)
+
+    def test_repr(self):
+        self.assertEqual(repr(Const(10)),  "(const 4'd10)")
+        self.assertEqual(repr(Const(-10)), "(const 4'sd-10)")
+
+    def test_hash(self):
+        with self.assertRaises(TypeError):
+            hash(Const(0))
+
+
+class OperatorTestCase(unittest.TestCase):
+    def test_invert(self):
+        v = ~Const(0, 4)
+        self.assertEqual(repr(v), "(~ (const 4'd0))")
+        self.assertEqual(v.shape(), (4, False))
+
+    def test_neg(self):
+        v1 = -Const(0, (4, False))
+        self.assertEqual(repr(v1), "(- (const 4'd0))")
+        self.assertEqual(v1.shape(), (5, True))
+        v2 = -Const(0, (4, True))
+        self.assertEqual(repr(v2), "(- (const 4'sd0))")
+        self.assertEqual(v2.shape(), (4, True))
+
+    def test_add(self):
+        v1 = Const(0, (4, False)) + Const(0, (6, False))
+        self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (7, False))
+        v2 = Const(0, (4, True)) + Const(0, (6, True))
+        self.assertEqual(v2.shape(), (7, True))
+        v3 = Const(0, (4, True)) + Const(0, (4, False))
+        self.assertEqual(v3.shape(), (6, True))
+        v4 = Const(0, (4, False)) + Const(0, (4, True))
+        self.assertEqual(v4.shape(), (6, True))
+        v5 = 10 + Const(0, 4)
+        self.assertEqual(v5.shape(), (5, False))
+
+    def test_sub(self):
+        v1 = Const(0, (4, False)) - Const(0, (6, False))
+        self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (7, False))
+        v2 = Const(0, (4, True)) - Const(0, (6, True))
+        self.assertEqual(v2.shape(), (7, True))
+        v3 = Const(0, (4, True)) - Const(0, (4, False))
+        self.assertEqual(v3.shape(), (6, True))
+        v4 = Const(0, (4, False)) - Const(0, (4, True))
+        self.assertEqual(v4.shape(), (6, True))
+        v5 = 10 - Const(0, 4)
+        self.assertEqual(v5.shape(), (5, False))
+
+    def test_mul(self):
+        v1 = Const(0, (4, False)) * Const(0, (6, False))
+        self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (10, False))
+        v2 = Const(0, (4, True)) * Const(0, (6, True))
+        self.assertEqual(v2.shape(), (9, True))
+        v3 = Const(0, (4, True)) * Const(0, (4, False))
+        self.assertEqual(v3.shape(), (8, True))
+        v5 = 10 * Const(0, 4)
+        self.assertEqual(v5.shape(), (8, False))
+
+    def test_and(self):
+        v1 = Const(0, (4, False)) & Const(0, (6, False))
+        self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Const(0, (4, True)) & Const(0, (6, True))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Const(0, (4, True)) & Const(0, (4, False))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Const(0, (4, False)) & Const(0, (4, True))
+        self.assertEqual(v4.shape(), (5, True))
+        v5 = 10 & Const(0, 4)
+        self.assertEqual(v5.shape(), (4, False))
+
+    def test_or(self):
+        v1 = Const(0, (4, False)) | Const(0, (6, False))
+        self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Const(0, (4, True)) | Const(0, (6, True))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Const(0, (4, True)) | Const(0, (4, False))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Const(0, (4, False)) | Const(0, (4, True))
+        self.assertEqual(v4.shape(), (5, True))
+        v5 = 10 | Const(0, 4)
+        self.assertEqual(v5.shape(), (4, False))
+
+    def test_xor(self):
+        v1 = Const(0, (4, False)) ^ Const(0, (6, False))
+        self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Const(0, (4, True)) ^ Const(0, (6, True))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Const(0, (4, True)) ^ Const(0, (4, False))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Const(0, (4, False)) ^ Const(0, (4, True))
+        self.assertEqual(v4.shape(), (5, True))
+        v5 = 10 ^ Const(0, 4)
+        self.assertEqual(v5.shape(), (4, False))
+
+    def test_lt(self):
+        v = Const(0, 4) < Const(0, 6)
+        self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_le(self):
+        v = Const(0, 4) <= Const(0, 6)
+        self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_gt(self):
+        v = Const(0, 4) > Const(0, 6)
+        self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_ge(self):
+        v = Const(0, 4) >= Const(0, 6)
+        self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_eq(self):
+        v = Const(0, 4) == Const(0, 6)
+        self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_ne(self):
+        v = Const(0, 4) != Const(0, 6)
+        self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_mux(self):
+        s  = Const(0)
+        v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False)))
+        self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True)))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False)))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True)))
+        self.assertEqual(v4.shape(), (5, True))
+
+    def test_bool(self):
+        v = Const(0).bool()
+        self.assertEqual(repr(v), "(b (const 0'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_hash(self):
+        with self.assertRaises(TypeError):
+            hash(Const(0) + Const(0))
+
+
+class SliceTestCase(unittest.TestCase):
+    def test_shape(self):
+        s1 = Const(10)[2]
+        self.assertEqual(s1.shape(), (1, False))
+        s2 = Const(-10)[0:2]
+        self.assertEqual(s2.shape(), (2, False))
+
+    def test_repr(self):
+        s1 = Const(10)[2]
+        self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)")
+
+
+class CatTestCase(unittest.TestCase):
+    def test_shape(self):
+        c1 = Cat(Const(10))
+        self.assertEqual(c1.shape(), (4, False))
+        c2 = Cat(Const(10), Const(1))
+        self.assertEqual(c2.shape(), (5, False))
+        c3 = Cat(Const(10), Const(1), Const(0))
+        self.assertEqual(c3.shape(), (5, False))
+
+    def test_repr(self):
+        c1 = Cat(Const(10), Const(1))
+        self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))")
+
+
+class ReplTestCase(unittest.TestCase):
+    def test_shape(self):
+        r1 = Repl(Const(10), 3)
+        self.assertEqual(r1.shape(), (12, False))
+
+    def test_count_wrong(self):
+        with self.assertRaises(TypeError):
+            Repl(Const(10), -1)
+        with self.assertRaises(TypeError):
+            Repl(Const(10), "str")
+
+    def test_repr(self):
+        r1 = Repl(Const(10), 3)
+        self.assertEqual(repr(r1), "(repl (const 4'd10) 3)")
+
+
+class SignalTestCase(unittest.TestCase):
+    def test_shape(self):
+        s1 = Signal()
+        self.assertEqual(s1.shape(), (1, False))
+        s2 = Signal(2)
+        self.assertEqual(s2.shape(), (2, False))
+        s3 = Signal((2, False))
+        self.assertEqual(s3.shape(), (2, False))
+        s4 = Signal((2, True))
+        self.assertEqual(s4.shape(), (2, True))
+        s5 = Signal(max=16)
+        self.assertEqual(s5.shape(), (4, False))
+        s6 = Signal(min=4, max=16)
+        self.assertEqual(s6.shape(), (4, False))
+        s7 = Signal(min=-4, max=16)
+        self.assertEqual(s7.shape(), (5, True))
+        s8 = Signal(min=-20, max=16)
+        self.assertEqual(s8.shape(), (6, True))
+
+        with self.assertRaises(ValueError):
+            Signal(min=10, max=4)
+        with self.assertRaises(ValueError):
+            Signal(2, min=10)
+        with self.assertRaises(TypeError):
+            Signal(-10)
+
+    def test_name(self):
+        s1 = Signal()
+        self.assertEqual(s1.name, "s1")
+        s2 = Signal(name="sig")
+        self.assertEqual(s2.name, "sig")
+
+    def test_reset(self):
+        s1 = Signal(4, reset=0b111, reset_less=True)
+        self.assertEqual(s1.reset, 0b111)
+        self.assertEqual(s1.reset_less, True)
+
+    def test_attrs(self):
+        s1 = Signal()
+        self.assertEqual(s1.attrs, {})
+        s2 = Signal(attrs={"no_retiming": True})
+        self.assertEqual(s2.attrs, {"no_retiming": True})
+
+    def test_repr(self):
+        s1 = Signal()
+        self.assertEqual(repr(s1), "(sig s1)")
+
+    def test_like(self):
+        s1 = Signal.like(Signal(4))
+        self.assertEqual(s1.shape(), (4, False))
+        s2 = Signal.like(Signal(min=-15))
+        self.assertEqual(s2.shape(), (5, True))
+        s3 = Signal.like(Signal(4, reset=0b111, reset_less=True))
+        self.assertEqual(s3.reset, 0b111)
+        self.assertEqual(s3.reset_less, True)
+        s4 = Signal.like(Signal(attrs={"no_retiming": True}))
+        self.assertEqual(s4.attrs, {"no_retiming": True})
+        s5 = Signal.like(10)
+        self.assertEqual(s5.shape(), (4, False))
+
+
+class ClockSignalTestCase(unittest.TestCase):
+    def test_domain(self):
+        s1 = ClockSignal()
+        self.assertEqual(s1.domain, "sync")
+        s2 = ClockSignal("pix")
+        self.assertEqual(s2.domain, "pix")
+
+        with self.assertRaises(TypeError):
+            ClockSignal(1)
+
+    def test_repr(self):
+        s1 = ClockSignal()
+        self.assertEqual(repr(s1), "(clk sync)")
+
+
+class ResetSignalTestCase(unittest.TestCase):
+    def test_domain(self):
+        s1 = ResetSignal()
+        self.assertEqual(s1.domain, "sync")
+        s2 = ResetSignal("pix")
+        self.assertEqual(s2.domain, "pix")
+
+        with self.assertRaises(TypeError):
+            ResetSignal(1)
+
+    def test_repr(self):
+        s1 = ResetSignal()
+        self.assertEqual(repr(s1), "(rst sync)")
diff --git a/nmigen/test/test_fhdl_values.py b/nmigen/test/test_fhdl_values.py
deleted file mode 100644 (file)
index 16e0970..0000000
+++ /dev/null
@@ -1,358 +0,0 @@
-import unittest
-
-from nmigen.fhdl.ast import *
-
-
-class ValueTestCase(unittest.TestCase):
-    def test_wrap(self):
-        self.assertIsInstance(Value.wrap(0), Const)
-        self.assertIsInstance(Value.wrap(True), Const)
-        c = Const(0)
-        self.assertIs(Value.wrap(c), c)
-        with self.assertRaises(TypeError):
-            Value.wrap("str")
-
-    def test_bool(self):
-        with self.assertRaises(TypeError):
-            if Const(0):
-                pass
-
-    def test_len(self):
-        self.assertEqual(len(Const(10)), 4)
-
-    def test_getitem_int(self):
-        s1 = Const(10)[0]
-        self.assertIsInstance(s1, Slice)
-        self.assertEqual(s1.start, 0)
-        self.assertEqual(s1.end, 1)
-        s2 = Const(10)[-1]
-        self.assertIsInstance(s2, Slice)
-        self.assertEqual(s2.start, 3)
-        self.assertEqual(s2.end, 4)
-        with self.assertRaises(IndexError):
-            Const(10)[5]
-
-    def test_getitem_slice(self):
-        s1 = Const(10)[1:3]
-        self.assertIsInstance(s1, Slice)
-        self.assertEqual(s1.start, 1)
-        self.assertEqual(s1.end, 3)
-        s2 = Const(10)[1:-2]
-        self.assertIsInstance(s2, Slice)
-        self.assertEqual(s2.start, 1)
-        self.assertEqual(s2.end, 2)
-        s3 = Const(31)[::2]
-        self.assertIsInstance(s3, Cat)
-        self.assertIsInstance(s3.operands[0], Slice)
-        self.assertEqual(s3.operands[0].start, 0)
-        self.assertEqual(s3.operands[0].end, 1)
-        self.assertIsInstance(s3.operands[1], Slice)
-        self.assertEqual(s3.operands[1].start, 2)
-        self.assertEqual(s3.operands[1].end, 3)
-        self.assertIsInstance(s3.operands[2], Slice)
-        self.assertEqual(s3.operands[2].start, 4)
-        self.assertEqual(s3.operands[2].end, 5)
-
-    def test_getitem_wrong(self):
-        with self.assertRaises(TypeError):
-            Const(31)["str"]
-
-
-class ConstTestCase(unittest.TestCase):
-    def test_shape(self):
-        self.assertEqual(Const(0).shape(),   (0, False))
-        self.assertEqual(Const(1).shape(),   (1, False))
-        self.assertEqual(Const(10).shape(),  (4, False))
-        self.assertEqual(Const(-10).shape(), (4, True))
-
-        self.assertEqual(Const(1, 4).shape(),         (4, False))
-        self.assertEqual(Const(1, (4, True)).shape(), (4, True))
-
-        with self.assertRaises(TypeError):
-            Const(1, -1)
-
-    def test_value(self):
-        self.assertEqual(Const(10).value, 10)
-
-    def test_repr(self):
-        self.assertEqual(repr(Const(10)),  "(const 4'd10)")
-        self.assertEqual(repr(Const(-10)), "(const 4'sd-10)")
-
-    def test_hash(self):
-        with self.assertRaises(TypeError):
-            hash(Const(0))
-
-
-class OperatorTestCase(unittest.TestCase):
-    def test_invert(self):
-        v = ~Const(0, 4)
-        self.assertEqual(repr(v), "(~ (const 4'd0))")
-        self.assertEqual(v.shape(), (4, False))
-
-    def test_neg(self):
-        v1 = -Const(0, (4, False))
-        self.assertEqual(repr(v1), "(- (const 4'd0))")
-        self.assertEqual(v1.shape(), (5, True))
-        v2 = -Const(0, (4, True))
-        self.assertEqual(repr(v2), "(- (const 4'sd0))")
-        self.assertEqual(v2.shape(), (4, True))
-
-    def test_add(self):
-        v1 = Const(0, (4, False)) + Const(0, (6, False))
-        self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (7, False))
-        v2 = Const(0, (4, True)) + Const(0, (6, True))
-        self.assertEqual(v2.shape(), (7, True))
-        v3 = Const(0, (4, True)) + Const(0, (4, False))
-        self.assertEqual(v3.shape(), (6, True))
-        v4 = Const(0, (4, False)) + Const(0, (4, True))
-        self.assertEqual(v4.shape(), (6, True))
-        v5 = 10 + Const(0, 4)
-        self.assertEqual(v5.shape(), (5, False))
-
-    def test_sub(self):
-        v1 = Const(0, (4, False)) - Const(0, (6, False))
-        self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (7, False))
-        v2 = Const(0, (4, True)) - Const(0, (6, True))
-        self.assertEqual(v2.shape(), (7, True))
-        v3 = Const(0, (4, True)) - Const(0, (4, False))
-        self.assertEqual(v3.shape(), (6, True))
-        v4 = Const(0, (4, False)) - Const(0, (4, True))
-        self.assertEqual(v4.shape(), (6, True))
-        v5 = 10 - Const(0, 4)
-        self.assertEqual(v5.shape(), (5, False))
-
-    def test_mul(self):
-        v1 = Const(0, (4, False)) * Const(0, (6, False))
-        self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (10, False))
-        v2 = Const(0, (4, True)) * Const(0, (6, True))
-        self.assertEqual(v2.shape(), (9, True))
-        v3 = Const(0, (4, True)) * Const(0, (4, False))
-        self.assertEqual(v3.shape(), (8, True))
-        v5 = 10 * Const(0, 4)
-        self.assertEqual(v5.shape(), (8, False))
-
-    def test_and(self):
-        v1 = Const(0, (4, False)) & Const(0, (6, False))
-        self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Const(0, (4, True)) & Const(0, (6, True))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Const(0, (4, True)) & Const(0, (4, False))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Const(0, (4, False)) & Const(0, (4, True))
-        self.assertEqual(v4.shape(), (5, True))
-        v5 = 10 & Const(0, 4)
-        self.assertEqual(v5.shape(), (4, False))
-
-    def test_or(self):
-        v1 = Const(0, (4, False)) | Const(0, (6, False))
-        self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Const(0, (4, True)) | Const(0, (6, True))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Const(0, (4, True)) | Const(0, (4, False))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Const(0, (4, False)) | Const(0, (4, True))
-        self.assertEqual(v4.shape(), (5, True))
-        v5 = 10 | Const(0, 4)
-        self.assertEqual(v5.shape(), (4, False))
-
-    def test_xor(self):
-        v1 = Const(0, (4, False)) ^ Const(0, (6, False))
-        self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Const(0, (4, True)) ^ Const(0, (6, True))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Const(0, (4, True)) ^ Const(0, (4, False))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Const(0, (4, False)) ^ Const(0, (4, True))
-        self.assertEqual(v4.shape(), (5, True))
-        v5 = 10 ^ Const(0, 4)
-        self.assertEqual(v5.shape(), (4, False))
-
-    def test_lt(self):
-        v = Const(0, 4) < Const(0, 6)
-        self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_le(self):
-        v = Const(0, 4) <= Const(0, 6)
-        self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_gt(self):
-        v = Const(0, 4) > Const(0, 6)
-        self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_ge(self):
-        v = Const(0, 4) >= Const(0, 6)
-        self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_eq(self):
-        v = Const(0, 4) == Const(0, 6)
-        self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_ne(self):
-        v = Const(0, 4) != Const(0, 6)
-        self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_mux(self):
-        s  = Const(0)
-        v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False)))
-        self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True)))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False)))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True)))
-        self.assertEqual(v4.shape(), (5, True))
-
-    def test_bool(self):
-        v = Const(0).bool()
-        self.assertEqual(repr(v), "(b (const 0'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_hash(self):
-        with self.assertRaises(TypeError):
-            hash(Const(0) + Const(0))
-
-
-class SliceTestCase(unittest.TestCase):
-    def test_shape(self):
-        s1 = Const(10)[2]
-        self.assertEqual(s1.shape(), (1, False))
-        s2 = Const(-10)[0:2]
-        self.assertEqual(s2.shape(), (2, False))
-
-    def test_repr(self):
-        s1 = Const(10)[2]
-        self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)")
-
-
-class CatTestCase(unittest.TestCase):
-    def test_shape(self):
-        c1 = Cat(Const(10))
-        self.assertEqual(c1.shape(), (4, False))
-        c2 = Cat(Const(10), Const(1))
-        self.assertEqual(c2.shape(), (5, False))
-        c3 = Cat(Const(10), Const(1), Const(0))
-        self.assertEqual(c3.shape(), (5, False))
-
-    def test_repr(self):
-        c1 = Cat(Const(10), Const(1))
-        self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))")
-
-
-class ReplTestCase(unittest.TestCase):
-    def test_shape(self):
-        r1 = Repl(Const(10), 3)
-        self.assertEqual(r1.shape(), (12, False))
-
-    def test_count_wrong(self):
-        with self.assertRaises(TypeError):
-            Repl(Const(10), -1)
-        with self.assertRaises(TypeError):
-            Repl(Const(10), "str")
-
-    def test_repr(self):
-        r1 = Repl(Const(10), 3)
-        self.assertEqual(repr(r1), "(repl (const 4'd10) 3)")
-
-
-class SignalTestCase(unittest.TestCase):
-    def test_shape(self):
-        s1 = Signal()
-        self.assertEqual(s1.shape(), (1, False))
-        s2 = Signal(2)
-        self.assertEqual(s2.shape(), (2, False))
-        s3 = Signal((2, False))
-        self.assertEqual(s3.shape(), (2, False))
-        s4 = Signal((2, True))
-        self.assertEqual(s4.shape(), (2, True))
-        s5 = Signal(max=16)
-        self.assertEqual(s5.shape(), (4, False))
-        s6 = Signal(min=4, max=16)
-        self.assertEqual(s6.shape(), (4, False))
-        s7 = Signal(min=-4, max=16)
-        self.assertEqual(s7.shape(), (5, True))
-        s8 = Signal(min=-20, max=16)
-        self.assertEqual(s8.shape(), (6, True))
-
-        with self.assertRaises(ValueError):
-            Signal(min=10, max=4)
-        with self.assertRaises(ValueError):
-            Signal(2, min=10)
-        with self.assertRaises(TypeError):
-            Signal(-10)
-
-    def test_name(self):
-        s1 = Signal()
-        self.assertEqual(s1.name, "s1")
-        s2 = Signal(name="sig")
-        self.assertEqual(s2.name, "sig")
-
-    def test_reset(self):
-        s1 = Signal(4, reset=0b111, reset_less=True)
-        self.assertEqual(s1.reset, 0b111)
-        self.assertEqual(s1.reset_less, True)
-
-    def test_attrs(self):
-        s1 = Signal()
-        self.assertEqual(s1.attrs, {})
-        s2 = Signal(attrs={"no_retiming": True})
-        self.assertEqual(s2.attrs, {"no_retiming": True})
-
-    def test_repr(self):
-        s1 = Signal()
-        self.assertEqual(repr(s1), "(sig s1)")
-
-    def test_like(self):
-        s1 = Signal.like(Signal(4))
-        self.assertEqual(s1.shape(), (4, False))
-        s2 = Signal.like(Signal(min=-15))
-        self.assertEqual(s2.shape(), (5, True))
-        s3 = Signal.like(Signal(4, reset=0b111, reset_less=True))
-        self.assertEqual(s3.reset, 0b111)
-        self.assertEqual(s3.reset_less, True)
-        s4 = Signal.like(Signal(attrs={"no_retiming": True}))
-        self.assertEqual(s4.attrs, {"no_retiming": True})
-        s5 = Signal.like(10)
-        self.assertEqual(s5.shape(), (4, False))
-
-
-class ClockSignalTestCase(unittest.TestCase):
-    def test_domain(self):
-        s1 = ClockSignal()
-        self.assertEqual(s1.domain, "sync")
-        s2 = ClockSignal("pix")
-        self.assertEqual(s2.domain, "pix")
-
-        with self.assertRaises(TypeError):
-            ClockSignal(1)
-
-    def test_repr(self):
-        s1 = ClockSignal()
-        self.assertEqual(repr(s1), "(clk sync)")
-
-
-class ResetSignalTestCase(unittest.TestCase):
-    def test_domain(self):
-        s1 = ResetSignal()
-        self.assertEqual(s1.domain, "sync")
-        s2 = ResetSignal("pix")
-        self.assertEqual(s2.domain, "pix")
-
-        with self.assertRaises(TypeError):
-            ResetSignal(1)
-
-    def test_repr(self):
-        s1 = ResetSignal()
-        self.assertEqual(repr(s1), "(reset sync)")
index b5696013303123930d03d4afa13ba274c887e526..aad45258495601e3bb5665af27e4e56fd10270b6 100644 (file)
@@ -1,25 +1,72 @@
 import re
 import unittest
 
-from nmigen.fhdl.ast import *
-from nmigen.fhdl.ir import *
-from nmigen.fhdl.xfrm import *
+from ..fhdl.ast import *
+from ..fhdl.ir import *
+from ..fhdl.xfrm import *
+from .tools import *
 
 
-class ResetInserterTestCase(unittest.TestCase):
+class DomainRenamerTestCase(FHDLTestCase):
+    def setUp(self):
+        self.s1 = Signal()
+        self.s2 = Signal()
+        self.s3 = Signal()
+        self.s4 = Signal()
+        self.s5 = Signal()
+        self.c1 = Signal()
+
+    def test_rename_signals(self):
+        f = Fragment()
+        f.add_statements(
+            self.s1.eq(ClockSignal()),
+            ResetSignal().eq(self.s2),
+            self.s3.eq(0),
+            self.s4.eq(ClockSignal("other")),
+            self.s5.eq(ResetSignal("other")),
+        )
+        f.drive(self.s1, None)
+        f.drive(self.s2, None)
+        f.drive(self.s3, "sync")
+
+        f = DomainRenamer("pix")(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s1) (clk pix))
+            (eq (rst pix) (sig s2))
+            (eq (sig s3) (const 0'd0))
+            (eq (sig s4) (clk other))
+            (eq (sig s5) (rst other))
+        )
+        """)
+        self.assertEqual(f.drivers, {
+            None: ValueSet((self.s1, self.s2)),
+            "pix": ValueSet((self.s3,)),
+        })
+
+    def test_rename_multi(self):
+        f = Fragment()
+        f.add_statements(
+            self.s1.eq(ClockSignal()),
+            self.s2.eq(ResetSignal("other")),
+        )
+
+        f = DomainRenamer({"sync": "pix", "other": "pix2"})(f)
+        self.assertRepr(f.statements, """
+        (
+            (eq (sig s1) (clk pix))
+            (eq (sig s2) (rst pix2))
+        )
+        """)
+
+
+class ResetInserterTestCase(FHDLTestCase):
     def setUp(self):
         self.s1 = Signal()
         self.s2 = Signal(reset=1)
         self.s3 = Signal(reset=1, reset_less=True)
         self.c1 = Signal()
 
-    def assertRepr(self, obj, repr_str):
-        obj = Statement.wrap(obj)
-        repr_str = re.sub(r"\s+",   " ",  repr_str)
-        repr_str = re.sub(r"\( (?=\()", "(", repr_str)
-        repr_str = re.sub(r"\) (?=\))", ")", repr_str)
-        self.assertEqual(repr(obj), repr_str.strip())
-
     def test_reset_default(self):
         f = Fragment()
         f.add_statements(
@@ -92,20 +139,13 @@ class ResetInserterTestCase(unittest.TestCase):
         """)
 
 
-class CEInserterTestCase(unittest.TestCase):
+class CEInserterTestCase(FHDLTestCase):
     def setUp(self):
         self.s1 = Signal()
         self.s2 = Signal()
         self.s3 = Signal()
         self.c1 = Signal()
 
-    def assertRepr(self, obj, repr_str):
-        obj = Statement.wrap(obj)
-        repr_str = re.sub(r"\s+",   " ",  repr_str)
-        repr_str = re.sub(r"\( (?=\()", "(", repr_str)
-        repr_str = re.sub(r"\) (?=\))", ")", repr_str)
-        self.assertEqual(repr(obj), repr_str.strip())
-
     def test_ce_default(self):
         f = Fragment()
         f.add_statements(
diff --git a/nmigen/test/tools.py b/nmigen/test/tools.py
new file mode 100644 (file)
index 0000000..9e3a8f0
--- /dev/null
@@ -0,0 +1,16 @@
+import re
+import unittest
+
+from ..fhdl.ast import *
+
+
+__all__ = ["FHDLTestCase"]
+
+
+class FHDLTestCase(unittest.TestCase):
+    def assertRepr(self, obj, repr_str):
+        obj = Statement.wrap(obj)
+        repr_str = re.sub(r"\s+",   " ",  repr_str)
+        repr_str = re.sub(r"\( (?=\()", "(", repr_str)
+        repr_str = re.sub(r"\) (?=\))", ")", repr_str)
+        self.assertEqual(repr(obj), repr_str.strip())