elif len(value.operands) == 2:
lhs, rhs = map(self, value.operands)
if value.op == "+":
- return lambda state: normalize(lhs(state) + rhs(state), shape)
+ return lambda state: normalize(lhs(state) + rhs(state), shape)
if value.op == "-":
- return lambda state: normalize(lhs(state) - rhs(state), shape)
+ return lambda state: normalize(lhs(state) - rhs(state), shape)
if value.op == "&":
- return lambda state: normalize(lhs(state) & rhs(state), shape)
+ return lambda state: normalize(lhs(state) & rhs(state), shape)
if value.op == "|":
- return lambda state: normalize(lhs(state) | rhs(state), shape)
+ return lambda state: normalize(lhs(state) | rhs(state), shape)
if value.op == "^":
- return lambda state: normalize(lhs(state) ^ rhs(state), shape)
+ return lambda state: normalize(lhs(state) ^ rhs(state), shape)
+ if value.op == "<<":
+ def sshl(lhs, rhs):
+ return lhs << rhs if rhs >= 0 else lhs >> -rhs
+ return lambda state: normalize(sshl(lhs(state), rhs(state)), shape)
+ if value.op == ">>":
+ def sshr(lhs, rhs):
+ return lhs >> rhs if rhs >= 0 else lhs << -rhs
+ return lambda state: normalize(sshr(lhs(state), rhs(state)), shape)
if value.op == "==":
return lambda state: normalize(lhs(state) == rhs(state), shape)
if value.op == "!=":
(2, "/"): "$div",
(2, "%"): "$mod",
(2, "**"): "$pow",
- (2, "<<<"): "$sshl",
- (2, ">>>"): "$sshr",
+ (2, "<<"): "$sshl",
+ (2, ">>"): "$sshr",
(2, "&"): "$and",
(2, "^"): "$xor",
(2, "|"): "$or",
def __rdiv__(self, other):
return Operator("/", [other, self])
def __lshift__(self, other):
- return Operator("<<<", [self, other])
+ return Operator("<<", [self, other])
def __rlshift__(self, other):
- return Operator("<<<", [other, self])
+ return Operator("<<", [other, self])
def __rshift__(self, other):
- return Operator(">>>", [self, other])
+ return Operator(">>", [self, other])
def __rrshift__(self, other):
- return Operator(">>>", [other, self])
+ return Operator(">>", [other, self])
def __and__(self, other):
return Operator("&", [self, other])
def __rand__(self, other):
return 1, False
if self.op in ("&", "^", "|"):
return self._bitwise_binary_shape(*op_shapes)
- if self.op == "<<<":
+ if self.op == "<<":
if b_sign:
- extra = 2**(b_bits - 1) - 1
+ extra = 2 ** (b_bits - 1) - 1
else:
- extra = 2**b_bits - 1
+ extra = 2 ** (b_bits) - 1
return a_bits + extra, a_sign
- if self.op == ">>>":
+ if self.op == ">>":
if b_sign:
- extra = 2**(b_bits - 1)
+ extra = 2 ** (b_bits - 1)
else:
extra = 0
return a_bits + extra, a_sign
v5 = 10 ^ Const(0, 4)
self.assertEqual(v5.shape(), (4, False))
+ def test_shl(self):
+ v1 = Const(1, 4) << Const(4)
+ self.assertEqual(repr(v1), "(<< (const 4'd1) (const 3'd4))")
+ self.assertEqual(v1.shape(), (11, False))
+ v2 = Const(1, 4) << Const(-3)
+ self.assertEqual(v2.shape(), (7, False))
+
+ def test_shr(self):
+ v1 = Const(1, 4) >> Const(4)
+ self.assertEqual(repr(v1), "(>> (const 4'd1) (const 3'd4))")
+ self.assertEqual(v1.shape(), (4, False))
+ v2 = Const(1, 4) >> Const(-3)
+ self.assertEqual(v2.shape(), (8, False))
+
def test_lt(self):
v = Const(0, 4) < Const(0, 6)
self.assertEqual(repr(v), "(< (const 4'd0) (const 6'd0))")
stmt = lambda a, b: a ^ b
self.assertOperator(stmt, [C(0b1100, 4), C(0b1010, 4)], C(0b0110, 4))
+ def test_shl(self):
+ stmt = lambda a, b: a << b
+ self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 5))
+ self.assertOperator(stmt, [C(0b1001, 4), C(3)], C(0b1001000, 7))
+ self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b10, 7))
+
+ def test_shr(self):
+ stmt = lambda a, b: a >> b
+ self.assertOperator(stmt, [C(0b1001, 4), C(0)], C(0b1001, 4))
+ self.assertOperator(stmt, [C(0b1001, 4), C(2)], C(0b10, 4))
+ self.assertOperator(stmt, [C(0b1001, 4), C(-2)], C(0b100100, 5))
+
def test_eq(self):
stmt = lambda a, b: a == b
self.assertOperator(stmt, [C(0, 4), C(0, 4)], C(1))