back.pysim: implement LHS for Part, Slice, Cat, ArrayProxy.
authorwhitequark <whitequark@whitequark.org>
Sun, 16 Dec 2018 10:31:42 +0000 (10:31 +0000)
committerwhitequark <whitequark@whitequark.org>
Sun, 16 Dec 2018 10:31:42 +0000 (10:31 +0000)
nmigen/back/pysim.py
nmigen/hdl/ast.py
nmigen/test/test_sim.py

index 103466bf5f706d6d36e9a5e0faf10455adaa2eae..9a2ae8634a33b25199ded0e0fdb30b63e90aedcc 100644 (file)
@@ -45,8 +45,9 @@ normalize = Const.normalize
 
 
 class _RHSValueCompiler(ValueTransformer):
-    def __init__(self, sensitivity=None):
+    def __init__(self, sensitivity=None, mode="rhs"):
         self.sensitivity = sensitivity
+        self.signal_mode = mode
 
     def on_Const(self, value):
         return lambda state: value.value
@@ -54,7 +55,12 @@ class _RHSValueCompiler(ValueTransformer):
     def on_Signal(self, value):
         if self.sensitivity is not None:
             self.sensitivity.add(value)
-        return lambda state: state.curr[value]
+        if self.signal_mode == "rhs":
+            return lambda state: state.curr[value]
+        elif self.signal_mode == "lhs":
+            return lambda state: state.next[value]
+        else:
+            raise ValueError # :nocov:
 
     def on_ClockSignal(self, value):
         raise NotImplementedError # :nocov:
@@ -160,11 +166,17 @@ class _RHSValueCompiler(ValueTransformer):
 
 
 class _LHSValueCompiler(ValueTransformer):
+    def __init__(self, rhs_compiler):
+        self.rhs_compiler = rhs_compiler
+
     def on_Const(self, value):
         raise TypeError # :nocov:
 
     def on_Signal(self, value):
-        return lambda state, arg: state.set(value, arg)
+        shape = value.shape()
+        def eval(state, rhs):
+            state.set(value, normalize(rhs, shape))
+        return eval
 
     def on_ClockSignal(self, value):
         raise NotImplementedError # :nocov:
@@ -176,37 +188,69 @@ class _LHSValueCompiler(ValueTransformer):
         raise TypeError # :nocov:
 
     def on_Slice(self, value):
-        raise NotImplementedError
+        lhs_r = self.rhs_compiler(value.value)
+        lhs_l = self(value.value)
+        shift = value.start
+        mask  = (1 << (value.end - value.start)) - 1
+        def eval(state, rhs):
+            lhs_value  = lhs_r(state)
+            lhs_value &= ~(mask << shift)
+            lhs_value |= (rhs & mask) << shift
+            lhs_l(state, lhs_value)
+        return eval
 
     def on_Part(self, value):
-        raise NotImplementedError
+        lhs_r = self.rhs_compiler(value.value)
+        lhs_l = self(value.value)
+        shift = self.rhs_compiler(value.offset)
+        mask  = (1 << value.width) - 1
+        def eval(state, rhs):
+            lhs_value   = lhs_r(state)
+            shift_value = shift(state)
+            lhs_value  &= ~(mask << shift_value)
+            lhs_value  |= (rhs & mask) << shift_value
+            lhs_l(state, lhs_value)
+        return eval
 
     def on_Cat(self, value):
-        raise NotImplementedError
+        parts  = []
+        offset = 0
+        for opnd in value.operands:
+            parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
+            offset += len(opnd)
+        def eval(state, rhs):
+            for offset, mask, opnd in parts:
+                opnd(state, (rhs >> offset) & mask)
+        return eval
 
     def on_Repl(self, value):
         raise TypeError # :nocov:
 
     def on_ArrayProxy(self, value):
-        raise NotImplementedError
+        elems = list(map(self, value.elems))
+        index = self.rhs_compiler(value.index)
+        def eval(state, rhs):
+            elems[index(state)](state, rhs)
+        return eval
 
 
 class _StatementCompiler(StatementTransformer):
     def __init__(self):
-        self.sensitivity  = ValueSet()
-        self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
-        self.lhs_compiler = _LHSValueCompiler()
+        self.sensitivity   = ValueSet()
+        self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs")
+        self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs")
+        self.lhs_compiler  = _LHSValueCompiler(self.lrhs_compiler)
 
     def on_Assign(self, stmt):
         shape = stmt.lhs.shape()
         lhs   = self.lhs_compiler(stmt.lhs)
