if self.operator in ("&", "^", "|"):
return _bitwise_binary_shape(*op_shapes)
if self.operator == "<<":
- if b_signed:
- extra = 2 ** (b_width - 1) - 1
- else:
- extra = 2 ** (b_width) - 1
- return Shape(a_width + extra, a_signed)
+ assert not b_signed
+ return Shape(a_width + 2 ** b_width - 1, a_signed)
if self.operator == ">>":
- if b_signed:
- extra = 2 ** (b_width - 1)
- else:
- extra = 0
- return Shape(a_width + extra, a_signed)
+ assert not b_signed
+ return Shape(a_width, a_signed)
elif len(op_shapes) == 3:
if self.operator == "m":
s_shape, a_shape, b_shape = op_shapes