From: whitequark Date: Sat, 1 Feb 2020 23:04:25 +0000 (+0000) Subject: hdl.ast: prohibit shifts by signed value. X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=a055eb897f1685fdc33aa996250c072de100121c;p=nmigen.git hdl.ast: prohibit shifts by signed value. These are not desirable in a HDL, and currently elaborate to broken RTLIL (after YosysHQ/yosys#1551); prohibit them completely, like we already do for division and modulo. Fixes #302. --- diff --git a/nmigen/back/pysim.py b/nmigen/back/pysim.py index b0449fe..6bf4b7c 100644 --- a/nmigen/back/pysim.py +++ b/nmigen/back/pysim.py @@ -347,8 +347,6 @@ class _ValueCompiler(ValueVisitor, _Compiler): helpers = { "sign": lambda value, sign: value | sign if value & sign else value, "zdiv": lambda lhs, rhs: 0 if rhs == 0 else lhs // rhs, - "sshl": lambda lhs, rhs: lhs << rhs if rhs >= 0 else lhs >> -rhs, - "sshr": lambda lhs, rhs: lhs >> rhs if rhs >= 0 else lhs << -rhs, } def on_ClockSignal(self, value): @@ -438,9 +436,9 @@ class _RHSValueCompiler(_ValueCompiler): if value.operator == "^": return f"({self(lhs)} ^ {self(rhs)})" if value.operator == "<<": - return f"sshl({sign(lhs)}, {sign(rhs)})" + return f"({sign(lhs)} << {sign(rhs)})" if value.operator == ">>": - return f"sshr({sign(lhs)}, {sign(rhs)})" + return f"({sign(lhs)} >> {sign(rhs)})" if value.operator == "==": return f"({sign(lhs)} == {sign(rhs)})" if value.operator == "!=": diff --git a/nmigen/hdl/ast.py b/nmigen/hdl/ast.py index efaeb0f..4bec9b3 100644 --- a/nmigen/hdl/ast.py +++ b/nmigen/hdl/ast.py @@ -172,14 +172,28 @@ class Value(metaclass=ABCMeta): self.__check_divisor() return Operator("//", [other, self]) + def __check_shamt(self): + width, signed = self.shape() + if signed: + # Neither Python nor HDLs implement shifts by negative values; prohibit any shifts + # by a signed value to make sure the shift amount can always be interpreted as + # an unsigned value. + raise NotImplementedError("Shift by a signed value is not supported") def __lshift__(self, other): + other = Value.cast(other) + other.__check_shamt() return Operator("<<", [self, other]) def __rlshift__(self, other): + self.__check_shamt() return Operator("<<", [other, self]) def __rshift__(self, other): + other = Value.cast(other) + other.__check_shamt() return Operator(">>", [self, other]) def __rrshift__(self, other): + self.__check_shamt() return Operator(">>", [other, self]) + def __and__(self, other): return Operator("&", [self, other]) def __rand__(self, other): diff --git a/nmigen/test/test_hdl_ast.py b/nmigen/test/test_hdl_ast.py index 960d0c9..76665cc 100644 --- a/nmigen/test/test_hdl_ast.py +++ b/nmigen/test/test_hdl_ast.py @@ -362,15 +362,27 @@ class OperatorTestCase(FHDLTestCase): v1 = Const(1, 4) << Const(4) self.assertEqual(repr(v1), "(<< (const 4'd1) (const 3'd4))") self.assertEqual(v1.shape(), unsigned(11)) - v2 = Const(1, 4) << Const(-3) - self.assertEqual(v2.shape(), unsigned(7)) + + def test_shl_wrong(self): + with self.assertRaises(NotImplementedError, + msg="Shift by a signed value is not supported"): + 1 << Const(0, signed(6)) + with self.assertRaises(NotImplementedError, + msg="Shift by a signed value is not supported"): + Const(1, unsigned(4)) << -1 def test_shr(self): v1 = Const(1, 4) >> Const(4) self.assertEqual(repr(v1), "(>> (const 4'd1) (const 3'd4))") self.assertEqual(v1.shape(), unsigned(4)) - v2 = Const(1, 4) >> Const(-3) - self.assertEqual(v2.shape(), unsigned(8)) + + def test_shr_wrong(self): + with self.assertRaises(NotImplementedError, + msg="Shift by a signed value is not supported"): + 1 << Const(0, signed(6)) + with self.assertRaises(NotImplementedError, + msg="Shift by a signed value is not supported"): + Const(1, unsigned(4)) << -1 def test_lt(self): v = Const(0, 4) < Const(0, 6) diff --git a/nmigen/test/test_sim.py b/nmigen/test/test_sim.py index 93b76c3..7c00072 100644 --- a/nmigen/test/test_sim.py +++ b/nmigen/test/test_sim.py @@ -116,13 +116,11 @@ class SimulatorUnitTestCase(FHDLTestCase): stmt = lambda y, a, b: y.eq(a << b) self.assertStatement(stmt, [C(0b1001, 4), C(0)], C(0b1001, 5)) self.assertStatement(stmt, [C(0b1001, 4), C(3)], C(0b1001000, 7)) - self.assertStatement(stmt, [C(0b1001, 4), C(-2)], C(0b10, 7)) def test_shr(self): stmt = lambda y, a, b: y.eq(a >> b) self.assertStatement(stmt, [C(0b1001, 4), C(0)], C(0b1001, 4)) self.assertStatement(stmt, [C(0b1001, 4), C(2)], C(0b10, 4)) - self.assertStatement(stmt, [C(0b1001, 4), C(-2)], C(0b100100, 5)) def test_eq(self): stmt = lambda y, a, b: y.eq(a == b)