From d4e8d3e95a9d95a060b89ebb4489ae278c123938 Mon Sep 17 00:00:00 2001 From: whitequark Date: Sun, 16 Dec 2018 10:31:42 +0000 Subject: [PATCH] back.pysim: implement LHS for Part, Slice, Cat, ArrayProxy. --- nmigen/back/pysim.py | 68 +++++++++++++++++++++++++++++++++-------- nmigen/hdl/ast.py | 2 +- nmigen/test/test_sim.py | 40 +++++++++++++++++++++--- 3 files changed, 93 insertions(+), 17 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 103466b..9a2ae86 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -45,8 +45,9 @@ normalize = Const.normalize class _RHSValueCompiler(ValueTransformer): - def __init__(self, sensitivity=None): + def __init__(self, sensitivity=None, mode="rhs"): self.sensitivity = sensitivity + self.signal_mode = mode def on_Const(self, value): return lambda state: value.value @@ -54,7 +55,12 @@ class _RHSValueCompiler(ValueTransformer): def on_Signal(self, value): if self.sensitivity is not None: self.sensitivity.add(value) - return lambda state: state.curr[value] + if self.signal_mode == "rhs": + return lambda state: state.curr[value] + elif self.signal_mode == "lhs": + return lambda state: state.next[value] + else: + raise ValueError # :nocov: def on_ClockSignal(self, value): raise NotImplementedError # :nocov: @@ -160,11 +166,17 @@ class _RHSValueCompiler(ValueTransformer): class _LHSValueCompiler(ValueTransformer): + def __init__(self, rhs_compiler): + self.rhs_compiler = rhs_compiler + def on_Const(self, value): raise TypeError # :nocov: def on_Signal(self, value): - return lambda state, arg: state.set(value, arg) + shape = value.shape() + def eval(state, rhs): + state.set(value, normalize(rhs, shape)) + return eval def on_ClockSignal(self, value): raise NotImplementedError # :nocov: @@ -176,37 +188,69 @@ class _LHSValueCompiler(ValueTransformer): raise TypeError # :nocov: def on_Slice(self, value): - raise NotImplementedError + lhs_r = self.rhs_compiler(value.value) + lhs_l = self(value.value) + shift = value.start + mask = (1 << (value.end - value.start)) - 1 + def eval(state, rhs): + lhs_value = lhs_r(state) + lhs_value &= ~(mask << shift) + lhs_value |= (rhs & mask) << shift + lhs_l(state, lhs_value) + return eval def on_Part(self, value): - raise NotImplementedError + lhs_r = self.rhs_compiler(value.value) + lhs_l = self(value.value) + shift = self.rhs_compiler(value.offset) + mask = (1 << value.width) - 1 + def eval(state, rhs): + lhs_value = lhs_r(state) + shift_value = shift(state) + lhs_value &= ~(mask << shift_value) + lhs_value |= (rhs & mask) << shift_value + lhs_l(state, lhs_value) + return eval def on_Cat(self, value): - raise NotImplementedError + parts = [] + offset = 0 + for opnd in value.operands: + parts.append((offset, (1 << len(opnd)) - 1, self(opnd))) + offset += len(opnd) + def eval(state, rhs): + for offset, mask, opnd in parts: + opnd(state, (rhs >> offset) & mask) + return eval def on_Repl(self, value): raise TypeError # :nocov: def on_ArrayProxy(self, value): - raise NotImplementedError + elems = list(map(self, value.elems)) + index = self.rhs_compiler(value.index) + def eval(state, rhs): + elems[index(state)](state, rhs) + return eval class _StatementCompiler(StatementTransformer): def __init__(self): - self.sensitivity = ValueSet() - self.rhs_compiler = _RHSValueCompiler(self.sensitivity) - self.lhs_compiler = _LHSValueCompiler() + self.sensitivity = ValueSet() + self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs") + self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs") + self.lhs_compiler = _LHSValueCompiler(self.lrhs_compiler) def on_Assign(self, stmt): shape = stmt.lhs.shape() lhs = self.lhs_compiler(stmt.lhs) - rhs = self.rhs_compiler(stmt.rhs) + rhs = self.rrhs_compiler(stmt.rhs) def run(state): lhs(state, normalize(rhs(state), shape)) return run def on_Switch(self, stmt): - test = self.rhs_compiler(stmt.test) + test = self.rrhs_compiler(stmt.test) cases = [] for value, stmts in stmt.cases.items(): if "-" in value: diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index 0ed355b..8ccd552 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -813,7 +813,7 @@ class Assign(Statement): return self.lhs._lhs_signals() def _rhs_signals(self): - return self.rhs._rhs_signals() + return self.lhs._rhs_signals() | self.rhs._rhs_signals() def __repr__(self): return "(eq {!r} {!r})".format(self.lhs, self.rhs) diff --git a/nmigen/test/test_sim.py b/nmigen/test/test_sim.py index dc491bf..607c440 100644 --- a/nmigen/test/test_sim.py +++ b/nmigen/test/test_sim.py @@ -1,20 +1,23 @@ from .tools import * +from ..tools import flatten, union from ..hdl.ast import * from ..hdl.ir import * from ..back.pysim import * class SimulatorUnitTestCase(FHDLTestCase): - def assertStatement(self, stmt, inputs, output): + def assertStatement(self, stmt, inputs, output, reset=0): inputs = [Value.wrap(i) for i in inputs] output = Value.wrap(output) isigs = [Signal(i.shape(), name=n) for i, n in zip(inputs, "abcd")] - osig = Signal(output.shape(), name="y") + osig = Signal(output.shape(), name="y", reset=reset) + stmt = stmt(osig, *isigs) frag = Fragment() - frag.add_statements(stmt(osig, *isigs)) - frag.add_driver(osig) + frag.add_statements(stmt) + for signal in flatten(s._lhs_signals() for s in Statement.wrap(stmt)): + frag.add_driver(signal) with Simulator(frag, vcd_file =open("test.vcd", "w"), @@ -130,16 +133,35 @@ class SimulatorUnitTestCase(FHDLTestCase): stmt2 = lambda y, a: y.eq(a[2:4]) self.assertStatement(stmt2, [C(0b10110100, 8)], C(0b01, 2)) + def test_slice_lhs(self): + stmt1 = lambda y, a: y[2].eq(a) + self.assertStatement(stmt1, [C(0b0, 1)], C(0b11111011, 8), reset=0b11111111) + stmt2 = lambda y, a: y[2:4].eq(a) + self.assertStatement(stmt2, [C(0b01, 2)], C(0b11110111, 8), reset=0b11111011) + def test_part(self): stmt = lambda y, a, b: y.eq(a.part(b, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3)) self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3)) + def test_part_lhs(self): + stmt = lambda y, a, b: y.part(a, 3).eq(b) + self.assertStatement(stmt, [C(0), C(0b100, 3)], C(0b11111100, 8), reset=0b11111111) + self.assertStatement(stmt, [C(2), C(0b101, 3)], C(0b11110111, 8), reset=0b11111111) + self.assertStatement(stmt, [C(3), C(0b110, 3)], C(0b11110111, 8), reset=0b11111111) + def test_cat(self): stmt = lambda y, *xs: y.eq(Cat(*xs)) self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4)) + def test_cat_lhs(self): + l = Signal(3) + m = Signal(3) + n = Signal(3) + stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))] + self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9)) + def test_repl(self): stmt = lambda y, a: y.eq(Repl(a, 3)) self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6)) @@ -151,6 +173,16 @@ class SimulatorUnitTestCase(FHDLTestCase): self.assertStatement(stmt, [C(1)], C(4)) self.assertStatement(stmt, [C(2)], C(10)) + def test_array_lhs(self): + l = Signal(3, reset=1) + m = Signal(3, reset=4) + n = Signal(3, reset=7) + array = Array([l, m, n]) + stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))] + self.assertStatement(stmt, [C(0), C(0b000)], C(0b111100000)) + self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001)) + self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001)) + def test_array_index(self): array = Array(Array(x * y for y in range(10)) for x in range(10)) stmt = lambda y, a, b: y.eq(array[a][b]) -- 2.30.2