From: whitequark Date: Thu, 13 Dec 2018 08:57:14 +0000 (+0000) Subject: fhdl.xfrm: implement DomainRenamer. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=0a32f0ecdf3a367377fcc88b44042602e7ea52ad;p=nmigen.git fhdl.xfrm: implement DomainRenamer. --- diff --git a/nmigen/fhdl/ast.py b/nmigen/fhdl/ast.py index 604bb9e..9d3af7d 100644 --- a/nmigen/fhdl/ast.py +++ b/nmigen/fhdl/ast.py @@ -573,7 +573,7 @@ class ClockSignal(Value): Parameters ---------- domain : str - Clock domain to obtain a clock signal for. Defaults to `"sync"`. + Clock domain to obtain a clock signal for. Defaults to ``"sync"``. """ def __init__(self, domain="sync"): super().__init__() @@ -588,13 +588,13 @@ class ClockSignal(Value): class ResetSignal(Value): """Reset signal for a given clock domain - `ResetSignal` s for a given clock domain can be retrieved multiple + ``ResetSignal`` s for a given clock domain can be retrieved multiple times. They all ultimately refer to the same signal. Parameters ---------- domain : str - Clock domain to obtain a reset signal for. Defaults to `"sync"`. + Clock domain to obtain a reset signal for. Defaults to ``"sync"``. """ def __init__(self, domain="sync"): super().__init__() @@ -603,7 +603,7 @@ class ResetSignal(Value): self.domain = domain def __repr__(self): - return "(reset {})".format(self.domain) + return "(rst {})".format(self.domain) class _StatementList(list): diff --git a/nmigen/fhdl/xfrm.py b/nmigen/fhdl/xfrm.py index 8a20aab..4277af6 100644 --- a/nmigen/fhdl/xfrm.py +++ b/nmigen/fhdl/xfrm.py @@ -4,7 +4,8 @@ from .ast import * from .ir import * -__all__ = ["ValueTransformer", "StatementTransformer", "ResetInserter", "CEInserter"] +__all__ = ["ValueTransformer", "StatementTransformer", "FragmentTransformer", + "DomainRenamer", "ResetInserter", "CEInserter"] class ValueTransformer: @@ -116,6 +117,30 @@ class FragmentTransformer: return self.on_fragment(value) +class DomainRenamer(FragmentTransformer, ValueTransformer, StatementTransformer): + def __init__(self, domains): + if isinstance(domains, str): + domains = {"sync": domains} + self.domains = OrderedDict(domains) + + def on_ClockSignal(self, value): + if value.domain in self.domains: + return ClockSignal(self.domains[value.domain]) + return value + + def on_ResetSignal(self, value): + if value.domain in self.domains: + return ResetSignal(self.domains[value.domain]) + return value + + def map_drivers(self, fragment, new_fragment): + for cd_name, signals in fragment.iter_domains(): + if cd_name in self.domains: + cd_name = self.domains[cd_name] + for signal in signals: + new_fragment.drive(signal, cd_name) + + class _ControlInserter(FragmentTransformer): def __init__(self, controls): if isinstance(controls, Value): diff --git a/nmigen/test/test_fhdl_dsl.py b/nmigen/test/test_fhdl_dsl.py index 22019b1..7da25cf 100644 --- a/nmigen/test/test_fhdl_dsl.py +++ b/nmigen/test/test_fhdl_dsl.py @@ -1,12 +1,12 @@ -import re import unittest from contextlib import contextmanager -from nmigen.fhdl.ast import * -from nmigen.fhdl.dsl import * +from ..fhdl.ast import * +from ..fhdl.dsl import * +from .tools import * -class DSLTestCase(unittest.TestCase): +class DSLTestCase(FHDLTestCase): def setUp(self): self.s1 = Signal() self.s2 = Signal() @@ -24,13 +24,6 @@ class DSLTestCase(unittest.TestCase): # WTF? unittest.assertRaises is completely broken. self.assertEqual(str(cm.exception), msg) - def assertRepr(self, obj, repr_str): - obj = Statement.wrap(obj) - repr_str = re.sub(r"\s+", " ", repr_str) - repr_str = re.sub(r"\( (?=\()", "(", repr_str) - repr_str = re.sub(r"\) (?=\))", ")", repr_str) - self.assertEqual(repr(obj), repr_str.strip()) - def test_d_comb(self): m = Module() m.d.comb += self.c1.eq(1) diff --git a/nmigen/test/test_fhdl_value.py b/nmigen/test/test_fhdl_value.py new file mode 100644 index 0000000..9c7dbde --- /dev/null +++ b/nmigen/test/test_fhdl_value.py @@ -0,0 +1,358 @@ +import unittest + +from nmigen.fhdl.ast import * + + +class ValueTestCase(unittest.TestCase): + def test_wrap(self): + self.assertIsInstance(Value.wrap(0), Const) + self.assertIsInstance(Value.wrap(True), Const) + c = Const(0) + self.assertIs(Value.wrap(c), c) + with self.assertRaises(TypeError): + Value.wrap("str") + + def test_bool(self): + with self.assertRaises(TypeError): + if Const(0): + pass + + def test_len(self): + self.assertEqual(len(Const(10)), 4) + + def test_getitem_int(self): + s1 = Const(10)[0] + self.assertIsInstance(s1, Slice) + self.assertEqual(s1.start, 0) + self.assertEqual(s1.end, 1) + s2 = Const(10)[-1] + self.assertIsInstance(s2, Slice) + self.assertEqual(s2.start, 3) + self.assertEqual(s2.end, 4) + with self.assertRaises(IndexError): + Const(10)[5] + + def test_getitem_slice(self): + s1 = Const(10)[1:3] + self.assertIsInstance(s1, Slice) + self.assertEqual(s1.start, 1) + self.assertEqual(s1.end, 3) + s2 = Const(10)[1:-2] + self.assertIsInstance(s2, Slice) + self.assertEqual(s2.start, 1) + self.assertEqual(s2.end, 2) + s3 = Const(31)[::2] + self.assertIsInstance(s3, Cat) + self.assertIsInstance(s3.operands[0], Slice) + self.assertEqual(s3.operands[0].start, 0) + self.assertEqual(s3.operands[0].end, 1) + self.assertIsInstance(s3.operands[1], Slice) + self.assertEqual(s3.operands[1].start, 2) + self.assertEqual(s3.operands[1].end, 3) + self.assertIsInstance(s3.operands[2], Slice) + self.assertEqual(s3.operands[2].start, 4) + self.assertEqual(s3.operands[2].end, 5) + + def test_getitem_wrong(self): + with self.assertRaises(TypeError): + Const(31)["str"] + + +class ConstTestCase(unittest.TestCase): + def test_shape(self): + self.assertEqual(Const(0).shape(), (0, False)) + self.assertEqual(Const(1).shape(), (1, False)) + self.assertEqual(Const(10).shape(), (4, False)) + self.assertEqual(Const(-10).shape(), (4, True)) + + self.assertEqual(Const(1, 4).shape(), (4, False)) + self.assertEqual(Const(1, (4, True)).shape(), (4, True)) + + with self.assertRaises(TypeError): + Const(1, -1) + + def test_value(self): + self.assertEqual(Const(10).value, 10) + + def test_repr(self): + self.assertEqual(repr(Const(10)), "(const 4'd10)") + self.assertEqual(repr(Const(-10)), "(const 4'sd-10)") + + def test_hash(self): + with self.assertRaises(TypeError): + hash(Const(0)) + + +class OperatorTestCase(unittest.TestCase): + def test_invert(self): + v = ~Const(0, 4) + self.assertEqual(repr(v), "(~ (const 4'd0))") + self.assertEqual(v.shape(), (4, False)) + + def test_neg(self): + v1 = -Const(0, (4, False)) + self.assertEqual(repr(v1), "(- (const 4'd0))") + self.assertEqual(v1.shape(), (5, True)) + v2 = -Const(0, (4, True)) + self.assertEqual(repr(v2), "(- (const 4'sd0))") + self.assertEqual(v2.shape(), (4, True)) + + def test_add(self): + v1 = Const(0, (4, False)) + Const(0, (6, False)) + self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (7, False)) + v2 = Const(0, (4, True)) + Const(0, (6, True)) + self.assertEqual(v2.shape(), (7, True)) + v3 = Const(0, (4, True)) + Const(0, (4, False)) + self.assertEqual(v3.shape(), (6, True)) + v4 = Const(0, (4, False)) + Const(0, (4, True)) + self.assertEqual(v4.shape(), (6, True)) + v5 = 10 + Const(0, 4) + self.assertEqual(v5.shape(), (5, False)) + + def test_sub(self): + v1 = Const(0, (4, False)) - Const(0, (6, False)) + self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (7, False)) + v2 = Const(0, (4, True)) - Const(0, (6, True)) + self.assertEqual(v2.shape(), (7, True)) + v3 = Const(0, (4, True)) - Const(0, (4, False)) + self.assertEqual(v3.shape(), (6, True)) + v4 = Const(0, (4, False)) - Const(0, (4, True)) + self.assertEqual(v4.shape(), (6, True)) + v5 = 10 - Const(0, 4) + self.assertEqual(v5.shape(), (5, False)) + + def test_mul(self): + v1 = Const(0, (4, False)) * Const(0, (6, False)) + self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (10, False)) + v2 = Const(0, (4, True)) * Const(0, (6, True)) + self.assertEqual(v2.shape(), (9, True)) + v3 = Const(0, (4, True)) * Const(0, (4, False)) + self.assertEqual(v3.shape(), (8, True)) + v5 = 10 * Const(0, 4) + self.assertEqual(v5.shape(), (8, False)) + + def test_and(self): + v1 = Const(0, (4, False)) & Const(0, (6, False)) + self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Const(0, (4, True)) & Const(0, (6, True)) + self.assertEqual(v2.shape(), (6, True)) + v3 = Const(0, (4, True)) & Const(0, (4, False)) + self.assertEqual(v3.shape(), (5, True)) + v4 = Const(0, (4, False)) & Const(0, (4, True)) + self.assertEqual(v4.shape(), (5, True)) + v5 = 10 & Const(0, 4) + self.assertEqual(v5.shape(), (4, False)) + + def test_or(self): + v1 = Const(0, (4, False)) | Const(0, (6, False)) + self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Const(0, (4, True)) | Const(0, (6, True)) + self.assertEqual(v2.shape(), (6, True)) + v3 = Const(0, (4, True)) | Const(0, (4, False)) + self.assertEqual(v3.shape(), (5, True)) + v4 = Const(0, (4, False)) | Const(0, (4, True)) + self.assertEqual(v4.shape(), (5, True)) + v5 = 10 | Const(0, 4) + self.assertEqual(v5.shape(), (4, False)) + + def test_xor(self): + v1 = Const(0, (4, False)) ^ Const(0, (6, False)) + self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Const(0, (4, True)) ^ Const(0, (6, True)) + self.assertEqual(v2.shape(), (6, True)) + v3 = Const(0, (4, True)) ^ Const(0, (4, False)) + self.assertEqual(v3.shape(), (5, True)) + v4 = Const(0, (4, False)) ^ Const(0, (4, True)) + self.assertEqual(v4.shape(), (5, True)) + v5 = 10 ^ Const(0, 4) + self.assertEqual(v5.shape(), (4, False)) + + def test_lt(self): + v = Const(0, 4) < Const(0, 6) + self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_le(self): + v = Const(0, 4) <= Const(0, 6) + self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_gt(self): + v = Const(0, 4) > Const(0, 6) + self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_ge(self): + v = Const(0, 4) >= Const(0, 6) + self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_eq(self): + v = Const(0, 4) == Const(0, 6) + self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_ne(self): + v = Const(0, 4) != Const(0, 6) + self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_mux(self): + s = Const(0) + v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False))) + self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))") + self.assertEqual(v1.shape(), (6, False)) + v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True))) + self.assertEqual(v2.shape(), (6, True)) + v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False))) + self.assertEqual(v3.shape(), (5, True)) + v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True))) + self.assertEqual(v4.shape(), (5, True)) + + def test_bool(self): + v = Const(0).bool() + self.assertEqual(repr(v), "(b (const 0'd0))") + self.assertEqual(v.shape(), (1, False)) + + def test_hash(self): + with self.assertRaises(TypeError): + hash(Const(0) + Const(0)) + + +class SliceTestCase(unittest.TestCase): + def test_shape(self): + s1 = Const(10)[2] + self.assertEqual(s1.shape(), (1, False)) + s2 = Const(-10)[0:2] + self.assertEqual(s2.shape(), (2, False)) + + def test_repr(self): + s1 = Const(10)[2] + self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)") + + +class CatTestCase(unittest.TestCase): + def test_shape(self): + c1 = Cat(Const(10)) + self.assertEqual(c1.shape(), (4, False)) + c2 = Cat(Const(10), Const(1)) + self.assertEqual(c2.shape(), (5, False)) + c3 = Cat(Const(10), Const(1), Const(0)) + self.assertEqual(c3.shape(), (5, False)) + + def test_repr(self): + c1 = Cat(Const(10), Const(1)) + self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))") + + +class ReplTestCase(unittest.TestCase): + def test_shape(self): + r1 = Repl(Const(10), 3) + self.assertEqual(r1.shape(), (12, False)) + + def test_count_wrong(self): + with self.assertRaises(TypeError): + Repl(Const(10), -1) + with self.assertRaises(TypeError): + Repl(Const(10), "str") + + def test_repr(self): + r1 = Repl(Const(10), 3) + self.assertEqual(repr(r1), "(repl (const 4'd10) 3)") + + +class SignalTestCase(unittest.TestCase): + def test_shape(self): + s1 = Signal() + self.assertEqual(s1.shape(), (1, False)) + s2 = Signal(2) + self.assertEqual(s2.shape(), (2, False)) + s3 = Signal((2, False)) + self.assertEqual(s3.shape(), (2, False)) + s4 = Signal((2, True)) + self.assertEqual(s4.shape(), (2, True)) + s5 = Signal(max=16) + self.assertEqual(s5.shape(), (4, False)) + s6 = Signal(min=4, max=16) + self.assertEqual(s6.shape(), (4, False)) + s7 = Signal(min=-4, max=16) + self.assertEqual(s7.shape(), (5, True)) + s8 = Signal(min=-20, max=16) + self.assertEqual(s8.shape(), (6, True)) + + with self.assertRaises(ValueError): + Signal(min=10, max=4) + with self.assertRaises(ValueError): + Signal(2, min=10) + with self.assertRaises(TypeError): + Signal(-10) + + def test_name(self): + s1 = Signal() + self.assertEqual(s1.name, "s1") + s2 = Signal(name="sig") + self.assertEqual(s2.name, "sig") + + def test_reset(self): + s1 = Signal(4, reset=0b111, reset_less=True) + self.assertEqual(s1.reset, 0b111) + self.assertEqual(s1.reset_less, True) + + def test_attrs(self): + s1 = Signal() + self.assertEqual(s1.attrs, {}) + s2 = Signal(attrs={"no_retiming": True}) + self.assertEqual(s2.attrs, {"no_retiming": True}) + + def test_repr(self): + s1 = Signal() + self.assertEqual(repr(s1), "(sig s1)") + + def test_like(self): + s1 = Signal.like(Signal(4)) + self.assertEqual(s1.shape(), (4, False)) + s2 = Signal.like(Signal(min=-15)) + self.assertEqual(s2.shape(), (5, True)) + s3 = Signal.like(Signal(4, reset=0b111, reset_less=True)) + self.assertEqual(s3.reset, 0b111) + self.assertEqual(s3.reset_less, True) + s4 = Signal.like(Signal(attrs={"no_retiming": True})) + self.assertEqual(s4.attrs, {"no_retiming": True}) + s5 = Signal.like(10) + self.assertEqual(s5.shape(), (4, False)) + + +class ClockSignalTestCase(unittest.TestCase): + def test_domain(self): + s1 = ClockSignal() + self.assertEqual(s1.domain, "sync") + s2 = ClockSignal("pix") + self.assertEqual(s2.domain, "pix") + + with self.assertRaises(TypeError): + ClockSignal(1) + + def test_repr(self): + s1 = ClockSignal() + self.assertEqual(repr(s1), "(clk sync)") + + +class ResetSignalTestCase(unittest.TestCase): + def test_domain(self): + s1 = ResetSignal() + self.assertEqual(s1.domain, "sync") + s2 = ResetSignal("pix") + self.assertEqual(s2.domain, "pix") + + with self.assertRaises(TypeError): + ResetSignal(1) + + def test_repr(self): + s1 = ResetSignal() + self.assertEqual(repr(s1), "(rst sync)") diff --git a/nmigen/test/test_fhdl_values.py b/nmigen/test/test_fhdl_values.py deleted file mode 100644 index 16e0970..0000000 --- a/nmigen/test/test_fhdl_values.py +++ /dev/null @@ -1,358 +0,0 @@ -import unittest - -from nmigen.fhdl.ast import * - - -class ValueTestCase(unittest.TestCase): - def test_wrap(self): - self.assertIsInstance(Value.wrap(0), Const) - self.assertIsInstance(Value.wrap(True), Const) - c = Const(0) - self.assertIs(Value.wrap(c), c) - with self.assertRaises(TypeError): - Value.wrap("str") - - def test_bool(self): - with self.assertRaises(TypeError): - if Const(0): - pass - - def test_len(self): - self.assertEqual(len(Const(10)), 4) - - def test_getitem_int(self): - s1 = Const(10)[0] - self.assertIsInstance(s1, Slice) - self.assertEqual(s1.start, 0) - self.assertEqual(s1.end, 1) - s2 = Const(10)[-1] - self.assertIsInstance(s2, Slice) - self.assertEqual(s2.start, 3) - self.assertEqual(s2.end, 4) - with self.assertRaises(IndexError): - Const(10)[5] - - def test_getitem_slice(self): - s1 = Const(10)[1:3] - self.assertIsInstance(s1, Slice) - self.assertEqual(s1.start, 1) - self.assertEqual(s1.end, 3) - s2 = Const(10)[1:-2] - self.assertIsInstance(s2, Slice) - self.assertEqual(s2.start, 1) - self.assertEqual(s2.end, 2) - s3 = Const(31)[::2] - self.assertIsInstance(s3, Cat) - self.assertIsInstance(s3.operands[0], Slice) - self.assertEqual(s3.operands[0].start, 0) - self.assertEqual(s3.operands[0].end, 1) - self.assertIsInstance(s3.operands[1], Slice) - self.assertEqual(s3.operands[1].start, 2) - self.assertEqual(s3.operands[1].end, 3) - self.assertIsInstance(s3.operands[2], Slice) - self.assertEqual(s3.operands[2].start, 4) - self.assertEqual(s3.operands[2].end, 5) - - def test_getitem_wrong(self): - with self.assertRaises(TypeError): - Const(31)["str"] - - -class ConstTestCase(unittest.TestCase): - def test_shape(self): - self.assertEqual(Const(0).shape(), (0, False)) - self.assertEqual(Const(1).shape(), (1, False)) - self.assertEqual(Const(10).shape(), (4, False)) - self.assertEqual(Const(-10).shape(), (4, True)) - - self.assertEqual(Const(1, 4).shape(), (4, False)) - self.assertEqual(Const(1, (4, True)).shape(), (4, True)) - - with self.assertRaises(TypeError): - Const(1, -1) - - def test_value(self): - self.assertEqual(Const(10).value, 10) - - def test_repr(self): - self.assertEqual(repr(Const(10)), "(const 4'd10)") - self.assertEqual(repr(Const(-10)), "(const 4'sd-10)") - - def test_hash(self): - with self.assertRaises(TypeError): - hash(Const(0)) - - -class OperatorTestCase(unittest.TestCase): - def test_invert(self): - v = ~Const(0, 4) - self.assertEqual(repr(v), "(~ (const 4'd0))") - self.assertEqual(v.shape(), (4, False)) - - def test_neg(self): - v1 = -Const(0, (4, False)) - self.assertEqual(repr(v1), "(- (const 4'd0))") - self.assertEqual(v1.shape(), (5, True)) - v2 = -Const(0, (4, True)) - self.assertEqual(repr(v2), "(- (const 4'sd0))") - self.assertEqual(v2.shape(), (4, True)) - - def test_add(self): - v1 = Const(0, (4, False)) + Const(0, (6, False)) - self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (7, False)) - v2 = Const(0, (4, True)) + Const(0, (6, True)) - self.assertEqual(v2.shape(), (7, True)) - v3 = Const(0, (4, True)) + Const(0, (4, False)) - self.assertEqual(v3.shape(), (6, True)) - v4 = Const(0, (4, False)) + Const(0, (4, True)) - self.assertEqual(v4.shape(), (6, True)) - v5 = 10 + Const(0, 4) - self.assertEqual(v5.shape(), (5, False)) - - def test_sub(self): - v1 = Const(0, (4, False)) - Const(0, (6, False)) - self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (7, False)) - v2 = Const(0, (4, True)) - Const(0, (6, True)) - self.assertEqual(v2.shape(), (7, True)) - v3 = Const(0, (4, True)) - Const(0, (4, False)) - self.assertEqual(v3.shape(), (6, True)) - v4 = Const(0, (4, False)) - Const(0, (4, True)) - self.assertEqual(v4.shape(), (6, True)) - v5 = 10 - Const(0, 4) - self.assertEqual(v5.shape(), (5, False)) - - def test_mul(self): - v1 = Const(0, (4, False)) * Const(0, (6, False)) - self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (10, False)) - v2 = Const(0, (4, True)) * Const(0, (6, True)) - self.assertEqual(v2.shape(), (9, True)) - v3 = Const(0, (4, True)) * Const(0, (4, False)) - self.assertEqual(v3.shape(), (8, True)) - v5 = 10 * Const(0, 4) - self.assertEqual(v5.shape(), (8, False)) - - def test_and(self): - v1 = Const(0, (4, False)) & Const(0, (6, False)) - self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Const(0, (4, True)) & Const(0, (6, True)) - self.assertEqual(v2.shape(), (6, True)) - v3 = Const(0, (4, True)) & Const(0, (4, False)) - self.assertEqual(v3.shape(), (5, True)) - v4 = Const(0, (4, False)) & Const(0, (4, True)) - self.assertEqual(v4.shape(), (5, True)) - v5 = 10 & Const(0, 4) - self.assertEqual(v5.shape(), (4, False)) - - def test_or(self): - v1 = Const(0, (4, False)) | Const(0, (6, False)) - self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Const(0, (4, True)) | Const(0, (6, True)) - self.assertEqual(v2.shape(), (6, True)) - v3 = Const(0, (4, True)) | Const(0, (4, False)) - self.assertEqual(v3.shape(), (5, True)) - v4 = Const(0, (4, False)) | Const(0, (4, True)) - self.assertEqual(v4.shape(), (5, True)) - v5 = 10 | Const(0, 4) - self.assertEqual(v5.shape(), (4, False)) - - def test_xor(self): - v1 = Const(0, (4, False)) ^ Const(0, (6, False)) - self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Const(0, (4, True)) ^ Const(0, (6, True)) - self.assertEqual(v2.shape(), (6, True)) - v3 = Const(0, (4, True)) ^ Const(0, (4, False)) - self.assertEqual(v3.shape(), (5, True)) - v4 = Const(0, (4, False)) ^ Const(0, (4, True)) - self.assertEqual(v4.shape(), (5, True)) - v5 = 10 ^ Const(0, 4) - self.assertEqual(v5.shape(), (4, False)) - - def test_lt(self): - v = Const(0, 4) < Const(0, 6) - self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_le(self): - v = Const(0, 4) <= Const(0, 6) - self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_gt(self): - v = Const(0, 4) > Const(0, 6) - self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_ge(self): - v = Const(0, 4) >= Const(0, 6) - self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_eq(self): - v = Const(0, 4) == Const(0, 6) - self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_ne(self): - v = Const(0, 4) != Const(0, 6) - self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_mux(self): - s = Const(0) - v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False))) - self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))") - self.assertEqual(v1.shape(), (6, False)) - v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True))) - self.assertEqual(v2.shape(), (6, True)) - v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False))) - self.assertEqual(v3.shape(), (5, True)) - v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True))) - self.assertEqual(v4.shape(), (5, True)) - - def test_bool(self): - v = Const(0).bool() - self.assertEqual(repr(v), "(b (const 0'd0))") - self.assertEqual(v.shape(), (1, False)) - - def test_hash(self): - with self.assertRaises(TypeError): - hash(Const(0) + Const(0)) - - -class SliceTestCase(unittest.TestCase): - def test_shape(self): - s1 = Const(10)[2] - self.assertEqual(s1.shape(), (1, False)) - s2 = Const(-10)[0:2] - self.assertEqual(s2.shape(), (2, False)) - - def test_repr(self): - s1 = Const(10)[2] - self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)") - - -class CatTestCase(unittest.TestCase): - def test_shape(self): - c1 = Cat(Const(10)) - self.assertEqual(c1.shape(), (4, False)) - c2 = Cat(Const(10), Const(1)) - self.assertEqual(c2.shape(), (5, False)) - c3 = Cat(Const(10), Const(1), Const(0)) - self.assertEqual(c3.shape(), (5, False)) - - def test_repr(self): - c1 = Cat(Const(10), Const(1)) - self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))") - - -class ReplTestCase(unittest.TestCase): - def test_shape(self): - r1 = Repl(Const(10), 3) - self.assertEqual(r1.shape(), (12, False)) - - def test_count_wrong(self): - with self.assertRaises(TypeError): - Repl(Const(10), -1) - with self.assertRaises(TypeError): - Repl(Const(10), "str") - - def test_repr(self): - r1 = Repl(Const(10), 3) - self.assertEqual(repr(r1), "(repl (const 4'd10) 3)") - - -class SignalTestCase(unittest.TestCase): - def test_shape(self): - s1 = Signal() - self.assertEqual(s1.shape(), (1, False)) - s2 = Signal(2) - self.assertEqual(s2.shape(), (2, False)) - s3 = Signal((2, False)) - self.assertEqual(s3.shape(), (2, False)) - s4 = Signal((2, True)) - self.assertEqual(s4.shape(), (2, True)) - s5 = Signal(max=16) - self.assertEqual(s5.shape(), (4, False)) - s6 = Signal(min=4, max=16) - self.assertEqual(s6.shape(), (4, False)) - s7 = Signal(min=-4, max=16) - self.assertEqual(s7.shape(), (5, True)) - s8 = Signal(min=-20, max=16) - self.assertEqual(s8.shape(), (6, True)) - - with self.assertRaises(ValueError): - Signal(min=10, max=4) - with self.assertRaises(ValueError): - Signal(2, min=10) - with self.assertRaises(TypeError): - Signal(-10) - - def test_name(self): - s1 = Signal() - self.assertEqual(s1.name, "s1") - s2 = Signal(name="sig") - self.assertEqual(s2.name, "sig") - - def test_reset(self): - s1 = Signal(4, reset=0b111, reset_less=True) - self.assertEqual(s1.reset, 0b111) - self.assertEqual(s1.reset_less, True) - - def test_attrs(self): - s1 = Signal() - self.assertEqual(s1.attrs, {}) - s2 = Signal(attrs={"no_retiming": True}) - self.assertEqual(s2.attrs, {"no_retiming": True}) - - def test_repr(self): - s1 = Signal() - self.assertEqual(repr(s1), "(sig s1)") - - def test_like(self): - s1 = Signal.like(Signal(4)) - self.assertEqual(s1.shape(), (4, False)) - s2 = Signal.like(Signal(min=-15)) - self.assertEqual(s2.shape(), (5, True)) - s3 = Signal.like(Signal(4, reset=0b111, reset_less=True)) - self.assertEqual(s3.reset, 0b111) - self.assertEqual(s3.reset_less, True) - s4 = Signal.like(Signal(attrs={"no_retiming": True})) - self.assertEqual(s4.attrs, {"no_retiming": True}) - s5 = Signal.like(10) - self.assertEqual(s5.shape(), (4, False)) - - -class ClockSignalTestCase(unittest.TestCase): - def test_domain(self): - s1 = ClockSignal() - self.assertEqual(s1.domain, "sync") - s2 = ClockSignal("pix") - self.assertEqual(s2.domain, "pix") - - with self.assertRaises(TypeError): - ClockSignal(1) - - def test_repr(self): - s1 = ClockSignal() - self.assertEqual(repr(s1), "(clk sync)") - - -class ResetSignalTestCase(unittest.TestCase): - def test_domain(self): - s1 = ResetSignal() - self.assertEqual(s1.domain, "sync") - s2 = ResetSignal("pix") - self.assertEqual(s2.domain, "pix") - - with self.assertRaises(TypeError): - ResetSignal(1) - - def test_repr(self): - s1 = ResetSignal() - self.assertEqual(repr(s1), "(reset sync)") diff --git a/nmigen/test/test_fhdl_xfrm.py b/nmigen/test/test_fhdl_xfrm.py index b569601..aad4525 100644 --- a/nmigen/test/test_fhdl_xfrm.py +++ b/nmigen/test/test_fhdl_xfrm.py @@ -1,25 +1,72 @@ import re import unittest -from nmigen.fhdl.ast import * -from nmigen.fhdl.ir import * -from nmigen.fhdl.xfrm import * +from ..fhdl.ast import * +from ..fhdl.ir import * +from ..fhdl.xfrm import * +from .tools import * -class ResetInserterTestCase(unittest.TestCase): +class DomainRenamerTestCase(FHDLTestCase): + def setUp(self): + self.s1 = Signal() + self.s2 = Signal() + self.s3 = Signal() + self.s4 = Signal() + self.s5 = Signal() + self.c1 = Signal() + + def test_rename_signals(self): + f = Fragment() + f.add_statements( + self.s1.eq(ClockSignal()), + ResetSignal().eq(self.s2), + self.s3.eq(0), + self.s4.eq(ClockSignal("other")), + self.s5.eq(ResetSignal("other")), + ) + f.drive(self.s1, None) + f.drive(self.s2, None) + f.drive(self.s3, "sync") + + f = DomainRenamer("pix")(f) + self.assertRepr(f.statements, """ + ( + (eq (sig s1) (clk pix)) + (eq (rst pix) (sig s2)) + (eq (sig s3) (const 0'd0)) + (eq (sig s4) (clk other)) + (eq (sig s5) (rst other)) + ) + """) + self.assertEqual(f.drivers, { + None: ValueSet((self.s1, self.s2)), + "pix": ValueSet((self.s3,)), + }) + + def test_rename_multi(self): + f = Fragment() + f.add_statements( + self.s1.eq(ClockSignal()), + self.s2.eq(ResetSignal("other")), + ) + + f = DomainRenamer({"sync": "pix", "other": "pix2"})(f) + self.assertRepr(f.statements, """ + ( + (eq (sig s1) (clk pix)) + (eq (sig s2) (rst pix2)) + ) + """) + + +class ResetInserterTestCase(FHDLTestCase): def setUp(self): self.s1 = Signal() self.s2 = Signal(reset=1) self.s3 = Signal(reset=1, reset_less=True) self.c1 = Signal() - def assertRepr(self, obj, repr_str): - obj = Statement.wrap(obj) - repr_str = re.sub(r"\s+", " ", repr_str) - repr_str = re.sub(r"\( (?=\()", "(", repr_str) - repr_str = re.sub(r"\) (?=\))", ")", repr_str) - self.assertEqual(repr(obj), repr_str.strip()) - def test_reset_default(self): f = Fragment() f.add_statements( @@ -92,20 +139,13 @@ class ResetInserterTestCase(unittest.TestCase): """) -class CEInserterTestCase(unittest.TestCase): +class CEInserterTestCase(FHDLTestCase): def setUp(self): self.s1 = Signal() self.s2 = Signal() self.s3 = Signal() self.c1 = Signal() - def assertRepr(self, obj, repr_str): - obj = Statement.wrap(obj) - repr_str = re.sub(r"\s+", " ", repr_str) - repr_str = re.sub(r"\( (?=\()", "(", repr_str) - repr_str = re.sub(r"\) (?=\))", ")", repr_str) - self.assertEqual(repr(obj), repr_str.strip()) - def test_ce_default(self): f = Fragment() f.add_statements( diff --git a/nmigen/test/tools.py b/nmigen/test/tools.py new file mode 100644 index 0000000..9e3a8f0 --- /dev/null +++ b/nmigen/test/tools.py @@ -0,0 +1,16 @@ +import re +import unittest + +from ..fhdl.ast import * + + +__all__ = ["FHDLTestCase"] + + +class FHDLTestCase(unittest.TestCase): + def assertRepr(self, obj, repr_str): + obj = Statement.wrap(obj) + repr_str = re.sub(r"\s+", " ", repr_str) + repr_str = re.sub(r"\( (?=\()", "(", repr_str) + repr_str = re.sub(r"\) (?=\))", ")", repr_str) + self.assertEqual(repr(obj), repr_str.strip())