From 932f1912a2ffe5d890441675a3a38d127d49a5c7 Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 13 Dec 2018 07:11:06 +0000 Subject: [PATCH] fhdl.dsl: use less error-prone Switch/Case two-level syntax. --- examples/pmux.py | 17 ++-- nmigen/fhdl/dsl.py | 170 +++++++++++++++++++++-------------- nmigen/test/test_fhdl_dsl.py | 123 ++++++++++++++++++++++++- 3 files changed, 235 insertions(+), 75 deletions(-) diff --git a/examples/pmux.py b/examples/pmux.py index e69068b..cda5447 100644 --- a/examples/pmux.py +++ b/examples/pmux.py @@ -12,14 +12,15 @@ class ParMux: def get_fragment(self, platform): m = Module() - with m.Case(self.s, "--1"): - m.d.comb += self.o.eq(self.a) - with m.Case(self.s, "-1-"): - m.d.comb += self.o.eq(self.b) - with m.Case(self.s, "1--"): - m.d.comb += self.o.eq(self.c) - with m.Case(self.s): - m.d.comb += self.o.eq(0) + with m.Switch(self.s): + with m.Case("--1"): + m.d.comb += self.o.eq(self.a) + with m.Case("-1-"): + m.d.comb += self.o.eq(self.b) + with m.Case("1--"): + m.d.comb += self.o.eq(self.c) + with m.Case(): + m.d.comb += self.o.eq(0) return m.lower(platform) diff --git a/nmigen/fhdl/dsl.py b/nmigen/fhdl/dsl.py index 0c0be8c..9ef6303 100644 --- a/nmigen/fhdl/dsl.py +++ b/nmigen/fhdl/dsl.py @@ -64,30 +64,6 @@ class _ModuleBuilderRoot: .format(type(self).__name__, name)) -class _ModuleBuilderCase(_ModuleBuilderRoot): - def __init__(self, builder, depth, test, value): - super().__init__(builder, depth) - self._test = test - self._value = value - - def __enter__(self): - if self._value is None: - self._value = "-" * len(self._test) - if isinstance(self._value, str) and len(self._test) != len(self._value): - raise SyntaxError("Case value {} must have the same width as test {}" - .format(self._value, self._test)) - if self._builder._stmt_switch_test != ValueKey(self._test): - self._builder._flush() - self._builder._stmt_switch_test = ValueKey(self._test) - self._outer_case = self._builder._statements - self._builder._statements = [] - return self - - def __exit__(self, *args): - self._builder._stmt_switch_cases[self._value] = self._builder._statements - self._builder._statements = self._outer_case - - class _ModuleBuilderSubmodules: def __init__(self, builder): object.__setattr__(self, "_builder", builder) @@ -106,86 +82,147 @@ class Module(_ModuleBuilderRoot): _ModuleBuilderRoot.__init__(self, self, depth=0) self.submodules = _ModuleBuilderSubmodules(self) - self._submodules = [] - self._driving = ValueDict() - self._statements = Statement.wrap([]) - self._stmt_depth = 0 + self._submodules = [] + self._driving = ValueDict() + self._statements = Statement.wrap([]) + self._ctrl_context = None + self._ctrl_stack = [] self._stmt_if_cond = [] self._stmt_if_bodies = [] self._stmt_switch_test = None self._stmt_switch_cases = OrderedDict() + def _check_context(self, construct, context): + if self._ctrl_context != context: + if self._ctrl_context is None: + raise SyntaxError("{} is not permitted outside of {}" + .format(construct, context)) + else: + raise SyntaxError("{} is not permitted inside of {}" + .format(construct, self._ctrl_context)) + + def _get_ctrl(self, name): + if self._ctrl_stack: + top_name, top_data = self._ctrl_stack[-1] + if top_name == name: + return top_data + + def _flush_ctrl(self): + while len(self._ctrl_stack) > self.domain._depth: + self._pop_ctrl() + + def _set_ctrl(self, name, data): + self._flush_ctrl() + self._ctrl_stack.append((name, data)) + return data + @contextmanager def If(self, cond): - self._flush() + self._check_context("If", context=None) + if_data = self._set_ctrl("If", {"tests": [], "bodies": []}) try: - _outer_case = self._statements - self._statements = [] + _outer_case, self._statements = self._statements, [] self.domain._depth += 1 yield - self._stmt_if_cond.append(cond) - self._stmt_if_bodies.append(self._statements) + self._flush_ctrl() + if_data["tests"].append(cond) + if_data["bodies"].append(self._statements) finally: self.domain._depth -= 1 self._statements = _outer_case @contextmanager def Elif(self, cond): - if not self._stmt_if_cond: + self._check_context("Elif", context=None) + if_data = self._get_ctrl("If") + if if_data is None: raise SyntaxError("Elif without preceding If") try: - _outer_case = self._statements - self._statements = [] + _outer_case, self._statements = self._statements, [] self.domain._depth += 1 yield - self._stmt_if_cond.append(cond) - self._stmt_if_bodies.append(self._statements) + self._flush_ctrl() + if_data["tests"].append(cond) + if_data["bodies"].append(self._statements) finally: self.domain._depth -= 1 self._statements = _outer_case @contextmanager def Else(self): - if not self._stmt_if_cond: + self._check_context("Else", context=None) + if_data = self._get_ctrl("If") + if if_data is None: raise SyntaxError("Else without preceding If/Elif") try: - _outer_case = self._statements - self._statements = [] + _outer_case, self._statements = self._statements, [] self.domain._depth += 1 yield - self._stmt_if_bodies.append(self._statements) + self._flush_ctrl() + if_data["bodies"].append(self._statements) finally: self.domain._depth -= 1 self._statements = _outer_case - self._flush() + self._pop_ctrl() + + @contextmanager + def Switch(self, test): + self._check_context("Switch", context=None) + switch_data = self._set_ctrl("Switch", {"test": test, "cases": OrderedDict()}) + try: + self._ctrl_context = "Switch" + self.domain._depth += 1 + yield + finally: + self.domain._depth -= 1 + self._ctrl_context = None + self._pop_ctrl() - def Case(self, test, value=None): - return _ModuleBuilderCase(self, self._stmt_depth + 1, test, value) + @contextmanager + def Case(self, value=None): + self._check_context("Case", context="Switch") + switch_data = self._get_ctrl("Switch") + if value is None: + value = "-" * len(switch_data["test"]) + if isinstance(value, str) and len(switch_data["test"]) != len(value): + raise SyntaxError("Case value '{}' must have the same width as test (which is {})" + .format(value, len(switch_data["test"]))) + try: + _outer_case, self._statements = self._statements, [] + self._ctrl_context = None + yield + self._flush_ctrl() + switch_data["cases"][value] = self._statements + finally: + self._ctrl_context = "Switch" + self._statements = _outer_case + + def _pop_ctrl(self): + name, data = self._ctrl_stack.pop() + + if name == "If": + if_tests, if_bodies = data["tests"], data["bodies"] - def _flush(self): - if self._stmt_if_cond: tests, cases = [], OrderedDict() - for if_cond, if_case in zip(self._stmt_if_cond + [None], self._stmt_if_bodies): - if if_cond is not None: - if_cond = Value.wrap(if_cond) - if len(if_cond) != 1: - if_cond = if_cond.bool() - tests.append(if_cond) - - if if_cond is not None: - match = ("1" + "-" * (len(tests) - 1)).rjust(len(self._stmt_if_cond), "-") + for if_test, if_case in zip(if_tests + [None], if_bodies): + if if_test is not None: + if_test = Value.wrap(if_test) + if len(if_test) != 1: + if_test = if_test.bool() + tests.append(if_test) + + if if_test is not None: + match = ("1" + "-" * (len(tests) - 1)).rjust(len(if_tests), "-") else: match = "-" * len(tests) cases[match] = if_case + self._statements.append(Switch(Cat(tests), cases)) - if self._stmt_switch_test: - self._statements.append(Switch(self._stmt_switch_test.value, self._stmt_switch_cases)) + if name == "Switch": + switch_test, switch_cases = data["test"], data["cases"] - self._stmt_if_cond = [] - self._stmt_if_bodies = [] - self._stmt_switch_test = None - self._stmt_switch_cases = OrderedDict() + self._statements.append(Switch(switch_test, switch_cases)) def _add_statement(self, assigns, cd_name, depth, compat_mode=False): def cd_human_name(cd_name): @@ -194,9 +231,8 @@ class Module(_ModuleBuilderRoot): else: return cd_name - if depth < self._stmt_depth: - self._flush() - self._stmt_depth = depth + while len(self._ctrl_stack) > self.domain._depth: + self._pop_ctrl() for assign in Statement.wrap(assigns): if not compat_mode and not isinstance(assign, Assign): @@ -222,6 +258,10 @@ class Module(_ModuleBuilderRoot): "a submodule".format(submodule)) self._submodules.append((submodule, name)) + def _flush(self): + while self._ctrl_stack: + self._pop_ctrl() + def lower(self, platform): self._flush() diff --git a/nmigen/test/test_fhdl_dsl.py b/nmigen/test/test_fhdl_dsl.py index 55e494c..ba102dc 100644 --- a/nmigen/test/test_fhdl_dsl.py +++ b/nmigen/test/test_fhdl_dsl.py @@ -21,7 +21,7 @@ class DSLTestCase(unittest.TestCase): def assertRaises(self, exception, msg=None): with super().assertRaises(exception) as cm: yield - if msg: + if msg is not None: # WTF? unittest.assertRaises is completely broken. self.assertEqual(str(cm.exception), msg) @@ -158,6 +158,68 @@ class DSLTestCase(unittest.TestCase): ) """) + def test_If_If(self): + m = Module() + with m.If(self.s1): + m.d.comb += self.c1.eq(1) + with m.If(self.s2): + m.d.comb += self.c2.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (cat (sig s1)) + (case 1 (eq (sig c1) (const 1'd1))) + ) + (switch (cat (sig s2)) + (case 1 (eq (sig c2) (const 1'd1))) + ) + ) + """) + + def test_If_nested_If(self): + m = Module() + with m.If(self.s1): + m.d.comb += self.c1.eq(1) + with m.If(self.s2): + m.d.comb += self.c2.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (cat (sig s1)) + (case 1 (eq (sig c1) (const 1'd1)) + (switch (cat (sig s2)) + (case 1 (eq (sig c2) (const 1'd1))) + ) + ) + ) + ) + """) + + def test_If_dangling_Else(self): + m = Module() + with m.If(self.s1): + m.d.comb += self.c1.eq(1) + with m.If(self.s2): + m.d.comb += self.c2.eq(1) + with m.Else(): + m.d.comb += self.c3.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (cat (sig s1)) + (case 1 + (eq (sig c1) (const 1'd1)) + (switch (cat (sig s2)) + (case 1 (eq (sig c2) (const 1'd1))) + ) + ) + (case - + (eq (sig c3) (const 1'd1)) + ) + ) + ) + """) + def test_Elif_wrong(self): m = Module() with self.assertRaises(SyntaxError, @@ -185,7 +247,64 @@ class DSLTestCase(unittest.TestCase): ) """) - def test_auto_flush(self): + def test_Switch(self): + m = Module() + with m.Switch(self.w1): + with m.Case(3): + m.d.comb += self.c1.eq(1) + with m.Case("11--"): + m.d.comb += self.c2.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (sig w1) + (case 0011 (eq (sig c1) (const 1'd1))) + (case 11-- (eq (sig c2) (const 1'd1))) + ) + ) + """) + + def test_Switch_default(self): + m = Module() + with m.Switch(self.w1): + with m.Case(3): + m.d.comb += self.c1.eq(1) + with m.Case(): + m.d.comb += self.c2.eq(1) + m._flush() + self.assertRepr(m._statements, """ + ( + (switch (sig w1) + (case 0011 (eq (sig c1) (const 1'd1))) + (case ---- (eq (sig c2) (const 1'd1))) + ) + ) + """) + + def test_Case_width_wrong(self): + m = Module() + with m.Switch(self.w1): + with self.assertRaises(SyntaxError, + msg="Case value '--' must have the same width as test (which is 4)"): + with m.Case("--"): + pass + + def test_Case_outside_Switch_wrong(self): + m = Module() + with self.assertRaises(SyntaxError, + msg="Case is not permitted outside of Switch"): + with m.Case(): + pass + + def test_If_inside_Switch_wrong(self): + m = Module() + with m.Switch(self.s1): + with self.assertRaises(SyntaxError, + msg="If is not permitted inside of Switch"): + with m.If(self.s2): + pass + + def test_auto_pop_ctrl(self): m = Module() with m.If(self.w1): m.d.comb += self.c1.eq(1) -- 2.30.2