self.operands = [Value.wrap(o) for o in operands]
@staticmethod
- def _bitwise_binary_shape(a, b):
- if not a[1] and not b[1]:
+ def _bitwise_binary_shape(a_shape, b_shape):
+ a_bits, a_sign = a_shape
+ b_bits, b_sign = b_shape
+ if not a_sign and not b_sign:
# both operands unsigned
- return max(a[0], b[0]), False
- elif a[1] and b[1]:
+ return max(a_bits, b_bits), False
+ elif a_sign and b_sign:
# both operands signed
- return max(a[0], b[0]), True
- elif not a[1] and b[1]:
+ return max(a_bits, b_bits), True
+ elif not a_sign and b_sign:
# first operand unsigned (add sign bit), second operand signed
- return max(a[0] + 1, b[0]), True
+ return max(a_bits + 1, b_bits), True
else:
# first signed, second operand unsigned (add sign bit)
- return max(a[0], b[0] + 1), True
+ return max(a_bits, b_bits + 1), True
def shape(self):
- obs = list(map(lambda x: x.shape(), self.operands))
- if self.op == "+" or self.op == "-":
- if len(obs) == 1:
- if self.op == "-" and not obs[0][1]:
- return obs[0][0] + 1, True
+ op_shapes = list(map(lambda x: x.shape(), self.operands))
+ if len(op_shapes) == 1:
+ (a_bits, a_sign), = op_shapes
+ if self.op in ("+", "~"):
+ return a_bits, a_sign
+ if self.op == "-":
+ if not a_sign:
+ return a_bits + 1, True
else:
- return obs[0]
- n, s = self._bitwise_binary_shape(*obs)
- return n + 1, s
- elif self.op == "*":
- if not obs[0][1] and not obs[1][1]:
- # both operands unsigned
- return obs[0][0] + obs[1][0], False
- elif obs[0][1] and obs[1][1]:
- # both operands signed
- return obs[0][0] + obs[1][0] - 1, True
- else:
+ return a_bits, a_sign
+ if self.op == "b":
+ return 1, False
+ elif len(op_shapes) == 2:
+ (a_bits, a_sign), (b_bits, b_sign) = op_shapes
+ if self.op == "+" or self.op == "-":
+ bits, sign = self._bitwise_binary_shape(*op_shapes)
+ return bits + 1, sign
+ if self.op == "*":
+ if not a_sign and not b_sign:
+ # both operands unsigned
+ return a_bits + b_bits, False
+ if a_sign and b_sign:
+ # both operands signed
+ return a_bits + b_bits - 1, True
# one operand signed, the other unsigned (add sign bit)
- return obs[0][0] + obs[1][0] + 1 - 1, True
- elif self.op == "<<<":
- if obs[1][1]:
- extra = 2**(obs[1][0] - 1) - 1
- else:
- extra = 2**obs[1][0] - 1
- return obs[0][0] + extra, obs[0][1]
- elif self.op == ">>>":
- if obs[1][1]:
- extra = 2**(obs[1][0] - 1)
- else:
- extra = 0
- return obs[0][0] + extra, obs[0][1]
- elif self.op in ("&", "^", "|"):
- return self._bitwise_binary_shape(*obs)
- elif self.op in ("<", "<=", "==", "!=", ">", ">=", "b"):
- return 1, False
- elif self.op == "~":
- return obs[0]
- elif self.op == "m":
- return self._bitwise_binary_shape(obs[1], obs[2])
- raise NotImplementedError("Operator '{}' not implemented".format(self.op)) # :nocov:
+ return a_bits + b_bits + 1 - 1, True
+ if self.op in ("<", "<=", "==", "!=", ">", ">=", "b"):
+ return 1, False
+ if self.op in ("&", "^", "|"):
+ return self._bitwise_binary_shape(*op_shapes)
+ if self.op == "<<<":
+ if b_sign:
+ extra = 2**(b_bits - 1) - 1
+ else:
+ extra = 2**b_bits - 1
+ return a_bits + extra, a_sign
+ if self.op == ">>>":
+ if b_sign:
+ extra = 2**(b_bits - 1)
+ else:
+ extra = 0
+ return a_bits + extra, a_sign
+ elif len(op_shapes) == 3:
+ if self.op == "m":
+ s_shape, a_shape, b_shape = op_shapes
+ return self._bitwise_binary_shape(a_shape, b_shape)
+ raise NotImplementedError("Operator {}/{} not implemented"
+ .format(self.op, len(op_shapes))) # :nocov:
def _rhs_signals(self):
return union(op._rhs_signals() for op in self.operands)