From dc486ad8b9ef3b2716bbca04d34c94caed45f0e4 Mon Sep 17 00:00:00 2001 From: whitequark Date: Thu, 13 Dec 2018 02:04:44 +0000 Subject: [PATCH] fhdl.ast: add tests for most logic. --- nmigen/__init__.py | 0 nmigen/fhdl/ast.py | 83 ++++----- nmigen/test/__init__.py | 0 nmigen/test/test_fhdl.py | 358 +++++++++++++++++++++++++++++++++++++++ 4 files changed, 391 insertions(+), 50 deletions(-) create mode 100644 nmigen/__init__.py create mode 100644 nmigen/test/__init__.py create mode 100644 nmigen/test/test_fhdl.py diff --git a/nmigen/__init__.py b/nmigen/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nmigen/fhdl/ast.py b/nmigen/fhdl/ast.py index 74aaadf..791598a 100644 --- a/nmigen/fhdl/ast.py +++ b/nmigen/fhdl/ast.py @@ -36,17 +36,6 @@ class Value: .format(repr(obj), type(obj))) def __bool__(self): - # Special case: Consts and Signals are part of a set or used as - # dictionary keys, and Python needs to check for equality. - if isinstance(self, Operator) and self.op == "==": - a, b = self.operands - if isinstance(a, Const) and isinstance(b, Const): - return a.value == b.value - if isinstance(a, Signal) and isinstance(b, Signal): - return a is b - if (isinstance(a, Const) and isinstance(b, Signal) - or isinstance(a, Signal) and isinstance(b, Const)): - return False raise TypeError("Attempted to convert Migen value to boolean") def __invert__(self): @@ -187,13 +176,13 @@ class Value: >>> Value.bits_sign(C(0xaa)) 8, False """ - raise TypeError("Cannot calculate bit length of {!r}".format(self)) + raise NotImplementedError # :nocov: def _lhs_signals(self): raise TypeError("Value {!r} cannot be used in assignments".format(self)) def _rhs_signals(self): - raise NotImplementedError + raise NotImplementedError # :nocov: def __hash__(self): raise TypeError("Unhashable type: {}".format(type(self).__name__)) @@ -232,12 +221,6 @@ class Const(Value): def _rhs_signals(self): return ValueSet() - def __eq__(self, other): - return self.value == other.value - - def __hash__(self): - return hash(self.value) - def __repr__(self): return "(const {}'{}d{})".format(self.nbits, "s" if self.signed else "", self.value) @@ -301,23 +284,20 @@ class Operator(Value): elif self.op == "&" or self.op == "^" or self.op == "|": return self._bitwise_binary_bits_sign(*obs) elif (self.op == "<" or self.op == "<=" or self.op == "==" or self.op == "!=" or - self.op == ">" or self.op == ">="): + self.op == ">" or self.op == ">=" or self.op == "b"): return 1, False elif self.op == "~": return obs[0] elif self.op == "m": - return _bitwise_binary_bits_sign(obs[1], obs[2]) + return self._bitwise_binary_bits_sign(obs[1], obs[2]) else: - raise TypeError + raise TypeError # :nocov: def _rhs_signals(self): return union(op._rhs_signals() for op in self.operands) def __repr__(self): - if len(self.operands) == 1: - return "({} {})".format(self.op, self.operands[0]) - elif len(self.operands) == 2: - return "({} {} {})".format(self.op, self.operands[0], self.operands[1]) + return "({} {})".format(self.op, " ".join(map(repr, self.operands))) def Mux(sel, val1, val0): @@ -470,7 +450,10 @@ class Repl(Value): return len(self.value) * self.count, False def _rhs_signals(self): - return value._rhs_signals() + return self.value._rhs_signals() + + def __repr__(self): + return "(repl {!r} {})".format(self.value, self.count) class Signal(Value, DUID): @@ -538,18 +521,18 @@ class Signal(Value, DUID): self.signed = min < 0 or max < 0 self.nbits = builtins.max(bits_for(min, self.signed), bits_for(max, self.signed)) - elif isinstance(bits_sign, int): - if not (min is None or max is None): - raise ValueError("Only one of bits/signedness or bounds may be specified") - self.nbits, self.signed = bits_sign, False - else: - self.nbits, self.signed = bits_sign + if not (min is None and max is None): + raise ValueError("Only one of bits/signedness or bounds may be specified") + if isinstance(bits_sign, int): + self.nbits, self.signed = bits_sign, False + else: + self.nbits, self.signed = bits_sign if not isinstance(self.nbits, int) or self.nbits < 0: raise TypeError("Width must be a positive integer, not {!r}".format(self.nbits)) - self.reset = reset - self.reset_less = reset_less + self.reset = int(reset) + self.reset_less = bool(reset_less) self.attrs = OrderedDict(() if attrs is None else attrs) @@ -564,7 +547,7 @@ class Signal(Value, DUID): """ kw = dict(bits_sign=cls.wrap(other).bits_sign()) if isinstance(other, cls): - kw.update(reset=other.reset.value, reset_less=other.reset_less, attrs=other.attrs) + kw.update(reset=other.reset, reset_less=other.reset_less, attrs=other.attrs) kw.update(kwargs) return cls(**kw) @@ -589,17 +572,17 @@ class ClockSignal(Value): Parameters ---------- - cd : str - Clock domain to obtain a clock signal for. Defaults to `"sys"`. + domain : str + Clock domain to obtain a clock signal for. Defaults to `"sync"`. """ - def __init__(self, cd="sys"): + def __init__(self, domain="sync"): super().__init__() - if not isinstance(cd, str): - raise TypeError("Clock domain name must be a string, not {!r}".format(cd)) - self.cd = cd + if not isinstance(domain, str): + raise TypeError("Clock domain name must be a string, not {!r}".format(domain)) + self.domain = domain def __repr__(self): - return "(clk {})".format(self.cd) + return "(clk {})".format(self.domain) class ResetSignal(Value): @@ -610,17 +593,17 @@ class ResetSignal(Value): Parameters ---------- - cd : str - Clock domain to obtain a reset signal for. Defaults to `"sys"`. + domain : str + Clock domain to obtain a reset signal for. Defaults to `"sync"`. """ - def __init__(self, cd="sys"): + def __init__(self, domain="sync"): super().__init__() - if not isinstance(cd, str): - raise TypeError("Clock domain name must be a string, not {!r}".format(cd)) - self.cd = cd + if not isinstance(domain, str): + raise TypeError("Clock domain name must be a string, not {!r}".format(domain)) + self.domain = domain def __repr__(self): - return "(rst {})".format(self.cd) + return "(reset {})".format(self.domain) class Statement: diff --git a/nmigen/test/__init__.py b/nmigen/test/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/nmigen/test/test_fhdl.py b/nmigen/test/test_fhdl.py new file mode 100644 index 0000000..29bde49 --- /dev/null +++ b/nmigen/test/test_fhdl.py @@ -0,0 +1,358 @@ +import unittest + +from nmigen.fhdl.ast import * + + +class ValueTestCase(unittest.TestCase): + def test_wrap(self): + self.assertIsInstance(Value.wrap(0), Const) + self.assertIsInstance(Value.wrap(True), Const) + c = Const(0) + self.assertIs(Value.wrap(c), c) + with self.assertRaises(TypeError): + Value.wrap("str") + + def test_bool(self): + with self.assertRaises(TypeError): + if Const(0): + pass + + def test_len(self): + self.assertEqual(len(Const(10)), 4) + + def test_getitem_int(self): + s1 = Const(10)[0] + self.assertIsInstance(s1, Slice) + self.assertEqual(s1.start, 0) + self.assertEqual(s1.end, 1) + s2 = Const(10)[-1] + self.assertIsInstance(s2, Slice) + self.assertEqual(s2.start, 3) + self.assertEqual(s2.end, 4) + with self.assertRaises(IndexError): + Const(10)[5] + + def test_getitem_slice(self): + s1 = Const(10)[1:3] + self.assertIsInstance(s1, Slice) + self.assertEqual(s1.start, 1) + self.assertEqual(s1.end, 3) + s2 = Const(10)[1:-2] + self.assertIsInstance(s2, Slice) + self.assertEqual(s2.start, 1) + self.assertEqual(s2.end, 2) + s3 = Const(31)[::2] + self.assertIsInstance(s3, Cat) + self.assertIsInstance(s3.operands[0], Slice) + self.assertEqual(s3.operands[0].start, 0) + self.assertEqual(s3.operands[0].end, 1) + self.assertIsInstance(s3.operands[1], Slice) + self.assertEqual(s3.operands[1].start, 2) + self.assertEqual(s3.operands[1].end, 3) + self.assertIsInstance(s3.operands[2], Slice) + self.assertEqual(s3.operands[2].start, 4) + self.assertEqual(s3.operands[2].end, 5) + + def test_getitem_wrong(self): + with self.assertRaises(TypeError): + Const(31)["str"] + + +class ConstTestCase(unittest.TestCase): + def test_bits_sign(self): + self.assertEqual(Const(0).bits_sign(), (0, False)) + self.assertEqual(Const(1).bits_sign(), (1, False)) + self.assertEqual(Const(10).bits_sign(), (4, False)) + self.assertEqual(Const(-10).bits_sign(), (4, True)) + + self.assertEqual(Const(1, 4).bits_sign(), (4, False)) + self.assertEqual(Const(1, (4, True)).bits_sign(), (4, True)) + + with self.assertRaises(TypeError): + Const(1, -1) + + def test_value(self): + self.assertEqual(Const(10).value, 10) + + def test_repr(self): + self.assertEqual(repr(Const(10)), "(const 4'd10)") + self.assertEqual(repr(Const(-10)), "(const 4'sd-10)") + + def test_hash(self): + with self.assertRaises(TypeError): + hash(Const(0)) + + +class OperatorTestCase(unittest.TestCase): + def test_invert(self): + v = ~Const(0, 4) + self.assertEqual(repr(v), "(~ (const 4'd0))") + self.assertEqual(v.bits_sign(), (4, False)) + + def test_neg(self): + v1 = -Const(0, (4, False)) + self.assertEqual(repr(v1), "(- (const 4'd0))") + self.assertEqual(v1.bits_sign(), (5, True)) + v2 = -Const(0, (4, True)) + self.assertEqual(repr(v2), "(- (const 4'sd0))") + self.assertEqual(v2.bits_sign(), (4, True)) + + def test_add(self): + v1 = Const(0, (4, False)) + Const(0, (6, False)) + self.assertEqual(repr(v1), "(+ (const 4'd0) (const 6'd0))") + self.assertEqual(v1.bits_sign(), (7, False)) + v2 = Const(0, (4, True)) + Const(0, (6, True)) + self.assertEqual(v2.bits_sign(), (7, True)) + v3 = Const(0, (4, True)) + Const(0, (4, False)) + self.assertEqual(v3.bits_sign(), (6, True)) + v4 = Const(0, (4, False)) + Const(0, (4, True)) + self.assertEqual(v4.bits_sign(), (6, True)) + v5 = 10 + Const(0, 4) + self.assertEqual(v5.bits_sign(), (5, False)) + + def test_sub(self): + v1 = Const(0, (4, False)) - Const(0, (6, False)) + self.assertEqual(repr(v1), "(- (const 4'd0) (const 6'd0))") + self.assertEqual(v1.bits_sign(), (7, False)) + v2 = Const(0, (4, True)) - Const(0, (6, True)) + self.assertEqual(v2.bits_sign(), (7, True)) + v3 = Const(0, (4, True)) - Const(0, (4, False)) + self.assertEqual(v3.bits_sign(), (6, True)) + v4 = Const(0, (4, False)) - Const(0, (4, True)) + self.assertEqual(v4.bits_sign(), (6, True)) + v5 = 10 - Const(0, 4) + self.assertEqual(v5.bits_sign(), (5, False)) + + def test_mul(self): + v1 = Const(0, (4, False)) * Const(0, (6, False)) + self.assertEqual(repr(v1), "(* (const 4'd0) (const 6'd0))") + self.assertEqual(v1.bits_sign(), (10, False)) + v2 = Const(0, (4, True)) * Const(0, (6, True)) + self.assertEqual(v2.bits_sign(), (9, True)) + v3 = Const(0, (4, True)) * Const(0, (4, False)) + self.assertEqual(v3.bits_sign(), (8, True)) + v5 = 10 * Const(0, 4) + self.assertEqual(v5.bits_sign(), (8, False)) + + def test_and(self): + v1 = Const(0, (4, False)) & Const(0, (6, False)) + self.assertEqual(repr(v1), "(& (const 4'd0) (const 6'd0))") + self.assertEqual(v1.bits_sign(), (6, False)) + v2 = Const(0, (4, True)) & Const(0, (6, True)) + self.assertEqual(v2.bits_sign(), (6, True)) + v3 = Const(0, (4, True)) & Const(0, (4, False)) + self.assertEqual(v3.bits_sign(), (5, True)) + v4 = Const(0, (4, False)) & Const(0, (4, True)) + self.assertEqual(v4.bits_sign(), (5, True)) + v5 = 10 & Const(0, 4) + self.assertEqual(v5.bits_sign(), (4, False)) + + def test_or(self): + v1 = Const(0, (4, False)) | Const(0, (6, False)) + self.assertEqual(repr(v1), "(| (const 4'd0) (const 6'd0))") + self.assertEqual(v1.bits_sign(), (6, False)) + v2 = Const(0, (4, True)) | Const(0, (6, True)) + self.assertEqual(v2.bits_sign(), (6, True)) + v3 = Const(0, (4, True)) | Const(0, (4, False)) + self.assertEqual(v3.bits_sign(), (5, True)) + v4 = Const(0, (4, False)) | Const(0, (4, True)) + self.assertEqual(v4.bits_sign(), (5, True)) + v5 = 10 | Const(0, 4) + self.assertEqual(v5.bits_sign(), (4, False)) + + def test_xor(self): + v1 = Const(0, (4, False)) ^ Const(0, (6, False)) + self.assertEqual(repr(v1), "(^ (const 4'd0) (const 6'd0))") + self.assertEqual(v1.bits_sign(), (6, False)) + v2 = Const(0, (4, True)) ^ Const(0, (6, True)) + self.assertEqual(v2.bits_sign(), (6, True)) + v3 = Const(0, (4, True)) ^ Const(0, (4, False)) + self.assertEqual(v3.bits_sign(), (5, True)) + v4 = Const(0, (4, False)) ^ Const(0, (4, True)) + self.assertEqual(v4.bits_sign(), (5, True)) + v5 = 10 ^ Const(0, 4) + self.assertEqual(v5.bits_sign(), (4, False)) + + def test_lt(self): + v = Const(0, 4) < Const(0, 6) + self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))") + self.assertEqual(v.bits_sign(), (1, False)) + + def test_le(self): + v = Const(0, 4) <= Const(0, 6) + self.assertEqual(repr(v), "(<= (const 4'd0) (const 6'd0))") + self.assertEqual(v.bits_sign(), (1, False)) + + def test_gt(self): + v = Const(0, 4) > Const(0, 6) + self.assertEqual(repr(v), "(> (const 4'd0) (const 6'd0))") + self.assertEqual(v.bits_sign(), (1, False)) + + def test_ge(self): + v = Const(0, 4) >= Const(0, 6) + self.assertEqual(repr(v), "(>= (const 4'd0) (const 6'd0))") + self.assertEqual(v.bits_sign(), (1, False)) + + def test_eq(self): + v = Const(0, 4) == Const(0, 6) + self.assertEqual(repr(v), "(== (const 4'd0) (const 6'd0))") + self.assertEqual(v.bits_sign(), (1, False)) + + def test_ne(self): + v = Const(0, 4) != Const(0, 6) + self.assertEqual(repr(v), "(!= (const 4'd0) (const 6'd0))") + self.assertEqual(v.bits_sign(), (1, False)) + + def test_mux(self): + s = Const(0) + v1 = Mux(s, Const(0, (4, False)), Const(0, (6, False))) + self.assertEqual(repr(v1), "(m (const 0'd0) (const 4'd0) (const 6'd0))") + self.assertEqual(v1.bits_sign(), (6, False)) + v2 = Mux(s, Const(0, (4, True)), Const(0, (6, True))) + self.assertEqual(v2.bits_sign(), (6, True)) + v3 = Mux(s, Const(0, (4, True)), Const(0, (4, False))) + self.assertEqual(v3.bits_sign(), (5, True)) + v4 = Mux(s, Const(0, (4, False)), Const(0, (4, True))) + self.assertEqual(v4.bits_sign(), (5, True)) + + def test_bool(self): + v = Const(0).bool() + self.assertEqual(repr(v), "(b (const 0'd0))") + self.assertEqual(v.bits_sign(), (1, False)) + + def test_hash(self): + with self.assertRaises(TypeError): + hash(Const(0) + Const(0)) + + +class SliceTestCase(unittest.TestCase): + def test_bits_sign(self): + s1 = Const(10)[2] + self.assertEqual(s1.bits_sign(), (1, False)) + s2 = Const(-10)[0:2] + self.assertEqual(s2.bits_sign(), (2, False)) + + def test_repr(self): + s1 = Const(10)[2] + self.assertEqual(repr(s1), "(slice (const 4'd10) 2:3)") + + +class CatTestCase(unittest.TestCase): + def test_bits_sign(self): + c1 = Cat(Const(10)) + self.assertEqual(c1.bits_sign(), (4, False)) + c2 = Cat(Const(10), Const(1)) + self.assertEqual(c2.bits_sign(), (5, False)) + c3 = Cat(Const(10), Const(1), Const(0)) + self.assertEqual(c3.bits_sign(), (5, False)) + + def test_repr(self): + c1 = Cat(Const(10), Const(1)) + self.assertEqual(repr(c1), "(cat (const 4'd10) (const 1'd1))") + + +class ReplTestCase(unittest.TestCase): + def test_bits_sign(self): + r1 = Repl(Const(10), 3) + self.assertEqual(r1.bits_sign(), (12, False)) + + def test_count_wrong(self): + with self.assertRaises(TypeError): + Repl(Const(10), -1) + with self.assertRaises(TypeError): + Repl(Const(10), "str") + + def test_repr(self): + r1 = Repl(Const(10), 3) + self.assertEqual(repr(r1), "(repl (const 4'd10) 3)") + + +class SignalTestCase(unittest.TestCase): + def test_bits_sign(self): + s1 = Signal() + self.assertEqual(s1.bits_sign(), (1, False)) + s2 = Signal(2) + self.assertEqual(s2.bits_sign(), (2, False)) + s3 = Signal((2, False)) + self.assertEqual(s3.bits_sign(), (2, False)) + s4 = Signal((2, True)) + self.assertEqual(s4.bits_sign(), (2, True)) + s5 = Signal(max=16) + self.assertEqual(s5.bits_sign(), (4, False)) + s6 = Signal(min=4, max=16) + self.assertEqual(s6.bits_sign(), (4, False)) + s7 = Signal(min=-4, max=16) + self.assertEqual(s7.bits_sign(), (5, True)) + s8 = Signal(min=-20, max=16) + self.assertEqual(s8.bits_sign(), (6, True)) + + with self.assertRaises(ValueError): + Signal(min=10, max=4) + with self.assertRaises(ValueError): + Signal(2, min=10) + with self.assertRaises(TypeError): + Signal(-10) + + def test_name(self): + s1 = Signal() + self.assertEqual(s1.name, "s1") + s2 = Signal(name="sig") + self.assertEqual(s2.name, "sig") + + def test_reset(self): + s1 = Signal(4, reset=0b111, reset_less=True) + self.assertEqual(s1.reset, 0b111) + self.assertEqual(s1.reset_less, True) + + def test_attrs(self): + s1 = Signal() + self.assertEqual(s1.attrs, {}) + s2 = Signal(attrs={"no_retiming": True}) + self.assertEqual(s2.attrs, {"no_retiming": True}) + + def test_repr(self): + s1 = Signal() + self.assertEqual(repr(s1), "(sig s1)") + + def test_like(self): + s1 = Signal.like(Signal(4)) + self.assertEqual(s1.bits_sign(), (4, False)) + s2 = Signal.like(Signal(min=-15)) + self.assertEqual(s2.bits_sign(), (5, True)) + s3 = Signal.like(Signal(4, reset=0b111, reset_less=True)) + self.assertEqual(s3.reset, 0b111) + self.assertEqual(s3.reset_less, True) + s4 = Signal.like(Signal(attrs={"no_retiming": True})) + self.assertEqual(s4.attrs, {"no_retiming": True}) + s5 = Signal.like(10) + self.assertEqual(s5.bits_sign(), (4, False)) + + +class ClockSignalTestCase(unittest.TestCase): + def test_domain(self): + s1 = ClockSignal() + self.assertEqual(s1.domain, "sync") + s2 = ClockSignal("pix") + self.assertEqual(s2.domain, "pix") + + with self.assertRaises(TypeError): + ClockSignal(1) + + def test_repr(self): + s1 = ClockSignal() + self.assertEqual(repr(s1), "(clk sync)") + + +class ResetSignalTestCase(unittest.TestCase): + def test_domain(self): + s1 = ResetSignal() + self.assertEqual(s1.domain, "sync") + s2 = ResetSignal("pix") + self.assertEqual(s2.domain, "pix") + + with self.assertRaises(TypeError): + ResetSignal(1) + + def test_repr(self): + s1 = ResetSignal() + self.assertEqual(repr(s1), "(reset sync)") -- 2.30.2