class _RHSValueCompiler(ValueTransformer):
- def __init__(self, sensitivity=None):
+ def __init__(self, sensitivity=None, mode="rhs"):
self.sensitivity = sensitivity
+ self.signal_mode = mode
def on_Const(self, value):
return lambda state: value.value
def on_Signal(self, value):
if self.sensitivity is not None:
self.sensitivity.add(value)
- return lambda state: state.curr[value]
+ if self.signal_mode == "rhs":
+ return lambda state: state.curr[value]
+ elif self.signal_mode == "lhs":
+ return lambda state: state.next[value]
+ else:
+ raise ValueError # :nocov:
def on_ClockSignal(self, value):
raise NotImplementedError # :nocov:
class _LHSValueCompiler(ValueTransformer):
+ def __init__(self, rhs_compiler):
+ self.rhs_compiler = rhs_compiler
+
def on_Const(self, value):
raise TypeError # :nocov:
def on_Signal(self, value):
- return lambda state, arg: state.set(value, arg)
+ shape = value.shape()
+ def eval(state, rhs):
+ state.set(value, normalize(rhs, shape))
+ return eval
def on_ClockSignal(self, value):
raise NotImplementedError # :nocov:
raise TypeError # :nocov:
def on_Slice(self, value):
- raise NotImplementedError
+ lhs_r = self.rhs_compiler(value.value)
+ lhs_l = self(value.value)
+ shift = value.start
+ mask = (1 << (value.end - value.start)) - 1
+ def eval(state, rhs):
+ lhs_value = lhs_r(state)
+ lhs_value &= ~(mask << shift)
+ lhs_value |= (rhs & mask) << shift
+ lhs_l(state, lhs_value)
+ return eval
def on_Part(self, value):
- raise NotImplementedError
+ lhs_r = self.rhs_compiler(value.value)
+ lhs_l = self(value.value)
+ shift = self.rhs_compiler(value.offset)
+ mask = (1 << value.width) - 1
+ def eval(state, rhs):
+ lhs_value = lhs_r(state)
+ shift_value = shift(state)
+ lhs_value &= ~(mask << shift_value)
+ lhs_value |= (rhs & mask) << shift_value
+ lhs_l(state, lhs_value)
+ return eval
def on_Cat(self, value):
- raise NotImplementedError
+ parts = []
+ offset = 0
+ for opnd in value.operands:
+ parts.append((offset, (1 << len(opnd)) - 1, self(opnd)))
+ offset += len(opnd)
+ def eval(state, rhs):
+ for offset, mask, opnd in parts:
+ opnd(state, (rhs >> offset) & mask)
+ return eval
def on_Repl(self, value):
raise TypeError # :nocov:
def on_ArrayProxy(self, value):
- raise NotImplementedError
+ elems = list(map(self, value.elems))
+ index = self.rhs_compiler(value.index)
+ def eval(state, rhs):
+ elems[index(state)](state, rhs)
+ return eval
class _StatementCompiler(StatementTransformer):
def __init__(self):
- self.sensitivity = ValueSet()
- self.rhs_compiler = _RHSValueCompiler(self.sensitivity)
- self.lhs_compiler = _LHSValueCompiler()
+ self.sensitivity = ValueSet()
+ self.rrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="rhs")
+ self.lrhs_compiler = _RHSValueCompiler(self.sensitivity, mode="lhs")
+ self.lhs_compiler = _LHSValueCompiler(self.lrhs_compiler)
def on_Assign(self, stmt):
shape = stmt.lhs.shape()
lhs = self.lhs_compiler(stmt.lhs)
- rhs = self.rhs_compiler(stmt.rhs)
+ rhs = self.rrhs_compiler(stmt.rhs)
def run(state):
lhs(state, normalize(rhs(state), shape))
return run
def on_Switch(self, stmt):
- test = self.rhs_compiler(stmt.test)
+ test = self.rrhs_compiler(stmt.test)
cases = []
for value, stmts in stmt.cases.items():
if "-" in value:
from .tools import *
+from ..tools import flatten, union
from ..hdl.ast import *
from ..hdl.ir import *
from ..back.pysim import *
class SimulatorUnitTestCase(FHDLTestCase):
- def assertStatement(self, stmt, inputs, output):
+ def assertStatement(self, stmt, inputs, output, reset=0):
inputs = [Value.wrap(i) for i in inputs]
output = Value.wrap(output)
isigs = [Signal(i.shape(), name=n) for i, n in zip(inputs, "abcd")]
- osig = Signal(output.shape(), name="y")
+ osig = Signal(output.shape(), name="y", reset=reset)
+ stmt = stmt(osig, *isigs)
frag = Fragment()
- frag.add_statements(stmt(osig, *isigs))
- frag.add_driver(osig)
+ frag.add_statements(stmt)
+ for signal in flatten(s._lhs_signals() for s in Statement.wrap(stmt)):
+ frag.add_driver(signal)
with Simulator(frag,
vcd_file =open("test.vcd", "w"),
stmt2 = lambda y, a: y.eq(a[2:4])
self.assertStatement(stmt2, [C(0b10110100, 8)], C(0b01, 2))
+ def test_slice_lhs(self):
+ stmt1 = lambda y, a: y[2].eq(a)
+ self.assertStatement(stmt1, [C(0b0, 1)], C(0b11111011, 8), reset=0b11111111)
+ stmt2 = lambda y, a: y[2:4].eq(a)
+ self.assertStatement(stmt2, [C(0b01, 2)], C(0b11110111, 8), reset=0b11111011)
+
def test_part(self):
stmt = lambda y, a, b: y.eq(a.part(b, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(0)], C(0b100, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(2)], C(0b101, 3))
self.assertStatement(stmt, [C(0b10110100, 8), C(3)], C(0b110, 3))
+ def test_part_lhs(self):
+ stmt = lambda y, a, b: y.part(a, 3).eq(b)
+ self.assertStatement(stmt, [C(0), C(0b100, 3)], C(0b11111100, 8), reset=0b11111111)
+ self.assertStatement(stmt, [C(2), C(0b101, 3)], C(0b11110111, 8), reset=0b11111111)
+ self.assertStatement(stmt, [C(3), C(0b110, 3)], C(0b11110111, 8), reset=0b11111111)
+
def test_cat(self):
stmt = lambda y, *xs: y.eq(Cat(*xs))
self.assertStatement(stmt, [C(0b10, 2), C(0b01, 2)], C(0b0110, 4))
+ def test_cat_lhs(self):
+ l = Signal(3)
+ m = Signal(3)
+ n = Signal(3)
+ stmt = lambda y, a: [Cat(l, m, n).eq(a), y.eq(Cat(n, m, l))]
+ self.assertStatement(stmt, [C(0b100101110, 9)], C(0b110101100, 9))
+
def test_repl(self):
stmt = lambda y, a: y.eq(Repl(a, 3))
self.assertStatement(stmt, [C(0b10, 2)], C(0b101010, 6))
self.assertStatement(stmt, [C(1)], C(4))
self.assertStatement(stmt, [C(2)], C(10))
+ def test_array_lhs(self):
+ l = Signal(3, reset=1)
+ m = Signal(3, reset=4)
+ n = Signal(3, reset=7)
+ array = Array([l, m, n])
+ stmt = lambda y, a, b: [array[a].eq(b), y.eq(Cat(*array))]
+ self.assertStatement(stmt, [C(0), C(0b000)], C(0b111100000))
+ self.assertStatement(stmt, [C(1), C(0b010)], C(0b111010001))
+ self.assertStatement(stmt, [C(2), C(0b100)], C(0b100100001))
+
def test_array_index(self):
array = Array(Array(x * y for y in range(10)) for x in range(10))
stmt = lambda y, a, b: y.eq(array[a][b])