aco: implement 8-bit/16-bit reductions
[mesa.git] / src / amd / compiler / aco_lower_to_hw_instr.cpp
index 765a7f63a983afca282749e890adb3e7eeb10fd6..1d3061d5dd94b539c48d0f08370d6db032c07369 100644 (file)
@@ -43,6 +43,22 @@ struct lower_context {
 
 aco_opcode get_reduce_opcode(chip_class chip, ReduceOp op) {
    switch (op) {
+   case iadd8:
+   case iadd16: return aco_opcode::v_add_u16;
+   case imul8:
+   case imul16: return aco_opcode::v_mul_lo_u16;
+   case fadd16: return aco_opcode::v_add_f16;
+   case fmul16: return aco_opcode::v_mul_f16;
+   case imax8:
+   case imax16: return aco_opcode::v_max_i16;
+   case imin8:
+   case imin16: return aco_opcode::v_min_i16;
+   case umin8:
+   case umin16: return aco_opcode::v_min_u16;
+   case umax8:
+   case umax16: return aco_opcode::v_max_u16;
+   case fmin16: return aco_opcode::v_min_f16;
+   case fmax16: return aco_opcode::v_max_f16;
    case iadd32: return chip >= GFX9 ? aco_opcode::v_add_u32 : aco_opcode::v_add_co_u32;
    case imul32: return aco_opcode::v_mul_lo_u32;
    case fadd32: return aco_opcode::v_add_f32;
@@ -53,8 +69,14 @@ aco_opcode get_reduce_opcode(chip_class chip, ReduceOp op) {
    case umax32: return aco_opcode::v_max_u32;
    case fmin32: return aco_opcode::v_min_f32;
    case fmax32: return aco_opcode::v_max_f32;
+   case iand8:
+   case iand16:
    case iand32: return aco_opcode::v_and_b32;
+   case ixor8:
+   case ixor16:
    case ixor32: return aco_opcode::v_xor_b32;
+   case ior8:
+   case ior16:
    case ior32: return aco_opcode::v_or_b32;
    case iadd64: return aco_opcode::num_opcodes;
    case imul64: return aco_opcode::num_opcodes;
@@ -363,41 +385,71 @@ void emit_dpp_mov(lower_context *ctx, PhysReg dst, PhysReg src0, unsigned size,
 uint32_t get_reduction_identity(ReduceOp op, unsigned idx)
 {
    switch (op) {
+   case iadd8:
+   case iadd16:
    case iadd32:
    case iadd64:
+   case fadd16:
    case fadd32:
    case fadd64:
+   case ior8:
+   case ior16:
    case ior32:
    case ior64:
+   case ixor8:
+   case ixor16:
    case ixor32:
    case ixor64:
+   case umax8:
+   case umax16:
    case umax32:
    case umax64:
       return 0;
+   case imul8:
+   case imul16:
    case imul32:
    case imul64:
       return idx ? 0 : 1;
+   case fmul16:
+      return 0x3c00u; /* 1.0 */
    case fmul32:
       return 0x3f800000u; /* 1.0 */
    case fmul64:
       return idx ? 0x3ff00000u : 0u; /* 1.0 */
+   case imin8:
+      return INT8_MAX;
+   case imin16:
+      return INT16_MAX;
    case imin32:
       return INT32_MAX;
    case imin64:
       return idx ? 0x7fffffffu : 0xffffffffu;
+   case imax8:
+      return INT8_MIN;
+   case imax16:
+      return INT16_MIN;
    case imax32:
       return INT32_MIN;
    case imax64:
       return idx ? 0x80000000u : 0;
+   case umin8:
+   case umin16:
+   case iand8:
+   case iand16:
+      return 0xffffffffu;
    case umin32:
    case umin64:
    case iand32:
    case iand64:
       return 0xffffffffu;
+   case fmin16:
+      return 0x7c00u; /* infinity */
    case fmin32:
       return 0x7f800000u; /* infinity */
    case fmin64:
       return idx ? 0x7ff00000u : 0u; /* infinity */
+   case fmax16:
+      return 0xfc00u; /* negative infinity */
    case fmax32:
       return 0xff800000u; /* negative infinity */
    case fmax64: