fhdl.dsl: add tests for d.comb/d.sync, If/Elif/Else.
authorwhitequark <whitequark@whitequark.org>
Thu, 13 Dec 2018 06:06:51 +0000 (06:06 +0000)
committerwhitequark <whitequark@whitequark.org>
Thu, 13 Dec 2018 06:06:51 +0000 (06:06 +0000)
nmigen/fhdl/ast.py
nmigen/fhdl/dsl.py
nmigen/test/test_fhdl.py [deleted file]
nmigen/test/test_fhdl_dsl.py [new file with mode: 0644]
nmigen/test/test_fhdl_values.py [new file with mode: 0644]

index 2e1942b1afaba0746a0235490a89c45bbc234ac8..604bb9e29d4f3654ae0cd2a94fc36749035ef294 100644 (file)
@@ -606,14 +606,19 @@ class ResetSignal(Value):
         return "(reset {})".format(self.domain)
 
 
+class _StatementList(list):
+    def __repr__(self):
+        return "({})".format(" ".join(map(repr, self)))
+
+
 class Statement:
     @staticmethod
     def wrap(obj):
         if isinstance(obj, Iterable):
-            return sum((Statement.wrap(e) for e in obj), [])
+            return _StatementList(sum((Statement.wrap(e) for e in obj), []))
         else:
             if isinstance(obj, Statement):
-                return [obj]
+                return _StatementList([obj])
             else:
                 raise TypeError("Object {!r} is not a Migen statement".format(obj))
 
index db04892ce98b3eec580bd3b09e2b9c6cecc1d47e..0c0be8cf6786a3c64605c5f13f9b82bc4e9bc441 100644 (file)
@@ -1,11 +1,16 @@
 from collections import OrderedDict
+from contextlib import contextmanager
 
 from .ast import *
 from .ir import *
 from .xfrm import *
 
 
-__all__ = ["Module"]
+__all__ = ["Module", "SyntaxError"]
+
+
+class SyntaxError(Exception):
+    pass
 
 
 class _ModuleBuilderProxy:
@@ -36,9 +41,11 @@ class _ModuleBuilderDomains(_ModuleBuilderProxy):
         return self.__getattr__(name)
 
     def __setattr__(self, name, value):
-        if not isinstance(value, _ModuleBuilderDomain):
-            raise AttributeError("Cannot assign d.{} attribute - use += instead"
-                                 .format(name))
+        if name == "_depth":
+            object.__setattr__(self, name, value)
+        elif not isinstance(value, _ModuleBuilderDomain):
+            raise AttributeError("Cannot assign 'd.{}' attribute; did you mean 'd.{} +='?"
+                                 .format(name, name))
 
     def __setitem__(self, name, value):
         return self.__setattr__(name, value)
@@ -57,59 +64,6 @@ class _ModuleBuilderRoot:
                              .format(type(self).__name__, name))
 
 
-class _ModuleBuilderIf(_ModuleBuilderRoot):
-    def __init__(self, builder, depth, cond):
-        super().__init__(builder, depth)
-        self._cond = cond
-
-    def __enter__(self):
-        self._builder._flush()
-        self._builder._stmt_if_cond.append(self._cond)
-        self._outer_case = self._builder._statements
-        self._builder._statements = []
-        return self
-
-    def __exit__(self, *args):
-        self._builder._stmt_if_bodies.append(self._builder._statements)
-        self._builder._statements = self._outer_case
-
-
-class _ModuleBuilderElif(_ModuleBuilderRoot):
-    def __init__(self, builder, depth, cond):
-        super().__init__(builder, depth)
-        self._cond = cond
-
-    def __enter__(self):
-        if not self._builder._stmt_if_cond:
-            raise ValueError("Elif without preceding If")
-        self._builder._stmt_if_cond.append(self._cond)
-        self._outer_case = self._builder._statements
-        self._builder._statements = []
-        return self
-
-    def __exit__(self, *args):
-        self._builder._stmt_if_bodies.append(self._builder._statements)
-        self._builder._statements = self._outer_case
-
-
-class _ModuleBuilderElse(_ModuleBuilderRoot):
-    def __init__(self, builder, depth):
-        super().__init__(builder, depth)
-
-    def __enter__(self):
-        if not self._builder._stmt_if_cond:
-            raise ValueError("Else without preceding If/Elif")
-        self._builder._stmt_if_cond.append(1)
-        self._outer_case = self._builder._statements
-        self._builder._statements = []
-        return self
-
-    def __exit__(self, *args):
-        self._builder._stmt_if_bodies.append(self._builder._statements)
-        self._builder._statements = self._outer_case
-        self._builder._flush()
-
-
 class _ModuleBuilderCase(_ModuleBuilderRoot):
     def __init__(self, builder, depth, test, value):
         super().__init__(builder, depth)
@@ -120,8 +74,8 @@ class _ModuleBuilderCase(_ModuleBuilderRoot):
         if self._value is None:
             self._value = "-" * len(self._test)
         if isinstance(self._value, str) and len(self._test) != len(self._value):
-            raise ValueError("Case value {} must have the same width as test {}"
-                             .format(self._value, self._test))
+            raise SyntaxError("Case value {} must have the same width as test {}"
+                              .format(self._value, self._test))
         if self._builder._stmt_switch_test != ValueKey(self._test):
             self._builder._flush()
             self._builder._stmt_switch_test = ValueKey(self._test)
@@ -154,21 +108,56 @@ class Module(_ModuleBuilderRoot):
 
         self._submodules        = []
         self._driving           = ValueDict()
