a0 = a[7..0], a1 = a[15..8], ....
b0 = b[7..0], b1 = b[15..8], ....
-then, we compute the following matrix:
+then, we compute the following matrix, with the first column output being the full width (32 bit), the second being only 24 bit, the third only 16 bit and finally the top part (comprising the most significant byte of a and b as input) being only 8 bit
| a0 << b0 | a1 << b0 | a2 << b0 | a3 << b0
| a0 << b1 | a1 << b1 | a2 << b1 | a3 << b1