From 991f235645ef88c3be1bee573e2e01f32a68811c Mon Sep 17 00:00:00 2001 From: whitequark Date: Sat, 15 Dec 2018 09:58:30 +0000 Subject: [PATCH] fhdl.ast, back.pysim: implement shifts. --- nmigen/back/pysim.py | 18 +++++++++++++----- nmigen/back/rtlil.py | 4 ++-- nmigen/fhdl/ast.py | 18 +++++++++--------- nmigen/test/test_fhdl_ast.py | 14 ++++++++++++++ nmigen/test/test_sim.py | 12 ++++++++++++ 5 files changed, 50 insertions(+), 16 deletions(-) diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index 6a15b7d..69ddf12 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -75,15 +75,23 @@ class _RHSValueCompiler(ValueTransformer): elif len(value.operands) == 2: lhs, rhs = map(self, value.operands) if value.op == "+": - return lambda state: normalize(lhs(state) + rhs(state), shape) + return lambda state: normalize(lhs(state) + rhs(state), shape) if value.op == "-": - return lambda state: normalize(lhs(state) - rhs(state), shape) + return lambda state: normalize(lhs(state) - rhs(state), shape) if value.op == "&": - return lambda state: normalize(lhs(state) & rhs(state), shape) + return lambda state: normalize(lhs(state) & rhs(state), shape) if value.op == "|": - return lambda state: normalize(lhs(state) | rhs(state), shape) + return lambda state: normalize(lhs(state) | rhs(state), shape) if value.op == "^": - return lambda state: normalize(lhs(state) ^ rhs(state), shape) + return lambda state: normalize(lhs(state) ^ rhs(state), shape) + if value.op == "<<": + def sshl(lhs, rhs): + return lhs << rhs if rhs >= 0 else lhs >> -rhs + return lambda state: normalize(sshl(lhs(state), rhs(state)), shape) + if value.op == ">>": + def sshr(lhs, rhs): + return lhs >> rhs if rhs >= 0 else lhs << -rhs + return lambda state: normalize(sshr(lhs(state), rhs(state)), shape) if value.op == "==": return lambda state: normalize(lhs(state) == rhs(state), shape) if value.op == "!=": diff --git a/nmigen/back/rtlil.py b/nmigen/back/rtlil.py index a8dd7d3..fc7420f 100644 --- a/nmigen/back/rtlil.py +++ b/nmigen/back/rtlil.py @@ -206,8 +206,8 @@ class _ValueTransformer(xfrm.ValueTransformer): (2, "/"): "$div", (2, "%"): "$mod", (2, "**"): "$pow", - (2, "<<<"): "$sshl", - (2, ">>>"): "$sshr", + (2, "<<"): "$sshl", + (2, ">>"): "$sshr", (2, "&"): "$and", (2, "^"): "$xor", (2, "|"): "$or", diff --git a/nmigen/fhdl/ast.py b/nmigen/fhdl/ast.py index d878968..c8c13ae 100644 --- a/nmigen/fhdl/ast.py +++ b/nmigen/fhdl/ast.py @@ -75,13 +75,13 @@ class Value(metaclass=ABCMeta): def __rdiv__(self, other): return Operator("/", [other, self]) def __lshift__(self, other): - return Operator("<<<", [self, other]) + return Operator("<<", [self, other]) def __rlshift__(self, other): - return Operator("<<<", [other, self]) + return Operator("<<", [other, self]) def __rshift__(self, other): - return Operator(">>>", [self, other]) + return Operator(">>", [self, other]) def __rrshift__(self, other): - return Operator(">>>", [other, self]) + return Operator(">>", [other, self]) def __and__(self, other): return Operator("&", [self, other]) def __rand__(self, other): @@ -306,15 +306,15 @@ class Operator(Value): return 1, False if self.op in ("&", "^", "|"): return self._bitwise_binary_shape(*op_shapes) - if self.op == "<<<": + if self.op == "<<": if b_sign: - extra = 2**(b_bits - 1) - 1 + extra = 2 ** (b_bits - 1) - 1 else: - extra = 2**b_bits - 1 + extra = 2 ** (b_bits) - 1 return a_bits + extra, a_sign - if self.op == ">>>": + if self.op == ">>": if b_sign: - extra = 2**(b_bits - 1) + extra = 2 ** (b_bits - 1) else: extra = 0 return a_bits + extra, a_sign diff --git a/nmigen/test/test_fhdl_ast.py b/nmigen/test/test_fhdl_ast.py index 433b609..fd1d58f 100644 --- a/nmigen/test/test_fhdl_ast.py +++ b/nmigen/test/test_fhdl_ast.py @@ -182,6 +182,20 @@ class OperatorTestCase(FHDLTestCase): v5 = 10 ^ Const(0, 4) self.assertEqual(v5.shape(), (4, False)) + def test_shl(self): + v1 = Const(1, 4) << Const(4) + self.assertEqual(repr(v1), "(<< (const 4'd1) (const 3'd4))") + self.assertEqual(v1.shape(), (11, False)) + v2 = Const(1, 4) << Const(-3) + self.assertEqual(v2.shape(), (7, False)) + + def test_shr(self): + v1 = Const(1, 4) >> Const(4) + self.assertEqual(repr(v1), "(>> (const 4'd1) (const 3'd4))") + self.assertEqual(v1.shape(), (4, False)) + v2 = Const(1, 4) >> Const(-3) + self.assertEqual(v2.shape(), (8, False)) + def test_lt(self): v = Const(0, 4) < Const(0, 6) self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))") diff --git a/nmigen/test/test_sim.py b/nmigen/test/test_sim.py index ec41986..ef7b18f 100644 --- a/nmigen/test/test_sim.py +++ b/nmigen/test/test_sim.py @@ -71,6 +71,18 @@ class SimulatorUnitTestCase(FHDLTestCase): stmt = lambda a, b: a ^ b self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4)) + def test_shl(self): + stmt = lambda a, b: a << b + self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 5)) + self.assertOperator(stmt, [C(0b1001, 4), C(3)], C(0b1001000, 7)) + self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b10, 7)) + + def test_shr(self): + stmt = lambda a, b: a >> b + self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 4)) + self.assertOperator(stmt, [C(0b1001, 4), C(2)], C(0b10, 4)) + self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b100100, 5)) + def test_eq(self): stmt = lambda a, b: a == b self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1)) -- 2.30.2