-        self._statements        = []
+        self._statements        = Statement.wrap([])
         self._stmt_depth        = 0
         self._stmt_if_cond      = []
         self._stmt_if_bodies    = []
         self._stmt_switch_test  = None
         self._stmt_switch_cases = OrderedDict()
 
+    @contextmanager
     def If(self, cond):
-        return _ModuleBuilderIf(self, self._stmt_depth + 1, cond)
-
+        self._flush()
+        try:
+            _outer_case = self._statements
+            self._statements = []
+            self.domain._depth += 1
+            yield
+            self._stmt_if_cond.append(cond)
+            self._stmt_if_bodies.append(self._statements)
+        finally:
+            self.domain._depth -= 1
+            self._statements = _outer_case
+
+    @contextmanager
     def Elif(self, cond):
-        return _ModuleBuilderElif(self, self._stmt_depth + 1, cond)
-
+        if not self._stmt_if_cond:
+            raise SyntaxError("Elif without preceding If")
+        try:
+            _outer_case = self._statements
+            self._statements = []
+            self.domain._depth += 1
+            yield
+            self._stmt_if_cond.append(cond)
+            self._stmt_if_bodies.append(self._statements)
+        finally:
+            self.domain._depth -= 1
+            self._statements = _outer_case
+
+    @contextmanager
     def Else(self):
-        return _ModuleBuilderElse(self, self._stmt_depth + 1)
+        if not self._stmt_if_cond:
+            raise SyntaxError("Else without preceding If/Elif")
+        try:
+            _outer_case = self._statements
+            self._statements = []
+            self.domain._depth += 1
+            yield
+            self._stmt_if_bodies.append(self._statements)
+        finally:
+            self.domain._depth -= 1
+            self._statements = _outer_case
+        self._flush()
 
     def Case(self, test, value=None):
         return _ModuleBuilderCase(self, self._stmt_depth + 1, test, value)
@@ -176,13 +165,17 @@ class Module(_ModuleBuilderRoot):
     def _flush(self):
         if self._stmt_if_cond:
             tests, cases = [], OrderedDict()
-            for if_cond, if_case in zip(self._stmt_if_cond, self._stmt_if_bodies):
-                if_cond = Value.wrap(if_cond)
-                if len(if_cond) != 1:
-                    if_cond = if_cond.bool()
-                tests.append(if_cond)
-
-                match = ("1" + "-" * (len(tests) - 1)).rjust(len(self._stmt_if_cond), "-")
+            for if_cond, if_case in zip(self._stmt_if_cond + [None], self._stmt_if_bodies):
+                if if_cond is not None:
+                    if_cond = Value.wrap(if_cond)
+                    if len(if_cond) != 1:
+                        if_cond = if_cond.bool()
+                    tests.append(if_cond)
+
+                if if_cond is not None:
+                    match = ("1" + "-" * (len(tests) - 1)).rjust(len(self._stmt_if_cond), "-")
+                else:
+                    match = "-" * len(tests)
                 cases[match] = if_case
             self._statements.append(Switch(Cat(tests), cases))
 
@@ -207,24 +200,25 @@ class Module(_ModuleBuilderRoot):
 
         for assign in Statement.wrap(assigns):
             if not compat_mode and not isinstance(assign, Assign):
-                raise TypeError("Only assignments can be appended to {}"
-                                .format(cd_human_name(cd_name)))
+                raise SyntaxError(
+                    "Only assignments may be appended to d.{}"
+                    .format(cd_human_name(cd_name)))
 
             for signal in assign._lhs_signals():
                 if signal not in self._driving:
                     self._driving[signal] = cd_name
                 elif self._driving[signal] != cd_name:
                     cd_curr = self._driving[signal]
-                    raise ValueError("Driver-driver conflict: trying to drive {!r} from d.{}, but "
-                                     "it is already driven from d.{}"
-                                     .format(signal, cd_human_name(cd_name),
-                                             cd_human_name(cd_curr)))
+                    raise SyntaxError(
+                        "Driver-driver conflict: trying to drive {!r} from d.{}, but it is "
+                        "already driven from d.{}"
+                        .format(signal, cd_human_name(cd_name), cd_human_name(cd_curr)))
 
             self._statements.append(assign)
 
     def _add_submodule(self, submodule, name=None):
         if not hasattr(submodule, "get_fragment"):
