nir: Support sysval tess levels in SPIR-V to NIR
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
index 09e4e598b26e497013899fba813dc08486639cb6..8d17845c1a0c420b83624b399c484c12dd2745f2 100644 (file)
@@ -28,7 +28,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
                          nir_intrinsic_op nir_op,
                          struct vtn_ssa_value *dst,
                          struct vtn_ssa_value *src0,
-                         nir_ssa_def *index)
+                         nir_ssa_def *index,
+                         unsigned const_idx0,
+                         unsigned const_idx1)
 {
    /* Some of the subgroup operations take an index.  SPIR-V allows this to be
     * any integer type.  To make things simpler for drivers, we only support
@@ -41,7 +43,8 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
    if (!glsl_type_is_vector_or_scalar(dst->type)) {
       for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
          vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
-                                  src0->elems[i], index);
+                                  src0->elems[i], index,
+                                  const_idx0, const_idx1);
       }
       return;
    }
@@ -56,6 +59,9 @@ vtn_build_subgroup_instr(struct vtn_builder *b,
    if (index)
       intrin->src[1] = nir_src_for_ssa(index);
 
+   intrin->const_index[0] = const_idx0;
+   intrin->const_index[1] = const_idx1;
+
    nir_builder_instr_insert(&b->nb, &intrin->instr);
 
    dst->def = &intrin->dest.ssa;
@@ -75,18 +81,20 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                   "OpGroupNonUniformElect must return a Bool");
       nir_intrinsic_instr *elect =
          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
-      nir_ssa_dest_init(&elect->instr, &elect->dest, 1, 32, NULL);
+      nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
+                                 val->type->type, NULL);
       nir_builder_instr_insert(&b->nb, &elect->instr);
       val->ssa->def = &elect->dest.ssa;
       break;
    }
 
-   case SpvOpGroupNonUniformBallot: {
+   case SpvOpGroupNonUniformBallot: ++w;
+   case SpvOpSubgroupBallotKHR: {
       vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
                   "OpGroupNonUniformBallot must return a uvec4");
       nir_intrinsic_instr *ballot =
          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
-      ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+      ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def);
       nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
       ballot->num_components = 4;
       nir_builder_instr_insert(&b->nb, &ballot->instr);
@@ -106,7 +114,8 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
       intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
 
-      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                                 val->type->type, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
       val->ssa->def = &intrin->dest.ssa;
@@ -160,45 +169,66 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
       if (src1)
          intrin->src[1] = nir_src_for_ssa(src1);
 
-      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                                 val->type->type, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
       val->ssa->def = &intrin->dest.ssa;
       break;
    }
 
-   case SpvOpGroupNonUniformBroadcastFirst:
+   case SpvOpGroupNonUniformBroadcastFirst: ++w;
+   case SpvOpSubgroupFirstInvocationKHR:
       vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
-                               val->ssa, vtn_ssa_value(b, w[4]), NULL);
+                               val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
       break;
 
    case SpvOpGroupNonUniformBroadcast:
+   case SpvOpGroupBroadcast: ++w;
+   case SpvOpSubgroupReadInvocationKHR:
       vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
-                               val->ssa, vtn_ssa_value(b, w[4]),
-                               vtn_ssa_value(b, w[5])->def);
+                               val->ssa, vtn_ssa_value(b, w[3]),
+                               vtn_ssa_value(b, w[4])->def, 0, 0);
       break;
 
    case SpvOpGroupNonUniformAll:
    case SpvOpGroupNonUniformAny:
-   case SpvOpGroupNonUniformAllEqual: {
+   case SpvOpGroupNonUniformAllEqual:
+   case SpvOpGroupAll:
+   case SpvOpGroupAny:
+   case SpvOpSubgroupAllKHR:
+   case SpvOpSubgroupAnyKHR:
+   case SpvOpSubgroupAllEqualKHR: {
       vtn_fail_if(val->type->type != glsl_bool_type(),
                   "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
       nir_intrinsic_op op;
       switch (opcode) {
       case SpvOpGroupNonUniformAll:
+      case SpvOpGroupAll:
+      case SpvOpSubgroupAllKHR:
          op = nir_intrinsic_vote_all;
          break;
       case SpvOpGroupNonUniformAny:
+      case SpvOpGroupAny:
+      case SpvOpSubgroupAnyKHR:
          op = nir_intrinsic_vote_any;
          break;
-      case SpvOpGroupNonUniformAllEqual: {
-         switch (glsl_get_base_type(val->type->type)) {
+      case SpvOpSubgroupAllEqualKHR:
+         op = nir_intrinsic_vote_ieq;
+         break;
+      case SpvOpGroupNonUniformAllEqual:
+         switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
          case GLSL_TYPE_FLOAT:
+         case GLSL_TYPE_FLOAT16:
          case GLSL_TYPE_DOUBLE:
             op = nir_intrinsic_vote_feq;
             break;
          case GLSL_TYPE_UINT:
          case GLSL_TYPE_INT:
+         case GLSL_TYPE_UINT8:
+         case GLSL_TYPE_INT8:
+         case GLSL_TYPE_UINT16:
+         case GLSL_TYPE_INT16:
          case GLSL_TYPE_UINT64:
          case GLSL_TYPE_INT64:
          case GLSL_TYPE_BOOL:
@@ -208,18 +238,25 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
             unreachable("Unhandled type");
          }
          break;
-      }
       default:
          unreachable("Unhandled opcode");
       }
 
-      nir_ssa_def *src0 = vtn_ssa_value(b, w[4])->def;
-
+      nir_ssa_def *src0;
+      if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
+          opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
+          opcode == SpvOpGroupNonUniformAllEqual) {
+         src0 = vtn_ssa_value(b, w[4])->def;
+      } else {
+         src0 = vtn_ssa_value(b, w[3])->def;
+      }
       nir_intrinsic_instr *intrin =
          nir_intrinsic_instr_create(b->nb.shader, op);
-      intrin->num_components = src0->num_components;
+      if (nir_intrinsic_infos[op].src_components[0] == 0)
+         intrin->num_components = src0->num_components;
       intrin->src[0] = nir_src_for_ssa(src0);
-      nir_ssa_dest_init(&intrin->instr, &intrin->dest, 1, 32, NULL);
+      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                                 val->type->type, NULL);
       nir_builder_instr_insert(&b->nb, &intrin->instr);
 
       val->ssa->def = &intrin->dest.ssa;
@@ -229,7 +266,56 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformShuffle:
    case SpvOpGroupNonUniformShuffleXor:
    case SpvOpGroupNonUniformShuffleUp:
-   case SpvOpGroupNonUniformShuffleDown:
+   case SpvOpGroupNonUniformShuffleDown: {
+      nir_intrinsic_op op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformShuffle:
+         op = nir_intrinsic_shuffle;
+         break;
+      case SpvOpGroupNonUniformShuffleXor:
+         op = nir_intrinsic_shuffle_xor;
+         break;
+      case SpvOpGroupNonUniformShuffleUp:
+         op = nir_intrinsic_shuffle_up;
+         break;
+      case SpvOpGroupNonUniformShuffleDown:
+         op = nir_intrinsic_shuffle_down;
+         break;
+      default:
+         unreachable("Invalid opcode");
+      }
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+                               vtn_ssa_value(b, w[5])->def, 0, 0);
+      break;
+   }
+
+   case SpvOpGroupNonUniformQuadBroadcast:
+      vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
+                               val->ssa, vtn_ssa_value(b, w[4]),
+                               vtn_ssa_value(b, w[5])->def, 0, 0);
+      break;
+
+   case SpvOpGroupNonUniformQuadSwap: {
+      unsigned direction = vtn_constant_uint(b, w[5]);
+      nir_intrinsic_op op;
+      switch (direction) {
+      case 0:
+         op = nir_intrinsic_quad_swap_horizontal;
+         break;
+      case 1:
+         op = nir_intrinsic_quad_swap_vertical;
+         break;
+      case 2:
+         op = nir_intrinsic_quad_swap_diagonal;
+         break;
+      default:
+         vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
+      }
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+                               NULL, 0, 0);
+      break;
+   }
+
    case SpvOpGroupNonUniformIAdd:
    case SpvOpGroupNonUniformFAdd:
    case SpvOpGroupNonUniformIMul:
@@ -246,8 +332,112 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformLogicalAnd:
    case SpvOpGroupNonUniformLogicalOr:
    case SpvOpGroupNonUniformLogicalXor:
-   case SpvOpGroupNonUniformQuadBroadcast:
-   case SpvOpGroupNonUniformQuadSwap:
+   case SpvOpGroupIAdd:
+   case SpvOpGroupFAdd:
+   case SpvOpGroupFMin:
+   case SpvOpGroupUMin:
+   case SpvOpGroupSMin:
+   case SpvOpGroupFMax:
+   case SpvOpGroupUMax:
+   case SpvOpGroupSMax:
+   case SpvOpGroupIAddNonUniformAMD:
+   case SpvOpGroupFAddNonUniformAMD:
+   case SpvOpGroupFMinNonUniformAMD:
+   case SpvOpGroupUMinNonUniformAMD:
+   case SpvOpGroupSMinNonUniformAMD:
+   case SpvOpGroupFMaxNonUniformAMD:
+   case SpvOpGroupUMaxNonUniformAMD:
+   case SpvOpGroupSMaxNonUniformAMD: {
+      nir_op reduction_op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformIAdd:
+      case SpvOpGroupIAdd:
+      case SpvOpGroupIAddNonUniformAMD:
+         reduction_op = nir_op_iadd;
+         break;
+      case SpvOpGroupNonUniformFAdd:
+      case SpvOpGroupFAdd:
+      case SpvOpGroupFAddNonUniformAMD:
+         reduction_op = nir_op_fadd;
+         break;
+      case SpvOpGroupNonUniformIMul:
+         reduction_op = nir_op_imul;
+         break;
+      case SpvOpGroupNonUniformFMul:
+         reduction_op = nir_op_fmul;
+         break;
+      case SpvOpGroupNonUniformSMin:
+      case SpvOpGroupSMin:
+      case SpvOpGroupSMinNonUniformAMD:
+         reduction_op = nir_op_imin;
+         break;
+      case SpvOpGroupNonUniformUMin:
+      case SpvOpGroupUMin:
+      case SpvOpGroupUMinNonUniformAMD:
+         reduction_op = nir_op_umin;
+         break;
+      case SpvOpGroupNonUniformFMin:
+      case SpvOpGroupFMin:
+      case SpvOpGroupFMinNonUniformAMD:
+         reduction_op = nir_op_fmin;
+         break;
+      case SpvOpGroupNonUniformSMax:
+      case SpvOpGroupSMax:
+      case SpvOpGroupSMaxNonUniformAMD:
+         reduction_op = nir_op_imax;
+         break;
+      case SpvOpGroupNonUniformUMax:
+      case SpvOpGroupUMax:
+      case SpvOpGroupUMaxNonUniformAMD:
+         reduction_op = nir_op_umax;
+         break;
+      case SpvOpGroupNonUniformFMax:
+      case SpvOpGroupFMax:
+      case SpvOpGroupFMaxNonUniformAMD:
+         reduction_op = nir_op_fmax;
+         break;
+      case SpvOpGroupNonUniformBitwiseAnd:
+      case SpvOpGroupNonUniformLogicalAnd:
+         reduction_op = nir_op_iand;
+         break;
+      case SpvOpGroupNonUniformBitwiseOr:
+      case SpvOpGroupNonUniformLogicalOr:
+         reduction_op = nir_op_ior;
+         break;
+      case SpvOpGroupNonUniformBitwiseXor:
+      case SpvOpGroupNonUniformLogicalXor:
+         reduction_op = nir_op_ixor;
+         break;
+      default:
+         unreachable("Invalid reduction operation");
+      }
+
+      nir_intrinsic_op op;
+      unsigned cluster_size = 0;
+      switch ((SpvGroupOperation)w[4]) {
+      case SpvGroupOperationReduce:
+         op = nir_intrinsic_reduce;
+         break;
+      case SpvGroupOperationInclusiveScan:
+         op = nir_intrinsic_inclusive_scan;
+         break;
+      case SpvGroupOperationExclusiveScan:
+         op = nir_intrinsic_exclusive_scan;
+         break;
+      case SpvGroupOperationClusteredReduce:
+         op = nir_intrinsic_reduce;
+         assert(count == 7);
+         cluster_size = vtn_constant_uint(b, w[6]);
+         break;
+      default:
+         unreachable("Invalid group operation");
+      }
+
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
+                               NULL, reduction_op, cluster_size);
+      break;
+   }
+
    default:
       unreachable("Invalid SPIR-V opcode");
    }