From: whitequark Date: Thu, 13 Dec 2018 06:06:51 +0000 (+0000) Subject: fhdl.dsl: add tests for d.comb/d.sync, If/Elif/Else. X-Git-Tag: working~314 X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=f70ae3bac5c7c8202566efc03620e371c82f3016;p=nmigen.git fhdl.dsl: add tests for d.comb/d.sync, If/Elif/Else. --- diff --git a/nmigen/fhdl/ast.py b/nmigen/fhdl/ast.py index 2e1942b..604bb9e 100644 --- a/nmigen/fhdl/ast.py +++ b/nmigen/fhdl/ast.py @@ -606,14 +606,19 @@ class ResetSignal(Value): return "(reset {})".format(self.domain) +class _StatementList(list): + def __repr__(self): + return "({})".format(" ".join(map(repr, self))) + + class Statement: @staticmethod def wrap(obj): if isinstance(obj, Iterable): - return sum((Statement.wrap(e) for e in obj), []) + return _StatementList(sum((Statement.wrap(e) for e in obj), [])) else: if isinstance(obj, Statement): - return [obj] + return _StatementList([obj]) else: raise TypeError("Object {!r} is not a Migen statement".format(obj)) diff --git a/nmigen/fhdl/dsl.py b/nmigen/fhdl/dsl.py index db04892..0c0be8c 100644 --- a/nmigen/fhdl/dsl.py +++ b/nmigen/fhdl/dsl.py @@ -1,11 +1,16 @@ from collections import OrderedDict +from contextlib import contextmanager from .ast import * from .ir import * from .xfrm import * -__all__ = ["Module"] +__all__ = ["Module", "SyntaxError"] + + +class SyntaxError(Exception): + pass class _ModuleBuilderProxy: @@ -36,9 +41,11 @@ class _ModuleBuilderDomains(_ModuleBuilderProxy): return self.__getattr__(name) def __setattr__(self, name, value): - if not isinstance(value, _ModuleBuilderDomain): - raise AttributeError("Cannot assign d.{} attribute - use += instead" - .format(name)) + if name == "_depth": + object.__setattr__(self, name, value) + elif not isinstance(value, _ModuleBuilderDomain): + raise AttributeError("Cannot assign 'd.{}' attribute; did you mean 'd.{} +='?" + .format(name, name)) def __setitem__(self, name, value): return self.__setattr__(name, value) @@ -57,59 +64,6 @@ class _ModuleBuilderRoot: .format(type(self).__name__, name)) -class _ModuleBuilderIf(_ModuleBuilderRoot): - def __init__(self, builder, depth, cond): - super().__init__(builder, depth) - self._cond = cond - - def __enter__(self): - self._builder._flush() - self._builder._stmt_if_cond.append(self._cond) - self._outer_case = self._builder._statements - self._builder._statements = [] - return self - - def __exit__(self, *args): - self._builder._stmt_if_bodies.append(self._builder._statements) - self._builder._statements = self._outer_case - - -class _ModuleBuilderElif(_ModuleBuilderRoot): - def __init__(self, builder, depth, cond): - super().__init__(builder, depth) - self._cond = cond - - def __enter__(self): - if not self._builder._stmt_if_cond: - raise ValueError("Elif without preceding If") - self._builder._stmt_if_cond.append(self._cond) - self._outer_case = self._builder._statements - self._builder._statements = [] - return self - - def __exit__(self, *args): - self._builder._stmt_if_bodies.append(self._builder._statements) - self._builder._statements = self._outer_case - - -class _ModuleBuilderElse(_ModuleBuilderRoot): - def __init__(self, builder, depth): - super().__init__(builder, depth) - - def __enter__(self): - if not self._builder._stmt_if_cond: - raise ValueError("Else without preceding If/Elif") - self._builder._stmt_if_cond.append(1) - self._outer_case = self._builder._statements - self._builder._statements = [] - return self - - def __exit__(self, *args): - self._builder._stmt_if_bodies.append(self._builder._statements) - self._builder._statements = self._outer_case - self._builder._flush() - - class _ModuleBuilderCase(_ModuleBuilderRoot): def __init__(self, builder, depth, test, value): super().__init__(builder, depth) @@ -120,8 +74,8 @@ class _ModuleBuilderCase(_ModuleBuilderRoot): if self._value is None: self._value = "-" * len(self._test) if isinstance(self._value, str) and len(self._test) != len(self._value): - raise ValueError("Case value {} must have the same width as test {}" - .format(self._value, self._test)) + raise SyntaxError("Case value {} must have the same width as test {}" + .format(self._value, self._test)) if self._builder._stmt_switch_test != ValueKey(self._test): self._builder._flush() self._builder._stmt_switch_test = ValueKey(self._test) @@ -154,21 +108,56 @@ class Module(_ModuleBuilderRoot): self._submodules = [] self._driving = ValueDict() - self._statements = [] + self._statements = Statement.wrap([]) self._stmt_depth = 0 self._stmt_if_cond = [] self._stmt_if_bodies = [] self._stmt_switch_test = None self._stmt_switch_cases = OrderedDict() + @contextmanager def If(self, cond): - return _ModuleBuilderIf(self, self._stmt_depth + 1, cond) - + self._flush() + try: + _outer_case = self._statements + self._statements = [] + self.domain._depth += 1 + yield + self._stmt_if_cond.append(cond) + self._stmt_if_bodies.append(self._statements) + finally: + self.domain._depth -= 1 + self._statements = _outer_case + + @contextmanager def Elif(self, cond): - return _ModuleBuilderElif(self, self._stmt_depth + 1, cond) - + if not self._stmt_if_cond: + raise SyntaxError("Elif without preceding If") + try: + _outer_case = self._statements + self._statements = [] + self.domain._depth += 1 + yield + self._stmt_if_cond.append(cond) + self._stmt_if_bodies.append(self._statements) + finally: + self.domain._depth -= 1 + self._statements = _outer_case + + @contextmanager def Else(self): - return _ModuleBuilderElse(self, self._stmt_depth + 1) + if not self._stmt_if_cond: + raise SyntaxError("Else without preceding If/Elif") + try: + _outer_case = self._statements + self._statements = [] + self.domain._depth += 1 + yield + self._stmt_if_bodies.append(self._statements) + finally: + self.domain._depth -= 1 + self._statements = _outer_case + self._flush() def Case(self, test, value=None): return _ModuleBuilderCase(self, self._stmt_depth + 1, test, value) @@ -176,13 +165,17 @@ class Module(_ModuleBuilderRoot): def _flush(self): if self._stmt_if_cond: tests, cases = [], OrderedDict() - for if_cond, if_case in zip(self._stmt_if_cond, self._stmt_if_bodies): - if_cond = Value.wrap(if_cond) - if len(if_cond) != 1: - if_cond = if_cond.bool() - tests.append(if_cond) - - match = ("1" + "-" * (len(tests) - 1)).rjust(len(self._stmt_if_cond), "-") + for if_cond, if_case in zip(self._stmt_if_cond + [None], self._stmt_if_bodies): + if if_cond is not None: + if_cond = Value.wrap(if_cond) + if len(if_cond) != 1: + if_cond = if_cond.bool() + tests.append(if_cond) + + if if_cond is not None: + match = ("1" + "-" * (len(tests) - 1)).rjust(len(self._stmt_if_cond), "-") + else: + match = "-" * len(tests) cases[match] = if_case self._statements.append(Switch(Cat(tests), cases)) @@ -207,24 +200,25 @@ class Module(_ModuleBuilderRoot): for assign in Statement.wrap(assigns): if not compat_mode and not isinstance(assign, Assign): - raise TypeError("Only assignments can be appended to {}" - .format(cd_human_name(cd_name))) + raise SyntaxError( + "Only assignments may be appended to d.{}" + .format(cd_human_name(cd_name))) for signal in assign._lhs_signals(): if signal not in self._driving: self._driving[signal] = cd_name elif self._driving[signal] != cd_name: cd_curr = self._driving[signal] - raise ValueError("Driver-driver conflict: trying to drive {!r} from d.{}, but " - "it is already driven from d.{}" - .format(signal, cd_human_name(cd_name), - cd_human_name(cd_curr))) + raise SyntaxError( + "Driver-driver conflict: trying to drive {!r} from d.{}, but it is " + "already driven from d.{}" + .format(signal, cd_human_name(cd_name), cd_human_name(cd_curr))) self._statements.append(assign) def _add_submodule(self, submodule, name=None): if not hasattr(submodule, "get_fragment"): - raise TypeError("Trying to add {!r}, which does not have .get_fragment(), as " + raise TypeError("Trying to add {!r}, which does not implement .get_fragment(), as " "a submodule".format(submodule)) self._submodules.append((submodule, name)) @@ -236,8 +230,7 @@ class Module(_ModuleBuilderRoot): fragment.add_subfragment(submodule.get_fragment(platform), name) fragment.add_statements(self._statements) for signal, cd_name in self._driving.items(): - for lhs_signal in signal._lhs_signals(): - fragment.drive(lhs_signal, cd_name) + fragment.drive(signal, cd_name) return fragment get_fragment = lower diff --git a/nmigen/test/test_fhdl.py b/nmigen/test/test_fhdl.py deleted file mode 100644 index 16e0970..0000000 --- a/nmigen/test/test_fhdl.py +++ /dev/null @@ -1,358 +0,0 @@ -import unittest - -from nmigen.fhdl.ast import * - - -class ValueTestCase(unittest.TestCase): - def test_wrap(self): - self.assertIsInstance(Value.wrap(0), Const) - self.assertIsInstance(Value.wrap(True), Const) - c = Const(0) - self.assertIs(Value.wrap(c), c) - with self.assertRaises(TypeError): - Value.wrap("str") - - def test_bool(self): - with self.assertRaises(TypeError): - if Const(0): - pass - - def test_len(self): - self.assertEqual(len(Const(10)), 4) - - def test_getitem_int(self): - s1 = Const(10)[0] - self.assertIsInstance(s1, Slice) - self.assertEqual(s1.start, 0) - self.assertEqual(s1.end, 1) - s2 = Const(10)[-1] - self.assertIsInstance(s2, Slice) - self.assertEqual(s2.start, 3) - self.assertEqual(s2.end, 4) - with self.assertRaises(IndexError): - Const(10)[5] - - def test_getitem_slice(self): - s1 = Const(10)[1:3] - self.assertIsInstance(s1, Slice) - self.assertEqual(s1.start, 1) - self.assertEqual(s1.end, 3) - s2 = Const(10)[1:-2] - self.assertIsInstance(s2, Slice) - self.assertEqual(s2.start, 1) - self.assertEqual(s2.end, 2) - s3 = Const(31)[::2] - self.assertIsInstance(s3, Cat) - self.assertIsInstance(s3.operands[0], Slice) - self.assertEqual(s3.operands[0].start, 0) - self.assertEqual(s3.operands[0].end, 1) - self.assertIsInstance(s3.operands[1], Slice) - self.assertEqual(s3.operands[1].start, 2) - self.assertEqual(s3.operands[1].end, 3) - self.assertIsInstance(s3.operands[2], Slice) - self.assertEqual(s3.operands[2].start, 4) - self.assertEqual(s3.operands[2].end, 5) - - def test_getitem_wrong(self): - with self.assertRaises(TypeError): - Const(31)["str"] - - -class ConstTestCase(unittest.TestCase): - def test_shape(self): - self.assertEqual(Const(0).shape(), (0, False)) - self.assertEqual(Const(1).shape(), (1, False)) - self.assertEqual(Const(10).shape(), (4, False)) - self.assertEqual(Const(-10).shape(), (4, True)) - - self.assertEqual(Const(1, 4).shape(), (4, False)) - self.assertEqual(Const(1, (4, True)).shape(), (4, True)) - - with self.assertRaises(TypeError): - Const(1, -1) - - def test_value(self): - self.assertEqual(Const(10).value, 10) - - def test_repr(self): - self.assertEqual(repr(Const(10)), "(const 4'd10)") - self.assertEqual(repr(Const(-10)), "(const 4'sd-10)") - - def test_hash(self): - with self.assertRaises(TypeError): - hash(Const(0)) - - -class OperatorTestCase(unittest.TestCase): - def test_invert(self): - v = ~Const(0, 4) - self.assertEqual(repr(v), "(~ (const 4'd0))") - self.assertEqual(v.shape(), (4, False)) - - def test_neg(self): - v1 = -Const(0, (4, False)) - self.assertEqual(repr(v1), "(- (const 4'd0))") - self.assertEqual(v1.shape(), (5, True)) - v2 = -Const(0, (4, True)) - self.assertEqual(repr(v2), "(- (const 4'sd0))") - self.assertEqual(v2.shape(), (4, True)) - - def test_add(self): - v1 = Const(0, (4, False)) + Const(0, (6, False)) - self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (7, False)) - v2 = Const(0, (4, True)) + Const(0, (6, True)) - self.assertEqual(v2.shape(), (7, True)) - v3 = Const(0, (4, True)) + Const(0, (4, False)) - self.assertEqual(v3.shape(), (6, True)) - v4 = Const(0, (4, False)) + Const(0, (4, True)) - self.assertEqual(v4.shape(), (6, True)) - v5 = 10 + Const(0, 4) - self.assertEqual(v5.shape(), (5, False)) - - def test_sub(self): - v1 = Const(0, (4, False)) - Const(0, (6, False)) - self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (7, False)) - v2 = Const(0, (4, True)) - Const(0, (6, True)) - self.assertEqual(v2.shape(), (7, True)) - v3 = Const(0, (4, True)) - Const(0, (4, False)) - self.assertEqual(v3.shape(), (6, True)) - v4 = Const(0, (4, False)) - Const(0, (4, True)) - self.assertEqual(v4.shape(), (6, True)) - v5 = 10 - Const(0, 4) - self.assertEqual(v5.shape(), (5, False)) - - def test_mul(self): - v1 = Const(0, (4, False)) * Const(0, (6, False)) - self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (10, False)) - v2 = Const(0, (4, True)) * Const(0, (6, True)) - self.assertEqual(v2.shape(), (9, True)) - v3 = Const(0, (4, True)) * Const(0, (4, False)) - self.assertEqual(v3.shape(), (8, True)) - v5 = 10 * Const(0, 4) - self.assertEqual(v5.shape(), (8, False)) - - def test_and(self): - v1 = Const(0, (4, False)) & Const(0, (6, False)) - self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Const(0, (4, True)) & Const(0, (6, True)) - self.assertEqual(v2.shape(), (6, True)) - v3 = Const(0, (4, True)) & Const(0, (4, False)) - self.assertEqual(v3.shape(), (5, True)) - v4 = Const(0, (4, False)) & Const(0, (4, True)) - self.assertEqual(v4.shape(), (5, True)) - v5 = 10 & Const(0, 4) - self.assertEqual(v5.shape(), (4, False)) - - def test_or(self): - v1 = Const(0, (4, False)) | Const(0, (6, False)) - self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Const(0, (4, True)) | Const(0, (6, True)) - self.assertEqual(v2.shape(), (6, True)) - v3 = Const(0, (4, True)) | Const(0, (4, False)) - self.assertEqual(v3.shape(), (5, True)) - v4 = Const(0, (4, False)) | Const(0, (4, True)) - self.assertEqual(v4.shape(), (5, True)) - v5 = 10 | Const(0, 4) - self.assertEqual(v5.shape(), (4, False)) - - def test_xor(self): - v1 = Const(0, (4, False)) ^ Const(0, (6, False)) - self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Const(0, (4, True)) ^ Const(0, (6, True)) - self.assertEqual(v2.shape(), (6, True)) - v3 = Const(0, (4, True)) ^ Const(0, (4, False)) - self.assertEqual(v3.shape(), (5, True)) - v4 = Const(0, (4, False)) ^ Const(0, (4, True)) - self.assertEqual(v4.shape(), (5, True)) - v5 = 10 ^ Const(0, 4) - self.assertEqual(v5.shape(), (4, False)) - - def test_lt(self): - v = Const(0, 4) < Const(0, 6) - self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_le(self): - v = Const(0, 4) <= Const(0, 6) - self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_gt(self): - v = Const(0, 4) > Const(0, 6) - self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_ge(self): - v = Const(0, 4) >= Const(0, 6) - self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_eq(self): - v = Const(0, 4) == Const(0, 6) - self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_ne(self): - v = Const(0, 4) != Const(0, 6) - self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_mux(self): - s = Const(0) - v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False))) - self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True))) - self.assertEqual(v2.shape(), (6, True)) - v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False))) - self.assertEqual(v3.shape(), (5, True)) - v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True))) - self.assertEqual(v4.shape(), (5, True)) - - def test_bool(self): - v = Const(0).bool() - self.assertEqual(repr(v), "(b (const 0'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_hash(self): - with self.assertRaises(TypeError): - hash(Const(0) + Const(0)) - - -class SliceTestCase(unittest.TestCase): - def test_shape(self): - s1 = Const(10)[2] - self.assertEqual(s1.shape(), (1, False)) - s2 = Const(-10)[0:2] - self.assertEqual(s2.shape(), (2, False)) - - def test_repr(self): - s1 = Const(10)[2] - self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)") - - -class CatTestCase(unittest.TestCase): - def test_shape(self): - c1 = Cat(Const(10)) - self.assertEqual(c1.shape(), (4, False)) - c2 = Cat(Const(10), Const(1)) - self.assertEqual(c2.shape(), (5, False)) - c3 = Cat(Const(10), Const(1), Const(0)) - self.assertEqual(c3.shape(), (5, False)) - - def test_repr(self): - c1 = Cat(Const(10), Const(1)) - self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))") - - -class ReplTestCase(unittest.TestCase): - def test_shape(self): - r1 = Repl(Const(10), 3) - self.assertEqual(r1.shape(), (12, False)) - - def test_count_wrong(self): - with self.assertRaises(TypeError): - Repl(Const(10), -1) - with self.assertRaises(TypeError): - Repl(Const(10), "str") - - def test_repr(self): - r1 = Repl(Const(10), 3) - self.assertEqual(repr(r1), "(repl (const 4'd10) 3)") - - -class SignalTestCase(unittest.TestCase): - def test_shape(self): - s1 = Signal() - self.assertEqual(s1.shape(), (1, False)) - s2 = Signal(2) - self.assertEqual(s2.shape(), (2, False)) - s3 = Signal((2, False)) - self.assertEqual(s3.shape(), (2, False)) - s4 = Signal((2, True)) - self.assertEqual(s4.shape(), (2, True)) - s5 = Signal(max=16) - self.assertEqual(s5.shape(), (4, False)) - s6 = Signal(min=4, max=16) - self.assertEqual(s6.shape(), (4, False)) - s7 = Signal(min=-4, max=16) - self.assertEqual(s7.shape(), (5, True)) - s8 = Signal(min=-20, max=16) - self.assertEqual(s8.shape(), (6, True)) - - with self.assertRaises(ValueError): - Signal(min=10, max=4) - with self.assertRaises(ValueError): - Signal(2, min=10) - with self.assertRaises(TypeError): - Signal(-10) - - def test_name(self): - s1 = Signal() - self.assertEqual(s1.name, "s1") - s2 = Signal(name="sig") - self.assertEqual(s2.name, "sig") - - def test_reset(self): - s1 = Signal(4, reset=0b111, reset_less=True) - self.assertEqual(s1.reset, 0b111) - self.assertEqual(s1.reset_less, True) - - def test_attrs(self): - s1 = Signal() - self.assertEqual(s1.attrs, {}) - s2 = Signal(attrs={"no_retiming": True}) - self.assertEqual(s2.attrs, {"no_retiming": True}) - - def test_repr(self): - s1 = Signal() - self.assertEqual(repr(s1), "(sig s1)") - - def test_like(self): - s1 = Signal.like(Signal(4)) - self.assertEqual(s1.shape(), (4, False)) - s2 = Signal.like(Signal(min=-15)) - self.assertEqual(s2.shape(), (5, True)) - s3 = Signal.like(Signal(4, reset=0b111, reset_less=True)) - self.assertEqual(s3.reset, 0b111) - self.assertEqual(s3.reset_less, True) - s4 = Signal.like(Signal(attrs={"no_retiming": True})) - self.assertEqual(s4.attrs, {"no_retiming": True}) - s5 = Signal.like(10) - self.assertEqual(s5.shape(), (4, False)) - - -class ClockSignalTestCase(unittest.TestCase): - def test_domain(self): - s1 = ClockSignal() - self.assertEqual(s1.domain, "sync") - s2 = ClockSignal("pix") - self.assertEqual(s2.domain, "pix") - - with self.assertRaises(TypeError): - ClockSignal(1) - - def test_repr(self): - s1 = ClockSignal() - self.assertEqual(repr(s1), "(clk sync)") - - -class ResetSignalTestCase(unittest.TestCase): - def test_domain(self): - s1 = ResetSignal() - self.assertEqual(s1.domain, "sync") - s2 = ResetSignal("pix") - self.assertEqual(s2.domain, "pix") - - with self.assertRaises(TypeError): - ResetSignal(1) - - def test_repr(self): - s1 = ResetSignal() - self.assertEqual(repr(s1), "(reset sync)") diff --git a/nmigen/test/test_fhdl_dsl.py b/nmigen/test/test_fhdl_dsl.py new file mode 100644 index 0000000..55e494c --- /dev/null +++ b/nmigen/test/test_fhdl_dsl.py @@ -0,0 +1,200 @@ +import re +import unittest +from contextlib import contextmanager + +from nmigen.fhdl.ast import * +from nmigen.fhdl.dsl import * + + +class DSLTestCase(unittest.TestCase): + def setUp(self): + self.s1 = Signal() + self.s2 = Signal() + self.s3 = Signal() + self.s4 = Signal() + self.c1 = Signal() + self.c2 = Signal() + self.c3 = Signal() + self.w1 = Signal(4) + + @contextmanager + def assertRaises(self, exception, msg=None): + with super().assertRaises(exception) as cm: + yield + if msg: + # WTF? unittest.assertRaises is completely broken. + self.assertEqual(str(cm.exception), msg) + + def assertRepr(self, obj, repr_str): + repr_str = re.sub(r"\s+", " ", repr_str) + repr_str = re.sub(r"\( (?=\()", "(", repr_str) + repr_str = re.sub(r"\) (?=\))", ")", repr_str) + self.assertEqual(repr(obj), repr_str.strip()) + + def test_d_comb(self): + m = Module() + m.d.comb += self.c1.eq(1) + m._flush() + self.assertEqual(m._driving[self.c1], None) + self.assertRepr(m._statements, """( + (eq (sig c1) (const 1'd1)) + )""") + + def test_d_sync(self): + m = Module() + m.d.sync += self.c1.eq(1) + m._flush() + self.assertEqual(m._driving[self.c1], "sync") + self.assertRepr(m._statements, """( + (eq (sig c1) (const 1'd1)) + )""") + + def test_d_pix(self): + m = Module() + m.d.pix += self.c1.eq(1) + m._flush() + self.assertEqual(m._driving[self.c1], "pix") + self.assertRepr(m._statements, """( + (eq (sig c1) (const 1'd1)) + )""") + + def test_d_index(self): + m = Module() + m.d["pix"] += self.c1.eq(1) + m._flush() + self.assertEqual(m._driving[self.c1], "pix") + self.assertRepr(m._statements, """( + (eq (sig c1) (const 1'd1)) + )""") + + def test_d_no_conflict(self): + m = Module() + m.d.comb += self.w1[0].eq(1) + m.d.comb += self.w1[1].eq(1) + + def test_d_conflict(self): + m = Module() + with self.assertRaises(SyntaxError, + msg="Driver-driver conflict: trying to drive (sig c1) from d.sync, but it " + "is already driven from d.comb"): + m.d.comb += self.c1.eq(1) + m.d.sync += self.c1.eq(1) + + def test_d_wrong(self): + m = Module() + with self.assertRaises(AttributeError, + msg="Cannot assign 'd.pix' attribute; did you mean 'd.pix +='?"): + m.d.pix = None + + def test_d_asgn_wrong(self): + m = Module() + with self.assertRaises(SyntaxError, + msg="Only assignments may be appended to d.sync"): + m.d.sync += Switch(self.s1, {}) + + def test_comb_wrong(self): + m = Module() + with self.assertRaises(AttributeError, + msg="'Module' object has no attribute 'comb'; did you mean 'd.comb'?"): + m.comb += self.c1.eq(1) + + def test_sync_wrong(self): + m = Module() + with self.assertRaises(AttributeError, + msg="'Module' object has no attribute 'sync'; did you mean 'd.sync'?"): + m.sync += self.c1.eq(1) + + def test_attr_wrong(self): + m = Module() + with self.assertRaises(AttributeError, + msg="'Module' object has no attribute 'nonexistentattr'"): + m.nonexistentattr + + def test_If(self): + m = Module() + with m.If(self.s1): + m.d.comb += self.c1.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (cat (sig s1)) + (case 1 (eq (sig c1) (const 1'd1))) + ) + ) + """) + + def test_If_Elif(self): + m = Module() + with m.If(self.s1): + m.d.comb += self.c1.eq(1) + with m.Elif(self.s2): + m.d.sync += self.c2.eq(0) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (cat (sig s1) (sig s2)) + (case -1 (eq (sig c1) (const 1'd1))) + (case 1- (eq (sig c2) (const 0'd0))) + ) + ) + """) + + def test_If_Elif_Else(self): + m = Module() + with m.If(self.s1): + m.d.comb += self.c1.eq(1) + with m.Elif(self.s2): + m.d.sync += self.c2.eq(0) + with m.Else(): + m.d.comb += self.c3.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (cat (sig s1) (sig s2)) + (case -1 (eq (sig c1) (const 1'd1))) + (case 1- (eq (sig c2) (const 0'd0))) + (case -- (eq (sig c3) (const 1'd1))) + ) + ) + """) + + def test_Elif_wrong(self): + m = Module() + with self.assertRaises(SyntaxError, + msg="Elif without preceding If"): + with m.Elif(self.s2): + pass + + def test_Else_wrong(self): + m = Module() + with self.assertRaises(SyntaxError, + msg="Else without preceding If/Elif"): + with m.Else(): + pass + + def test_If_wide(self): + m = Module() + with m.If(self.w1): + m.d.comb += self.c1.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (cat (b (sig w1))) + (case 1 (eq (sig c1) (const 1'd1))) + ) + ) + """) + + def test_auto_flush(self): + m = Module() + with m.If(self.w1): + m.d.comb += self.c1.eq(1) + m.d.comb += self.c2.eq(1) + self.assertRepr(m._statements, """ + ( + (switch (cat (b (sig w1))) + (case 1 (eq (sig c1) (const 1'd1))) + ) + (eq (sig c2) (const 1'd1)) + ) + """) diff --git a/nmigen/test/test_fhdl_values.py b/nmigen/test/test_fhdl_values.py new file mode 100644 index 0000000..16e0970 --- /dev/null +++ b/nmigen/test/test_fhdl_values.py @@ -0,0 +1,358 @@ +import unittest + +from nmigen.fhdl.ast import * + + +class ValueTestCase(unittest.TestCase): + def test_wrap(self): + self.assertIsInstance(Value.wrap(0), Const) + self.assertIsInstance(Value.wrap(True), Const) + c = Const(0) + self.assertIs(Value.wrap(c), c) + with self.assertRaises(TypeError): + Value.wrap("str") + + def test_bool(self): + with self.assertRaises(TypeError): + if Const(0): + pass + + def test_len(self): + self.assertEqual(len(Const(10)), 4) + + def test_getitem_int(self): + s1 = Const(10)[0] + self.assertIsInstance(s1, Slice) + self.assertEqual(s1.start, 0) + self.assertEqual(s1.end, 1) + s2 = Const(10)[-1] + self.assertIsInstance(s2, Slice) + self.assertEqual(s2.start, 3) + self.assertEqual(s2.end, 4) + with self.assertRaises(IndexError): + Const(10)[5] + + def test_getitem_slice(self): + s1 = Const(10)[1:3] + self.assertIsInstance(s1, Slice) + self.assertEqual(s1.start, 1) + self.assertEqual(s1.end, 3) + s2 = Const(10)[1:-2] + self.assertIsInstance(s2, Slice) + self.assertEqual(s2.start, 1) + self.assertEqual(s2.end, 2) + s3 = Const(31)[::2] + self.assertIsInstance(s3, Cat) + self.assertIsInstance(s3.operands[0], Slice) + self.assertEqual(s3.operands[0].start, 0) + self.assertEqual(s3.operands[0].end, 1) + self.assertIsInstance(s3.operands[1], Slice) + self.assertEqual(s3.operands[1].start, 2) + self.assertEqual(s3.operands[1].end, 3) + self.assertIsInstance(s3.operands[2], Slice) + self.assertEqual(s3.operands[2].start, 4) + self.assertEqual(s3.operands[2].end, 5) + + def test_getitem_wrong(self): + with self.assertRaises(TypeError): + Const(31)["str"] + + +class ConstTestCase(unittest.TestCase): + def test_shape(self): + self.assertEqual(Const(0).shape(), (0, False)) + self.assertEqual(Const(1).shape(), (1, False)) + self.assertEqual(Const(10).shape(), (4, False)) + self.assertEqual(Const(-10).shape(), (4, True)) + + self.assertEqual(Const(1, 4).shape(), (4, False)) + self.assertEqual(Const(1, (4, True)).shape(), (4, True)) + + with self.assertRaises(TypeError): + Const(1, -1) + + def test_value(self): + self.assertEqual(Const(10).value, 10) + + def test_repr(self): + self.assertEqual(repr(Const(10)), "(const 4'd10)") + self.assertEqual(repr(Const(-10)), "(const 4'sd-10)") + + def test_hash(self): + with self.assertRaises(TypeError): + hash(Const(0)) + + +class OperatorTestCase(unittest.TestCase): + def test_invert(self): + v = ~Const(0, 4) + self.assertEqual(repr(v), "(~ (const 4'd0))") + self.assertEqual(v.shape(), (4, False)) + + def test_neg(self): + v1 = -Const(0, (4, False)) + self.assertEqual(repr(v1), "(- (const 4'd0))") + self.assertEqual(v1.shape(), (5, True)) + v2 = -Const(0, (4, True)) + self.assertEqual(repr(v2), "(- (const 4'sd0))") + self.assertEqual(v2.shape(), (4, True)) + + def test_add(self): + v1 = Const(0, (4, False)) + Const(0, (6, False)) + self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (7, False)) + v2 = Const(0, (4, True)) + Const(0, (6, True)) + self.assertEqual(v2.shape(), (7, True)) + v3 = Const(0, (4, True)) + Const(0, (4, False)) + self.assertEqual(v3.shape(), (6, True)) + v4 = Const(0, (4, False)) + Const(0, (4, True)) + self.assertEqual(v4.shape(), (6, True)) + v5 = 10 + Const(0, 4) + self.assertEqual(v5.shape(), (5, False)) + + def test_sub(self): + v1 = Const(0, (4, False)) - Const(0, (6, False)) + self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (7, False)) + v2 = Const(0, (4, True)) - Const(0, (6, True)) + self.assertEqual(v2.shape(), (7, True)) + v3 = Const(0, (4, True)) - Const(0, (4, False)) + self.assertEqual(v3.shape(), (6, True)) + v4 = Const(0, (4, False)) - Const(0, (4, True)) + self.assertEqual(v4.shape(), (6, True)) + v5 = 10 - Const(0, 4) + self.assertEqual(v5.shape(), (5, False)) + + def test_mul(self): + v1 = Const(0, (4, False)) * Const(0, (6, False)) + self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (10, False)) + v2 = Const(0, (4, True)) * Const(0, (6, True)) + self.assertEqual(v2.shape(), (9, True)) + v3 = Const(0, (4, True)) * Const(0, (4, False)) + self.assertEqual(v3.shape(), (8, True)) + v5 = 10 * Const(0, 4) + self.assertEqual(v5.shape(), (8, False)) + + def test_and(self): + v1 = Const(0, (4, False)) & Const(0, (6, False)) + self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Const(0, (4, True)) & Const(0, (6, True)) + self.assertEqual(v2.shape(), (6, True)) + v3 = Const(0, (4, True)) & Const(0, (4, False)) + self.assertEqual(v3.shape(), (5, True)) + v4 = Const(0, (4, False)) & Const(0, (4, True)) + self.assertEqual(v4.shape(), (5, True)) + v5 = 10 & Const(0, 4) + self.assertEqual(v5.shape(), (4, False)) + + def test_or(self): + v1 = Const(0, (4, False)) | Const(0, (6, False)) + self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Const(0, (4, True)) | Const(0, (6, True)) + self.assertEqual(v2.shape(), (6, True)) + v3 = Const(0, (4, True)) | Const(0, (4, False)) + self.assertEqual(v3.shape(), (5, True)) + v4 = Const(0, (4, False)) | Const(0, (4, True)) + self.assertEqual(v4.shape(), (5, True)) + v5 = 10 | Const(0, 4) + self.assertEqual(v5.shape(), (4, False)) + + def test_xor(self): + v1 = Const(0, (4, False)) ^ Const(0, (6, False)) + self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Const(0, (4, True)) ^ Const(0, (6, True)) + self.assertEqual(v2.shape(), (6, True)) + v3 = Const(0, (4, True)) ^ Const(0, (4, False)) + self.assertEqual(v3.shape(), (5, True)) + v4 = Const(0, (4, False)) ^ Const(0, (4, True)) + self.assertEqual(v4.shape(), (5, True)) + v5 = 10 ^ Const(0, 4) + self.assertEqual(v5.shape(), (4, False)) + + def test_lt(self): + v = Const(0, 4) < Const(0, 6) + self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_le(self): + v = Const(0, 4) <= Const(0, 6) + self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_gt(self): + v = Const(0, 4) > Const(0, 6) + self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_ge(self): + v = Const(0, 4) >= Const(0, 6) + self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_eq(self): + v = Const(0, 4) == Const(0, 6) + self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_ne(self): + v = Const(0, 4) != Const(0, 6) + self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_mux(self): + s = Const(0) + v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False))) + self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True))) + self.assertEqual(v2.shape(), (6, True)) + v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False))) + self.assertEqual(v3.shape(), (5, True)) + v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True))) + self.assertEqual(v4.shape(), (5, True)) + + def test_bool(self): + v = Const(0).bool() + self.assertEqual(repr(v), "(b (const 0'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_hash(self): + with self.assertRaises(TypeError): + hash(Const(0) + Const(0)) + + +class SliceTestCase(unittest.TestCase): + def test_shape(self): + s1 = Const(10)[2] + self.assertEqual(s1.shape(), (1, False)) + s2 = Const(-10)[0:2] + self.assertEqual(s2.shape(), (2, False)) + + def test_repr(self): + s1 = Const(10)[2] + self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)") + + +class CatTestCase(unittest.TestCase): + def test_shape(self): + c1 = Cat(Const(10)) + self.assertEqual(c1.shape(), (4, False)) + c2 = Cat(Const(10), Const(1)) + self.assertEqual(c2.shape(), (5, False)) + c3 = Cat(Const(10), Const(1), Const(0)) + self.assertEqual(c3.shape(), (5, False)) + + def test_repr(self): + c1 = Cat(Const(10), Const(1)) + self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))") + + +class ReplTestCase(unittest.TestCase): + def test_shape(self): + r1 = Repl(Const(10), 3) + self.assertEqual(r1.shape(), (12, False)) + + def test_count_wrong(self): + with self.assertRaises(TypeError): + Repl(Const(10), -1) + with self.assertRaises(TypeError): + Repl(Const(10), "str") + + def test_repr(self): + r1 = Repl(Const(10), 3) + self.assertEqual(repr(r1), "(repl (const 4'd10) 3)") + + +class SignalTestCase(unittest.TestCase): + def test_shape(self): + s1 = Signal() + self.assertEqual(s1.shape(), (1, False)) + s2 = Signal(2) + self.assertEqual(s2.shape(), (2, False)) + s3 = Signal((2, False)) + self.assertEqual(s3.shape(), (2, False)) + s4 = Signal((2, True)) + self.assertEqual(s4.shape(), (2, True)) + s5 = Signal(max=16) + self.assertEqual(s5.shape(), (4, False)) + s6 = Signal(min=4, max=16) + self.assertEqual(s6.shape(), (4, False)) + s7 = Signal(min=-4, max=16) + self.assertEqual(s7.shape(), (5, True)) + s8 = Signal(min=-20, max=16) + self.assertEqual(s8.shape(), (6, True)) + + with self.assertRaises(ValueError): + Signal(min=10, max=4) + with self.assertRaises(ValueError): + Signal(2, min=10) + with self.assertRaises(TypeError): + Signal(-10) + + def test_name(self): + s1 = Signal() + self.assertEqual(s1.name, "s1") + s2 = Signal(name="sig") + self.assertEqual(s2.name, "sig") + + def test_reset(self): + s1 = Signal(4, reset=0b111, reset_less=True) + self.assertEqual(s1.reset, 0b111) + self.assertEqual(s1.reset_less, True) + + def test_attrs(self): + s1 = Signal() + self.assertEqual(s1.attrs, {}) + s2 = Signal(attrs={"no_retiming": True}) + self.assertEqual(s2.attrs, {"no_retiming": True}) + + def test_repr(self): + s1 = Signal() + self.assertEqual(repr(s1), "(sig s1)") + + def test_like(self): + s1 = Signal.like(Signal(4)) + self.assertEqual(s1.shape(), (4, False)) + s2 = Signal.like(Signal(min=-15)) + self.assertEqual(s2.shape(), (5, True)) + s3 = Signal.like(Signal(4, reset=0b111, reset_less=True)) + self.assertEqual(s3.reset, 0b111) + self.assertEqual(s3.reset_less, True) + s4 = Signal.like(Signal(attrs={"no_retiming": True})) + self.assertEqual(s4.attrs, {"no_retiming": True}) + s5 = Signal.like(10) + self.assertEqual(s5.shape(), (4, False)) + + +class ClockSignalTestCase(unittest.TestCase): + def test_domain(self): + s1 = ClockSignal() + self.assertEqual(s1.domain, "sync") + s2 = ClockSignal("pix") + self.assertEqual(s2.domain, "pix") + + with self.assertRaises(TypeError): + ClockSignal(1) + + def test_repr(self): + s1 = ClockSignal() + self.assertEqual(repr(s1), "(clk sync)") + + +class ResetSignalTestCase(unittest.TestCase): + def test_domain(self): + s1 = ResetSignal() + self.assertEqual(s1.domain, "sync") + s2 = ResetSignal("pix") + self.assertEqual(s2.domain, "pix") + + with self.assertRaises(TypeError): + ResetSignal(1) + + def test_repr(self): + s1 = ResetSignal() + self.assertEqual(repr(s1), "(reset sync)")