-            raise TypeError("Trying to add {!r}, which does not have .get_fragment(), as "
+            raise TypeError("Trying to add {!r}, which does not implement .get_fragment(), as "
                             "a submodule".format(submodule))
         self._submodules.append((submodule, name))
 
@@ -236,8 +230,7 @@ class Module(_ModuleBuilderRoot):
             fragment.add_subfragment(submodule.get_fragment(platform), name)
         fragment.add_statements(self._statements)
         for signal, cd_name in self._driving.items():
-            for lhs_signal in signal._lhs_signals():
-                fragment.drive(lhs_signal, cd_name)
+            fragment.drive(signal, cd_name)
         return fragment
 
     get_fragment = lower
diff --git a/nmigen/test/test_fhdl.py b/nmigen/test/test_fhdl.py
deleted file mode 100644 (file)
index 16e0970..0000000
+++ /dev/null
@@ -1,358 +0,0 @@
-import unittest
-
-from nmigen.fhdl.ast import *
-
-
-class ValueTestCase(unittest.TestCase):
-    def test_wrap(self):
-        self.assertIsInstance(Value.wrap(0), Const)
-        self.assertIsInstance(Value.wrap(True), Const)
-        c = Const(0)
-        self.assertIs(Value.wrap(c), c)
-        with self.assertRaises(TypeError):
-            Value.wrap("str")
-
-    def test_bool(self):
-        with self.assertRaises(TypeError):
-            if Const(0):
-                pass
-
-    def test_len(self):
-        self.assertEqual(len(Const(10)), 4)
-
-    def test_getitem_int(self):
-        s1 = Const(10)[0]
-        self.assertIsInstance(s1, Slice)
-        self.assertEqual(s1.start, 0)
-        self.assertEqual(s1.end, 1)
-        s2 = Const(10)[-1]
-        self.assertIsInstance(s2, Slice)
-        self.assertEqual(s2.start, 3)
-        self.assertEqual(s2.end, 4)
-        with self.assertRaises(IndexError):
-            Const(10)[5]
-
-    def test_getitem_slice(self):
-        s1 = Const(10)[1:3]
-        self.assertIsInstance(s1, Slice)
-        self.assertEqual(s1.start, 1)
-        self.assertEqual(s1.end, 3)
-        s2 = Const(10)[1:-2]
-        self.assertIsInstance(s2, Slice)
-        self.assertEqual(s2.start, 1)
-        self.assertEqual(s2.end, 2)
-        s3 = Const(31)[::2]
-        self.assertIsInstance(s3, Cat)
-        self.assertIsInstance(s3.operands[0], Slice)
-        self.assertEqual(s3.operands[0].start, 0)
-        self.assertEqual(s3.operands[0].end, 1)
-        self.assertIsInstance(s3.operands[1], Slice)
-        self.assertEqual(s3.operands[1].start, 2)
-        self.assertEqual(s3.operands[1].end, 3)
-        self.assertIsInstance(s3.operands[2], Slice)
-        self.assertEqual(s3.operands[2].start, 4)
-        self.assertEqual(s3.operands[2].end, 5)
-
-    def test_getitem_wrong(self):
-        with self.assertRaises(TypeError):
-            Const(31)["str"]
-
-
-class ConstTestCase(unittest.TestCase):
-    def test_shape(self):
-        self.assertEqual(Const(0).shape(),   (0, False))
-        self.assertEqual(Const(1).shape(),   (1, False))
-        self.assertEqual(Const(10).shape(),  (4, False))
-        self.assertEqual(Const(-10).shape(), (4, True))
-
-        self.assertEqual(Const(1, 4).shape(),         (4, False))
-        self.assertEqual(Const(1, (4, True)).shape(), (4, True))
-
-        with self.assertRaises(TypeError):
-            Const(1, -1)
-
-    def test_value(self):
-        self.assertEqual(Const(10).value, 10)
-
-    def test_repr(self):
-        self.assertEqual(repr(Const(10)),  "(const 4'd10)")
-        self.assertEqual(repr(Const(-10)), "(const 4'sd-10)")
-
-    def test_hash(self):
-        with self.assertRaises(TypeError):
-            hash(Const(0))
-
-
-class OperatorTestCase(unittest.TestCase):
-    def test_invert(self):
-        v = ~Const(0, 4)
-        self.assertEqual(repr(v), "(~ (const 4'd0))")
-        self.assertEqual(v.shape(), (4, False))
-
-    def test_neg(self):
-        v1 = -Const(0, (4, False))
-        self.assertEqual(repr(v1), "(- (const 4'd0))")
-        self.assertEqual(v1.shape(), (5, True))
-        v2 = -Const(0, (4, True))
-        self.assertEqual(repr(v2), "(- (const 4'sd0))")
-        self.assertEqual(v2.shape(), (4, True))
-
-    def test_add(self):
-        v1 = Const(0, (4, False)) + Const(0, (6, False))
-        self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (7, False))
-        v2 = Const(0, (4, True)) + Const(0, (6, True))
-        self.assertEqual(v2.shape(), (7, True))
-        v3 = Const(0, (4, True)) + Const(0, (4, False))
-        self.assertEqual(v3.shape(), (6, True))
-        v4 = Const(0, (4, False)) + Const(0, (4, True))
-        self.assertEqual(v4.shape(), (6, True))
-        v5 = 10 + Const(0, 4)
-        self.assertEqual(v5.shape(), (5, False))
-
-    def test_sub(self):
-        v1 = Const(0, (4, False)) - Const(0, (6, False))
-        self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (7, False))
-        v2 = Const(0, (4, True)) - Const(0, (6, True))
-        self.assertEqual(v2.shape(), (7, True))
-        v3 = Const(0, (4, True)) - Const(0, (4, False))
-        self.assertEqual(v3.shape(), (6, True))
-        v4 = Const(0, (4, False)) - Const(0, (4, True))
-        self.assertEqual(v4.shape(), (6, True))
-        v5 = 10 - Const(0, 4)
-        self.assertEqual(v5.shape(), (5, False))
-
-    def test_mul(self):
-        v1 = Const(0, (4, False)) * Const(0, (6, False))
-        self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (10, False))
-        v2 = Const(0, (4, True)) * Const(0, (6, True))
-        self.assertEqual(v2.shape(), (9, True))
-        v3 = Const(0, (4, True)) * Const(0, (4, False))
-        self.assertEqual(v3.shape(), (8, True))
-        v5 = 10 * Const(0, 4)
-        self.assertEqual(v5.shape(), (8, False))
-
-    def test_and(self):
-        v1 = Const(0, (4, False)) & Const(0, (6, False))
-        self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Const(0, (4, True)) & Const(0, (6, True))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Const(0, (4, True)) & Const(0, (4, False))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Const(0, (4, False)) & Const(0, (4, True))
-        self.assertEqual(v4.shape(), (5, True))
-        v5 = 10 & Const(0, 4)
-        self.assertEqual(v5.shape(), (4, False))
-
-    def test_or(self):
-        v1 = Const(0, (4, False)) | Const(0, (6, False))
-        self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Const(0, (4, True)) | Const(0, (6, True))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Const(0, (4, True)) | Const(0, (4, False))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Const(0, (4, False)) | Const(0, (4, True))
-        self.assertEqual(v4.shape(), (5, True))
-        v5 = 10 | Const(0, 4)
-        self.assertEqual(v5.shape(), (4, False))
-
-    def test_xor(self):
-        v1 = Const(0, (4, False)) ^ Const(0, (6, False))
-        self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Const(0, (4, True)) ^ Const(0, (6, True))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Const(0, (4, True)) ^ Const(0, (4, False))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Const(0, (4, False)) ^ Const(0, (4, True))
-        self.assertEqual(v4.shape(), (5, True))
-        v5 = 10 ^ Const(0, 4)
-        self.assertEqual(v5.shape(), (4, False))
-
-    def test_lt(self):
-        v = Const(0, 4) < Const(0, 6)
-        self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_le(self):
-        v = Const(0, 4) <= Const(0, 6)
-        self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_gt(self):
-        v = Const(0, 4) > Const(0, 6)
-        self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_ge(self):
-        v = Const(0, 4) >= Const(0, 6)
-        self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_eq(self):
-        v = Const(0, 4) == Const(0, 6)
-        self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_ne(self):
-        v = Const(0, 4) != Const(0, 6)
-        self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_mux(self):
-        s  = Const(0)
-        v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False)))
-        self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))")
-        self.assertEqual(v1.shape(), (6, False))
-        v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True)))
-        self.assertEqual(v2.shape(), (6, True))
-        v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False)))
-        self.assertEqual(v3.shape(), (5, True))
-        v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True)))
-        self.assertEqual(v4.shape(), (5, True))
-
-    def test_bool(self):
-        v = Const(0).bool()
-        self.assertEqual(repr(v), "(b (const 0'd0))")
-        self.assertEqual(v.shape(), (1, False))
-
-    def test_hash(self):
-        with self.assertRaises(TypeError):
-            hash(Const(0) + Const(0))
-
-
-class SliceTestCase(unittest.TestCase):
-    def test_shape(self):
-        s1 = Const(10)[2]
-        self.assertEqual(s1.shape(), (1, False))
-        s2 = Const(-10)[0:2]
-        self.assertEqual(s2.shape(), (2, False))
-
-    def test_repr(self):
-        s1 = Const(10)[2]
-        self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)")
-
-
-class CatTestCase(unittest.TestCase):
-    def test_shape(self):
-        c1 = Cat(Const(10))
-        self.assertEqual(c1.shape(), (4, False))
-        c2 = Cat(Const(10), Const(1))
-        self.assertEqual(c2.shape(), (5, False))
-        c3 = Cat(Const(10), Const(1), Const(0))
-        self.assertEqual(c3.shape(), (5, False))
-
-    def test_repr(self):
-        c1 = Cat(Const(10), Const(1))
-        self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))")
-
-
-class ReplTestCase(unittest.TestCase):
-    def test_shape(self):
-        r1 = Repl(Const(10), 3)
-        self.assertEqual(r1.shape(), (12, False))
-
-    def test_count_wrong(self):
-        with self.assertRaises(TypeError):
-            Repl(Const(10), -1)
-        with self.assertRaises(TypeError):
-            Repl(Const(10), "str")
-
-    def test_repr(self):
-        r1 = Repl(Const(10), 3)
-        self.assertEqual(repr(r1), "(repl (const 4'd10) 3)")
-
-
-class SignalTestCase(unittest.TestCase):
-    def test_shape(self):
-        s1 = Signal()
-        self.assertEqual(s1.shape(), (1, False))
-        s2 = Signal(2)
-        self.assertEqual(s2.shape(), (2, False))
-        s3 = Signal((2, False))
-        self.assertEqual(s3.shape(), (2, False))
-        s4 = Signal((2, True))
-        self.assertEqual(s4.shape(), (2, True))
-        s5 = Signal(max=16)
-        self.assertEqual(s5.shape(), (4, False))
-        s6 = Signal(min=4, max=16)
-        self.assertEqual(s6.shape(), (4, False))
-        s7 = Signal(min=-4, max=16)
-        self.assertEqual(s7.shape(), (5, True))
-        s8 = Signal(min=-20, max=16)
-        self.assertEqual(s8.shape(), (6, True))
-
-        with self.assertRaises(ValueError):
-            Signal(min=10, max=4)
-        with self.assertRaises(ValueError):
-            Signal(2, min=10)
-        with self.assertRaises(TypeError):
-            Signal(-10)
-
-    def test_name(self):
-        s1 = Signal()
-        self.assertEqual(s1.name, "s1")
-        s2 = Signal(name="sig")
-        self.assertEqual(s2.name, "sig")
-
-    def test_reset(self):
-        s1 = Signal(4, reset=0b111, reset_less=True)
-        self.assertEqual(s1.reset, 0b111)
-        self.assertEqual(s1.reset_less, True)
-
-    def test_attrs(self):
-        s1 = Signal()
-        self.assertEqual(s1.attrs, {})
-        s2 = Signal(attrs={"no_retiming": True})
-        self.assertEqual(s2.attrs, {"no_retiming": True})
-
-    def test_repr(self):
-        s1 = Signal()
-        self.assertEqual(repr(s1), "(sig s1)")
-
-    def test_like(self):
-        s1 = Signal.like(Signal(4))
-        self.assertEqual(s1.shape(), (4, False))
-        s2 = Signal.like(Signal(min=-15))
-        self.assertEqual(s2.shape(), (5, True))
-        s3 = Signal.like(Signal(4, reset=0b111, reset_less=True))
-        self.assertEqual(s3.reset, 0b111)
-        self.assertEqual(s3.reset_less, True)
-        s4 = Signal.like(Signal(attrs={"no_retiming": True}))
-        self.assertEqual(s4.attrs, {"no_retiming": True})
-        s5 = Signal.like(10)
-        self.assertEqual(s5.shape(), (4, False))
-
-
-class ClockSignalTestCase(unittest.TestCase):
-    def test_domain(self):
-        s1 = ClockSignal()
-        self.assertEqual(s1.domain, "sync")
-        s2 = ClockSignal("pix")
-        self.assertEqual(s2.domain, "pix")
-
-        with self.assertRaises(TypeError):
-            ClockSignal(1)
-
-    def test_repr(self):
-        s1 = ClockSignal()
-        self.assertEqual(repr(s1), "(clk sync)")
-
-
-class ResetSignalTestCase(unittest.TestCase):
-    def test_domain(self):
-        s1 = ResetSignal()
-        self.assertEqual(s1.domain, "sync")
-        s2 = ResetSignal("pix")
-        self.assertEqual(s2.domain, "pix")
-
-        with self.assertRaises(TypeError):
-            ResetSignal(1)
-
-    def test_repr(self):
-        s1 = ResetSignal()
-        self.assertEqual(repr(s1), "(reset sync)")
diff --git a/nmigen/test/test_fhdl_dsl.py b/nmigen/test/test_fhdl_dsl.py
new file mode 100644 (file)
index 0000000..55e494c
--- /dev/null
@@ -0,0 +1,200 @@
+import re
+import unittest
+from contextlib import contextmanager
+
+from nmigen.fhdl.ast import *
+from nmigen.fhdl.dsl import *
+
+
+class DSLTestCase(unittest.TestCase):
+    def setUp(self):
+        self.s1 = Signal()
+        self.s2 = Signal()
+        self.s3 = Signal()
+        self.s4 = Signal()
+        self.c1 = Signal()
+        self.c2 = Signal()
+        self.c3 = Signal()
+        self.w1 = Signal(4)
+
+    @contextmanager
+    def assertRaises(self, exception, msg=None):
+        with super().assertRaises(exception) as cm:
+            yield
+        if msg:
+            # WTF? unittest.assertRaises is completely broken.
+            self.assertEqual(str(cm.exception), msg)
+
+    def assertRepr(self, obj, repr_str):
+        repr_str = re.sub(r"\s+",   " ",  repr_str)
+        repr_str = re.sub(r"\( (?=\()", "(", repr_str)
+        repr_str = re.sub(r"\) (?=\))", ")", repr_str)
+        self.assertEqual(repr(obj), repr_str.strip())
+
+    def test_d_comb(self):
+        m = Module()
+        m.d.comb += self.c1.eq(1)
+        m._flush()
+        self.assertEqual(m._driving[self.c1], None)
+        self.assertRepr(m._statements, """(
+            (eq (sig c1) (const 1'd1))
+        )""")
+
+    def test_d_sync(self):
+        m = Module()
+        m.d.sync += self.c1.eq(1)
+        m._flush()
+        self.assertEqual(m._driving[self.c1], "sync")
+        self.assertRepr(m._statements, """(
+            (eq (sig c1) (const 1'd1))
+        )""")
+
+    def test_d_pix(self):
+        m = Module()
+        m.d.pix += self.c1.eq(1)
+        m._flush()
+        self.assertEqual(m._driving[self.c1], "pix")
+        self.assertRepr(m._statements, """(
+            (eq (sig c1) (const 1'd1))
+        )""")
+
+    def test_d_index(self):
+        m = Module()
+        m.d["pix"] += self.c1.eq(1)
+        m._flush()
+        self.assertEqual(m._driving[self.c1], "pix")
+        self.assertRepr(m._statements, """(
+            (eq (sig c1) (const 1'd1))
+        )""")
+
+    def test_d_no_conflict(self):
+        m = Module()
+        m.d.comb += self.w1[0].eq(1)
+        m.d.comb += self.w1[1].eq(1)
+
+    def test_d_conflict(self):
+        m = Module()
+        with self.assertRaises(SyntaxError,
+                msg="Driver-driver conflict: trying to drive (sig c1) from d.sync, but it "
+                    "is already driven from d.comb"):
+            m.d.comb += self.c1.eq(1)
+            m.d.sync += self.c1.eq(1)
+
+    def test_d_wrong(self):
+        m = Module()
+        with self.assertRaises(AttributeError,
+                msg="Cannot assign 'd.pix' attribute; did you mean 'd.pix +='?"):
+            m.d.pix = None
+
+    def test_d_asgn_wrong(self):
+        m = Module()
+        with self.assertRaises(SyntaxError,
+                msg="Only assignments may be appended to d.sync"):
+            m.d.sync += Switch(self.s1, {})
+
+    def test_comb_wrong(self):
+        m = Module()
+        with self.assertRaises(AttributeError,
+                msg="'Module' object has no attribute 'comb'; did you mean 'd.comb'?"):
+            m.comb += self.c1.eq(1)
+
+    def test_sync_wrong(self):
+        m = Module()
+        with self.assertRaises(AttributeError,
+                msg="'Module' object has no attribute 'sync'; did you mean 'd.sync'?"):
+            m.sync += self.c1.eq(1)
+
+    def test_attr_wrong(self):
+        m = Module()
+        with self.assertRaises(AttributeError,
+                msg="'Module' object has no attribute 'nonexistentattr'"):
+            m.nonexistentattr
+
+    def test_If(self):
+        m = Module()
+        with m.If(self.s1):
+            m.d.comb += self.c1.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (sig s1))
+                (case 1 (eq (sig c1) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_If_Elif(self):
+        m = Module()
+        with m.If(self.s1):
+            m.d.comb += self.c1.eq(1)
+        with m.Elif(self.s2):
+            m.d.sync += self.c2.eq(0)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (sig s1) (sig s2))
+                (case -1 (eq (sig c1) (const 1'd1)))
+                (case 1- (eq (sig c2) (const 0'd0)))
+            )
+        )
+        """)
+
+    def test_If_Elif_Else(self):
+        m = Module()
+        with m.If(self.s1):
+            m.d.comb += self.c1.eq(1)
+        with m.Elif(self.s2):
+            m.d.sync += self.c2.eq(0)
+        with m.Else():
+            m.d.comb += self.c3.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (sig s1) (sig s2))
+                (case -1 (eq (sig c1) (const 1'd1)))
+                (case 1- (eq (sig c2) (const 0'd0)))
+                (case -- (eq (sig c3) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_Elif_wrong(self):
+        m = Module()
+        with self.assertRaises(SyntaxError,
+                msg="Elif without preceding If"):
+            with m.Elif(self.s2):
+                pass
+
+    def test_Else_wrong(self):
+        m = Module()
+        with self.assertRaises(SyntaxError,
+                msg="Else without preceding If/Elif"):
+            with m.Else():
+                pass
+
+    def test_If_wide(self):
+        m = Module()
+        with m.If(self.w1):
+            m.d.comb += self.c1.eq(1)
+        m._flush()
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (b (sig w1)))
+                (case 1 (eq (sig c1) (const 1'd1)))
+            )
+        )
+        """)
+
+    def test_auto_flush(self):
+        m = Module()
+        with m.If(self.w1):
+            m.d.comb += self.c1.eq(1)
+        m.d.comb += self.c2.eq(1)
+        self.assertRepr(m._statements, """
+        (
+            (switch (cat (b (sig w1)))
+                (case 1 (eq (sig c1) (const 1'd1)))
+            )
+            (eq (sig c2) (const 1'd1))
+        )
+        """)
diff --git a/nmigen/test/test_fhdl_values.py b/nmigen/test/test_fhdl_values.py
new file mode 100644 (file)
index 0000000..16e0970
--- /dev/null
@@ -0,0 +1,358 @@
+import unittest
+
+from nmigen.fhdl.ast import *
+
+
+class ValueTestCase(unittest.TestCase):
+    def test_wrap(self):
+        self.assertIsInstance(Value.wrap(0), Const)
+        self.assertIsInstance(Value.wrap(True), Const)
+        c = Const(0)
+        self.assertIs(Value.wrap(c), c)
+        with self.assertRaises(TypeError):
+            Value.wrap("str")
+
+    def test_bool(self):
+        with self.assertRaises(TypeError):
+            if Const(0):
+                pass
+
+    def test_len(self):
+        self.assertEqual(len(Const(10)), 4)
+
+    def test_getitem_int(self):
+        s1 = Const(10)[0]
+        self.assertIsInstance(s1, Slice)
+        self.assertEqual(s1.start, 0)
+        self.assertEqual(s1.end, 1)
+        s2 = Const(10)[-1]
+        self.assertIsInstance(s2, Slice)
+        self.assertEqual(s2.start, 3)
+        self.assertEqual(s2.end, 4)
+        with self.assertRaises(IndexError):
+            Const(10)[5]
+
+    def test_getitem_slice(self):
+        s1 = Const(10)[1:3]
+        self.assertIsInstance(s1, Slice)
+        self.assertEqual(s1.start, 1)
+        self.assertEqual(s1.end, 3)
+        s2 = Const(10)[1:-2]
+        self.assertIsInstance(s2, Slice)
+        self.assertEqual(s2.start, 1)
+        self.assertEqual(s2.end, 2)
+        s3 = Const(31)[::2]
+        self.assertIsInstance(s3, Cat)
+        self.assertIsInstance(s3.operands[0], Slice)
+        self.assertEqual(s3.operands[0].start, 0)
+        self.assertEqual(s3.operands[0].end, 1)
+        self.assertIsInstance(s3.operands[1], Slice)
+        self.assertEqual(s3.operands[1].start, 2)
+        self.assertEqual(s3.operands[1].end, 3)
+        self.assertIsInstance(s3.operands[2], Slice)
+        self.assertEqual(s3.operands[2].start, 4)
+        self.assertEqual(s3.operands[2].end, 5)
+
+    def test_getitem_wrong(self):
+        with self.assertRaises(TypeError):
+            Const(31)["str"]
+
+
+class ConstTestCase(unittest.TestCase):
+    def test_shape(self):
+        self.assertEqual(Const(0).shape(),   (0, False))
+        self.assertEqual(Const(1).shape(),   (1, False))
+        self.assertEqual(Const(10).shape(),  (4, False))
+        self.assertEqual(Const(-10).shape(), (4, True))
+
+        self.assertEqual(Const(1, 4).shape(),         (4, False))
+        self.assertEqual(Const(1, (4, True)).shape(), (4, True))
+
+        with self.assertRaises(TypeError):
+            Const(1, -1)
+
+    def test_value(self):
+        self.assertEqual(Const(10).value, 10)
+
+    def test_repr(self):
+        self.assertEqual(repr(Const(10)),  "(const 4'd10)")
+        self.assertEqual(repr(Const(-10)), "(const 4'sd-10)")
+
+    def test_hash(self):
+        with self.assertRaises(TypeError):
+            hash(Const(0))
+
+
+class OperatorTestCase(unittest.TestCase):
+    def test_invert(self):
+        v = ~Const(0, 4)
+        self.assertEqual(repr(v), "(~ (const 4'd0))")
+        self.assertEqual(v.shape(), (4, False))
+
+    def test_neg(self):
+        v1 = -Const(0, (4, False))
+        self.assertEqual(repr(v1), "(- (const 4'd0))")
+        self.assertEqual(v1.shape(), (5, True))
+        v2 = -Const(0, (4, True))
+        self.assertEqual(repr(v2), "(- (const 4'sd0))")
+        self.assertEqual(v2.shape(), (4, True))
+
+    def test_add(self):
+        v1 = Const(0, (4, False)) + Const(0, (6, False))
+        self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (7, False))
+        v2 = Const(0, (4, True)) + Const(0, (6, True))
+        self.assertEqual(v2.shape(), (7, True))
+        v3 = Const(0, (4, True)) + Const(0, (4, False))
+        self.assertEqual(v3.shape(), (6, True))
+        v4 = Const(0, (4, False)) + Const(0, (4, True))
+        self.assertEqual(v4.shape(), (6, True))
+        v5 = 10 + Const(0, 4)
+        self.assertEqual(v5.shape(), (5, False))
+
+    def test_sub(self):
+        v1 = Const(0, (4, False)) - Const(0, (6, False))
+        self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (7, False))
+        v2 = Const(0, (4, True)) - Const(0, (6, True))
+        self.assertEqual(v2.shape(), (7, True))
+        v3 = Const(0, (4, True)) - Const(0, (4, False))
+        self.assertEqual(v3.shape(), (6, True))
+        v4 = Const(0, (4, False)) - Const(0, (4, True))
+        self.assertEqual(v4.shape(), (6, True))
+        v5 = 10 - Const(0, 4)
+        self.assertEqual(v5.shape(), (5, False))
+
+    def test_mul(self):
+        v1 = Const(0, (4, False)) * Const(0, (6, False))
+        self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (10, False))
+        v2 = Const(0, (4, True)) * Const(0, (6, True))
+        self.assertEqual(v2.shape(), (9, True))
+        v3 = Const(0, (4, True)) * Const(0, (4, False))
+        self.assertEqual(v3.shape(), (8, True))
+        v5 = 10 * Const(0, 4)
+        self.assertEqual(v5.shape(), (8, False))
+
+    def test_and(self):
+        v1 = Const(0, (4, False)) & Const(0, (6, False))
+        self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Const(0, (4, True)) & Const(0, (6, True))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Const(0, (4, True)) & Const(0, (4, False))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Const(0, (4, False)) & Const(0, (4, True))
+        self.assertEqual(v4.shape(), (5, True))
+        v5 = 10 & Const(0, 4)
+        self.assertEqual(v5.shape(), (4, False))
+
+    def test_or(self):
+        v1 = Const(0, (4, False)) | Const(0, (6, False))
+        self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Const(0, (4, True)) | Const(0, (6, True))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Const(0, (4, True)) | Const(0, (4, False))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Const(0, (4, False)) | Const(0, (4, True))
+        self.assertEqual(v4.shape(), (5, True))
+        v5 = 10 | Const(0, 4)
+        self.assertEqual(v5.shape(), (4, False))
+
+    def test_xor(self):
+        v1 = Const(0, (4, False)) ^ Const(0, (6, False))
+        self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Const(0, (4, True)) ^ Const(0, (6, True))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Const(0, (4, True)) ^ Const(0, (4, False))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Const(0, (4, False)) ^ Const(0, (4, True))
+        self.assertEqual(v4.shape(), (5, True))
+        v5 = 10 ^ Const(0, 4)
+        self.assertEqual(v5.shape(), (4, False))
+
+    def test_lt(self):
+        v = Const(0, 4) < Const(0, 6)
+        self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_le(self):
+        v = Const(0, 4) <= Const(0, 6)
+        self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_gt(self):
+        v = Const(0, 4) > Const(0, 6)
+        self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_ge(self):
+        v = Const(0, 4) >= Const(0, 6)
+        self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_eq(self):
+        v = Const(0, 4) == Const(0, 6)
+        self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_ne(self):
+        v = Const(0, 4) != Const(0, 6)
+        self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_mux(self):
+        s  = Const(0)
+        v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False)))
+        self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))")
+        self.assertEqual(v1.shape(), (6, False))
+        v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True)))
+        self.assertEqual(v2.shape(), (6, True))
+        v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False)))
+        self.assertEqual(v3.shape(), (5, True))
+        v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True)))
+        self.assertEqual(v4.shape(), (5, True))
+
+    def test_bool(self):
+        v = Const(0).bool()
+        self.assertEqual(repr(v), "(b (const 0'd0))")
+        self.assertEqual(v.shape(), (1, False))
+
+    def test_hash(self):
+        with self.assertRaises(TypeError):
+            hash(Const(0) + Const(0))
+
+
+class SliceTestCase(unittest.TestCase):
+    def test_shape(self):
+        s1 = Const(10)[2]
+        self.assertEqual(s1.shape(), (1, False))
+        s2 = Const(-10)[0:2]
+        self.assertEqual(s2.shape(), (2, False))
+
+    def test_repr(self):
+        s1 = Const(10)[2]
+        self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)")
+
+
+class CatTestCase(unittest.TestCase):
+    def test_shape(self):
+        c1 = Cat(Const(10))
+        self.assertEqual(c1.shape(), (4, False))
+        c2 = Cat(Const(10), Const(1))
+        self.assertEqual(c2.shape(), (5, False))
+        c3 = Cat(Const(10), Const(1), Const(0))
+        self.assertEqual(c3.shape(), (5, False))
+
+    def test_repr(self):
+        c1 = Cat(Const(10), Const(1))
+        self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))")
+
+
+class ReplTestCase(unittest.TestCase):
+    def test_shape(self):
+        r1 = Repl(Const(10), 3)
+        self.assertEqual(r1.shape(), (12, False))
+
+    def test_count_wrong(self):
+        with self.assertRaises(TypeError):
+            Repl(Const(10), -1)
+        with self.assertRaises(TypeError):
+            Repl(Const(10), "str")
+
+    def test_repr(self):
+        r1 = Repl(Const(10), 3)
+        self.assertEqual(repr(r1), "(repl (const 4'd10) 3)")
+
+
+class SignalTestCase(unittest.TestCase):
+    def test_shape(self):
+        s1 = Signal()
+        self.assertEqual(s1.shape(), (1, False))
+        s2 = Signal(2)
+        self.assertEqual(s2.shape(), (2, False))
+        s3 = Signal((2, False))
+        self.assertEqual(s3.shape(), (2, False))
+        s4 = Signal((2, True))
+        self.assertEqual(s4.shape(), (2, True))
+        s5 = Signal(max=16)
+        self.assertEqual(s5.shape(), (4, False))
+        s6 = Signal(min=4, max=16)
+        self.assertEqual(s6.shape(), (4, False))
+        s7 = Signal(min=-4, max=16)
+        self.assertEqual(s7.shape(), (5, True))
+        s8 = Signal(min=-20, max=16)
+        self.assertEqual(s8.shape(), (6, True))
+
+        with self.assertRaises(ValueError):
+            Signal(min=10, max=4)
+        with self.assertRaises(ValueError):
+            Signal(2, min=10)
+        with self.assertRaises(TypeError):
+            Signal(-10)
+
+    def test_name(self):
+        s1 = Signal()
+        self.assertEqual(s1.name, "s1")
+        s2 = Signal(name="sig")
+        self.assertEqual(s2.name, "sig")
+
+    def test_reset(self):
+        s1 = Signal(4, reset=0b111, reset_less=True)
+        self.assertEqual(s1.reset, 0b111)
+        self.assertEqual(s1.reset_less, True)
+
+    def test_attrs(self):
+        s1 = Signal()
+        self.assertEqual(s1.attrs, {})
+        s2 = Signal(attrs={"no_retiming": True})
+        self.assertEqual(s2.attrs, {"no_retiming": True})
+
+    def test_repr(self):
+        s1 = Signal()
+        self.assertEqual(repr(s1), "(sig s1)")
+
+    def test_like(self):
+        s1 = Signal.like(Signal(4))
+        self.assertEqual(s1.shape(), (4, False))
+        s2 = Signal.like(Signal(min=-15))
+        self.assertEqual(s2.shape(), (5, True))
+        s3 = Signal.like(Signal(4, reset=0b111, reset_less=True))
+        self.assertEqual(s3.reset, 0b111)
+        self.assertEqual(s3.reset_less, True)
+        s4 = Signal.like(Signal(attrs={"no_retiming": True}))
+        self.assertEqual(s4.attrs, {"no_retiming": True})
+        s5 = Signal.like(10)
+        self.assertEqual(s5.shape(), (4, False))
+
+
+class ClockSignalTestCase(unittest.TestCase):
+    def test_domain(self):
+        s1 = ClockSignal()
+        self.assertEqual(s1.domain, "sync")
+        s2 = ClockSignal("pix")
+        self.assertEqual(s2.domain, "pix")
+
+        with self.assertRaises(TypeError):
+            ClockSignal(1)
+
+    def test_repr(self):
+        s1 = ClockSignal()
+        self.assertEqual(repr(s1), "(clk sync)")
+
+
+class ResetSignalTestCase(unittest.TestCase):
+    def test_domain(self):
+        s1 = ResetSignal()
+        self.assertEqual(s1.domain, "sync")
+        s2 = ResetSignal("pix")
+        self.assertEqual(s2.domain, "pix")
+
+        with self.assertRaises(TypeError):
+            ResetSignal(1)
+
+    def test_repr(self):
+        s1 = ResetSignal()
+        self.assertEqual(repr(s1), "(reset sync)")