- @staticmethod
- def simd_mul(a: int, b: int, lanes: List[SIMDMulLane]) -> Tuple[int, int]:
- output = 0
- intermediate_output = 0
- shift = 0
- for lane in lanes:
- a_signed = lane.a_signed or not lane.high_half
- b_signed = lane.b_signed or not lane.high_half
- mask = (1 << lane.bit_width) - 1
- sign_bit = 1 << (lane.bit_width - 1)
- a_part = (a >> shift) & mask
- if a_signed and (a_part & sign_bit) != 0:
- a_part -= 1 << lane.bit_width
- b_part = (b >> shift) & mask
- if b_signed and (b_part & sign_bit) != 0:
- b_part -= 1 << lane.bit_width
- value = a_part * b_part
- value &= (1 << (lane.bit_width * 2)) - 1
- intermediate_output |= value << (shift * 2)
- if lane.high_half:
- value >>= lane.bit_width
- value &= mask
- output |= value << shift
- shift += lane.bit_width
- return output, intermediate_output