-        rhs   = self.rhs_compiler(stmt.rhs)
+        rhs   = self.rrhs_compiler(stmt.rhs)
         def run(state):
             lhs(state, normalize(rhs(state), shape))
         return run
 
     def on_Switch(self, stmt):
-        test  = self.rhs_compiler(stmt.test)
+        test  = self.rrhs_compiler(stmt.test)
         cases = []
         for value, stmts in stmt.cases.items():
             if "-" in value:
index 0ed355b8af3a70f909bf2351ae5619c336098157..8ccd5522462945669dd4970495c1e485fcd4b53d 100644 (file)
@@ -813,7 +813,7 @@ class Assign(Statement):
         return self.lhs._lhs_signals()
 
     def _rhs_signals(self):
-        return self.rhs._rhs_signals()
+        return self.lhs._rhs_signals() | self.rhs._rhs_signals()
 
     def __repr__(self):
         return "(eq {!r} {!r})".format(self.lhs, self.rhs)
index dc491bf6e64153c6377cf241e55d3cea75ef1522..607c4401ed9a023c0762941d6b75112260c960b3 100644 (file)
@@ -1,20 +1,23 @@
 from .tools import *
+from ..tools import flatten, union
 from ..hdl.ast import *
 from ..hdl.ir import *
 from ..back.pysim import *
 
 
 class SimulatorUnitTestCase(FHDLTestCase):
-    def assertStatement(self, stmt, inputs, output):
+    def assertStatement(self, stmt, inputs, output, reset=0):
         inputs = [Value.wrap(i) for i in inputs]
         output = Value.wrap(output)
 
         isigs = [Signal(i.shape(), name=n) for i, n in zip(inputs, "abcd")]
-        osig  = Signal(output.shape(), name="y")
+        osig  = Signal(output.shape(), name="y", reset=reset)
 
+        stmt = stmt(osig, *isigs)
         frag = Fragment()
-        frag.add_statements(stmt(osig, *isigs))
-        frag.add_driver(osig)
+        frag.add_statements(stmt)
+        for signal in flatten(s._lhs_signals() for s in Statement.wrap(stmt)):
+            frag.add_driver(signal)
 
         with Simulator(frag,
                 vcd_file =open("test.vcd",  "w"),
@@ -130,16 +133,35 @@ class SimulatorUnitTestCase(FHDLTestCase):
         stmt2 = lambda y, a: y.eq(a[2:4])
         self.assertStatement(stmt2, [C(0b10110100, 8)], C(0b01, 2))
 
+    def test_slice_lhs(self):
+        stmt1 = lambda y, a: y[2].eq(a)
+        self.assertStatement(stmt1, [C(0b0,  1)], C(0b11111011, 8), reset=0b11111111)
+        stmt2 = lambda y, a: y[2:4].eq(a)
+        self.assertStatement(stmt2, [C(0b01, 2)], C(0b11110111, 8), reset=0b11111011)
+
     def test_part(self):
         stmt = lambda y, a, b: y.eq(a.part(b, 3))
         self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3))
         self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3))
         self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3))
 
+    def test_part_lhs(self):
+        stmt = lambda y, a, b: y.part(a, 3).eq(b)
+        self.assertStatement(stmt, [C(0), C(0b100, 3)], C(0b11111100, 8), reset=0b11111111)
+        self.assertStatement(stmt, [C(2), C(0b101, 3)], C(0b11110111, 8), reset=0b11111111)
+        self.assertStatement(stmt, [C(3), C(0b110, 3)], C(0b11110111, 8), reset=0b11111111)
+
     def test_cat(self):
         stmt = lambda y, *xs: y.eq(Cat(*xs))
         self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4))
 
+    def test_cat_lhs(self):
+        l = Signal(3)
+        m = Signal(3)
+        n = Signal(3)
+        stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))]
+        self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9))
+
     def test_repl(self):
         stmt = lambda y, a: y.eq(Repl(a, 3))
         self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6))
@@ -151,6 +173,16 @@ class SimulatorUnitTestCase(FHDLTestCase):
         self.assertStatement(stmt, [C(1)], C(4))
         self.assertStatement(stmt, [C(2)], C(10))
 
+    def test_array_lhs(self):
+        l = Signal(3, reset=1)
+        m = Signal(3, reset=4)
+        n = Signal(3, reset=7)
+        array = Array([l, m, n])
+        stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))]
+        self.assertStatement(stmt, [C(0), C(0b000)], C(0b111100000))
+        self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001))
+        self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001))
+
     def test_array_index(self):
         array = Array(Array(x * y for y in range(10)) for x in range(10))
         stmt = lambda y, a, b: y.eq(array[a][b])