nir: Support sysval tess levels in SPIR-V to NIR
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
index ce795ec2cb51c8c61641e7621cf4bfe32841bf29..8d17845c1a0c420b83624b399c484c12dd2745f2 100644 (file)
@@ -183,7 +183,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                                val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
       break;
 
-   case SpvOpGroupNonUniformBroadcast: ++w;
+   case SpvOpGroupNonUniformBroadcast:
+   case SpvOpGroupBroadcast: ++w;
    case SpvOpSubgroupReadInvocationKHR:
       vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
                                val->ssa, vtn_ssa_value(b, w[3]),
@@ -193,6 +194,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformAll:
    case SpvOpGroupNonUniformAny:
    case SpvOpGroupNonUniformAllEqual:
+   case SpvOpGroupAll:
+   case SpvOpGroupAny:
    case SpvOpSubgroupAllKHR:
    case SpvOpSubgroupAnyKHR:
    case SpvOpSubgroupAllEqualKHR: {
@@ -201,22 +204,31 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       nir_intrinsic_op op;
       switch (opcode) {
       case SpvOpGroupNonUniformAll:
+      case SpvOpGroupAll:
       case SpvOpSubgroupAllKHR:
          op = nir_intrinsic_vote_all;
          break;
       case SpvOpGroupNonUniformAny:
+      case SpvOpGroupAny:
       case SpvOpSubgroupAnyKHR:
          op = nir_intrinsic_vote_any;
          break;
+      case SpvOpSubgroupAllEqualKHR:
+         op = nir_intrinsic_vote_ieq;
+         break;
       case SpvOpGroupNonUniformAllEqual:
-      case SpvOpSubgroupAllEqualKHR: {
-         switch (glsl_get_base_type(val->type->type)) {
+         switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
          case GLSL_TYPE_FLOAT:
+         case GLSL_TYPE_FLOAT16:
          case GLSL_TYPE_DOUBLE:
             op = nir_intrinsic_vote_feq;
             break;
          case GLSL_TYPE_UINT:
          case GLSL_TYPE_INT:
+         case GLSL_TYPE_UINT8:
+         case GLSL_TYPE_INT8:
+         case GLSL_TYPE_UINT16:
+         case GLSL_TYPE_INT16:
          case GLSL_TYPE_UINT64:
          case GLSL_TYPE_INT64:
          case GLSL_TYPE_BOOL:
@@ -226,14 +238,13 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
             unreachable("Unhandled type");
          }
          break;
-      }
       default:
          unreachable("Unhandled opcode");
       }
 
       nir_ssa_def *src0;
-      if (opcode == SpvOpGroupNonUniformAll ||
-          opcode == SpvOpGroupNonUniformAny ||
+      if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
+          opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
           opcode == SpvOpGroupNonUniformAllEqual) {
          src0 = vtn_ssa_value(b, w[4])->def;
       } else {
@@ -241,7 +252,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       }
       nir_intrinsic_instr *intrin =
          nir_intrinsic_instr_create(b->nb.shader, op);
-      intrin->num_components = src0->num_components;
+      if (nir_intrinsic_infos[op].src_components[0] == 0)
+         intrin->num_components = src0->num_components;
       intrin->src[0] = nir_src_for_ssa(src0);
       nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
                                  val->type->type, NULL);
@@ -319,13 +331,33 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformBitwiseXor:
    case SpvOpGroupNonUniformLogicalAnd:
    case SpvOpGroupNonUniformLogicalOr:
-   case SpvOpGroupNonUniformLogicalXor: {
+   case SpvOpGroupNonUniformLogicalXor:
+   case SpvOpGroupIAdd:
+   case SpvOpGroupFAdd:
+   case SpvOpGroupFMin:
+   case SpvOpGroupUMin:
+   case SpvOpGroupSMin:
+   case SpvOpGroupFMax:
+   case SpvOpGroupUMax:
+   case SpvOpGroupSMax:
+   case SpvOpGroupIAddNonUniformAMD:
+   case SpvOpGroupFAddNonUniformAMD:
+   case SpvOpGroupFMinNonUniformAMD:
+   case SpvOpGroupUMinNonUniformAMD:
+   case SpvOpGroupSMinNonUniformAMD:
+   case SpvOpGroupFMaxNonUniformAMD:
+   case SpvOpGroupUMaxNonUniformAMD:
+   case SpvOpGroupSMaxNonUniformAMD: {
       nir_op reduction_op;
       switch (opcode) {
       case SpvOpGroupNonUniformIAdd:
+      case SpvOpGroupIAdd:
+      case SpvOpGroupIAddNonUniformAMD:
          reduction_op = nir_op_iadd;
          break;
       case SpvOpGroupNonUniformFAdd:
+      case SpvOpGroupFAdd:
+      case SpvOpGroupFAddNonUniformAMD:
          reduction_op = nir_op_fadd;
          break;
       case SpvOpGroupNonUniformIMul:
@@ -335,21 +367,33 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
          reduction_op = nir_op_fmul;
          break;
       case SpvOpGroupNonUniformSMin:
+      case SpvOpGroupSMin:
+      case SpvOpGroupSMinNonUniformAMD:
          reduction_op = nir_op_imin;
          break;
       case SpvOpGroupNonUniformUMin:
+      case SpvOpGroupUMin:
+      case SpvOpGroupUMinNonUniformAMD:
          reduction_op = nir_op_umin;
          break;
       case SpvOpGroupNonUniformFMin:
+      case SpvOpGroupFMin:
+      case SpvOpGroupFMinNonUniformAMD:
          reduction_op = nir_op_fmin;
          break;
       case SpvOpGroupNonUniformSMax:
+      case SpvOpGroupSMax:
+      case SpvOpGroupSMaxNonUniformAMD:
          reduction_op = nir_op_imax;
          break;
       case SpvOpGroupNonUniformUMax:
+      case SpvOpGroupUMax:
+      case SpvOpGroupUMaxNonUniformAMD:
          reduction_op = nir_op_umax;
          break;
       case SpvOpGroupNonUniformFMax:
+      case SpvOpGroupFMax:
+      case SpvOpGroupFMaxNonUniformAMD:
          reduction_op = nir_op_fmax;
          break;
       case SpvOpGroupNonUniformBitwiseAnd: