nir: Support sysval tess levels in SPIR-V to NIR
[mesa.git] / src / compiler / spirv / vtn_subgroup.c
index 033c43e601c1f4dae8eafbdb76b19eadfb057ed1..8d17845c1a0c420b83624b399c484c12dd2745f2 100644 (file)
 
 #include "vtn_private.h"
 
+static void
+vtn_build_subgroup_instr(struct vtn_builder *b,
+                         nir_intrinsic_op nir_op,
+                         struct vtn_ssa_value *dst,
+                         struct vtn_ssa_value *src0,
+                         nir_ssa_def *index,
+                         unsigned const_idx0,
+                         unsigned const_idx1)
+{
+   /* Some of the subgroup operations take an index.  SPIR-V allows this to be
+    * any integer type.  To make things simpler for drivers, we only support
+    * 32-bit indices.
+    */
+   if (index && index->bit_size != 32)
+      index = nir_u2u32(&b->nb, index);
+
+   vtn_assert(dst->type == src0->type);
+   if (!glsl_type_is_vector_or_scalar(dst->type)) {
+      for (unsigned i = 0; i < glsl_get_length(dst->type); i++) {
+         vtn_build_subgroup_instr(b, nir_op, dst->elems[i],
+                                  src0->elems[i], index,
+                                  const_idx0, const_idx1);
+      }
+      return;
+   }
+
+   nir_intrinsic_instr *intrin =
+      nir_intrinsic_instr_create(b->nb.shader, nir_op);
+   nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                              dst->type, NULL);
+   intrin->num_components = intrin->dest.ssa.num_components;
+
+   intrin->src[0] = nir_src_for_ssa(src0->def);
+   if (index)
+      intrin->src[1] = nir_src_for_ssa(index);
+
+   intrin->const_index[0] = const_idx0;
+   intrin->const_index[1] = const_idx1;
+
+   nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+   dst->def = &intrin->dest.ssa;
+}
+
 void
 vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                     const uint32_t *w, unsigned count)
@@ -37,27 +81,241 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
                   "OpGroupNonUniformElect must return a Bool");
       nir_intrinsic_instr *elect =
          nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_elect);
-      nir_ssa_dest_init(&elect->instr, &elect->dest, 1, 32, NULL);
+      nir_ssa_dest_init_for_type(&elect->instr, &elect->dest,
+                                 val->type->type, NULL);
       nir_builder_instr_insert(&b->nb, &elect->instr);
       val->ssa->def = &elect->dest.ssa;
       break;
    }
 
-   case SpvOpGroupNonUniformAll:
-   case SpvOpGroupNonUniformAny:
-   case SpvOpGroupNonUniformAllEqual:
-   case SpvOpGroupNonUniformBroadcast:
-   case SpvOpGroupNonUniformBroadcastFirst:
-   case SpvOpGroupNonUniformBallot:
-   case SpvOpGroupNonUniformInverseBallot:
+   case SpvOpGroupNonUniformBallot: ++w;
+   case SpvOpSubgroupBallotKHR: {
+      vtn_fail_if(val->type->type != glsl_vector_type(GLSL_TYPE_UINT, 4),
+                  "OpGroupNonUniformBallot must return a uvec4");
+      nir_intrinsic_instr *ballot =
+         nir_intrinsic_instr_create(b->nb.shader, nir_intrinsic_ballot);
+      ballot->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[3])->def);
+      nir_ssa_dest_init(&ballot->instr, &ballot->dest, 4, 32, NULL);
+      ballot->num_components = 4;
+      nir_builder_instr_insert(&b->nb, &ballot->instr);
+      val->ssa->def = &ballot->dest.ssa;
+      break;
+   }
+
+   case SpvOpGroupNonUniformInverseBallot: {
+      /* This one is just a BallotBitfieldExtract with subgroup invocation.
+       * We could add a NIR intrinsic but it's easier to just lower it on the
+       * spot.
+       */
+      nir_intrinsic_instr *intrin =
+         nir_intrinsic_instr_create(b->nb.shader,
+                                    nir_intrinsic_ballot_bitfield_extract);
+
+      intrin->src[0] = nir_src_for_ssa(vtn_ssa_value(b, w[4])->def);
+      intrin->src[1] = nir_src_for_ssa(nir_load_subgroup_invocation(&b->nb));
+
+      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                                 val->type->type, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+      val->ssa->def = &intrin->dest.ssa;
+      break;
+   }
+
    case SpvOpGroupNonUniformBallotBitExtract:
    case SpvOpGroupNonUniformBallotBitCount:
    case SpvOpGroupNonUniformBallotFindLSB:
-   case SpvOpGroupNonUniformBallotFindMSB:
+   case SpvOpGroupNonUniformBallotFindMSB: {
+      nir_ssa_def *src0, *src1 = NULL;
+      nir_intrinsic_op op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformBallotBitExtract:
+         op = nir_intrinsic_ballot_bitfield_extract;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         src1 = vtn_ssa_value(b, w[5])->def;
+         break;
+      case SpvOpGroupNonUniformBallotBitCount:
+         switch ((SpvGroupOperation)w[4]) {
+         case SpvGroupOperationReduce:
+            op = nir_intrinsic_ballot_bit_count_reduce;
+            break;
+         case SpvGroupOperationInclusiveScan:
+            op = nir_intrinsic_ballot_bit_count_inclusive;
+            break;
+         case SpvGroupOperationExclusiveScan:
+            op = nir_intrinsic_ballot_bit_count_exclusive;
+            break;
+         default:
+            unreachable("Invalid group operation");
+         }
+         src0 = vtn_ssa_value(b, w[5])->def;
+         break;
+      case SpvOpGroupNonUniformBallotFindLSB:
+         op = nir_intrinsic_ballot_find_lsb;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         break;
+      case SpvOpGroupNonUniformBallotFindMSB:
+         op = nir_intrinsic_ballot_find_msb;
+         src0 = vtn_ssa_value(b, w[4])->def;
+         break;
+      default:
+         unreachable("Unhandled opcode");
+      }
+
+      nir_intrinsic_instr *intrin =
+         nir_intrinsic_instr_create(b->nb.shader, op);
+
+      intrin->src[0] = nir_src_for_ssa(src0);
+      if (src1)
+         intrin->src[1] = nir_src_for_ssa(src1);
+
+      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                                 val->type->type, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+      val->ssa->def = &intrin->dest.ssa;
+      break;
+   }
+
+   case SpvOpGroupNonUniformBroadcastFirst: ++w;
+   case SpvOpSubgroupFirstInvocationKHR:
+      vtn_build_subgroup_instr(b, nir_intrinsic_read_first_invocation,
+                               val->ssa, vtn_ssa_value(b, w[3]), NULL, 0, 0);
+      break;
+
+   case SpvOpGroupNonUniformBroadcast:
+   case SpvOpGroupBroadcast: ++w;
+   case SpvOpSubgroupReadInvocationKHR:
+      vtn_build_subgroup_instr(b, nir_intrinsic_read_invocation,
+                               val->ssa, vtn_ssa_value(b, w[3]),
+                               vtn_ssa_value(b, w[4])->def, 0, 0);
+      break;
+
+   case SpvOpGroupNonUniformAll:
+   case SpvOpGroupNonUniformAny:
+   case SpvOpGroupNonUniformAllEqual:
+   case SpvOpGroupAll:
+   case SpvOpGroupAny:
+   case SpvOpSubgroupAllKHR:
+   case SpvOpSubgroupAnyKHR:
+   case SpvOpSubgroupAllEqualKHR: {
+      vtn_fail_if(val->type->type != glsl_bool_type(),
+                  "OpGroupNonUniform(All|Any|AllEqual) must return a bool");
+      nir_intrinsic_op op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformAll:
+      case SpvOpGroupAll:
+      case SpvOpSubgroupAllKHR:
+         op = nir_intrinsic_vote_all;
+         break;
+      case SpvOpGroupNonUniformAny:
+      case SpvOpGroupAny:
+      case SpvOpSubgroupAnyKHR:
+         op = nir_intrinsic_vote_any;
+         break;
+      case SpvOpSubgroupAllEqualKHR:
+         op = nir_intrinsic_vote_ieq;
+         break;
+      case SpvOpGroupNonUniformAllEqual:
+         switch (glsl_get_base_type(vtn_ssa_value(b, w[4])->type)) {
+         case GLSL_TYPE_FLOAT:
+         case GLSL_TYPE_FLOAT16:
+         case GLSL_TYPE_DOUBLE:
+            op = nir_intrinsic_vote_feq;
+            break;
+         case GLSL_TYPE_UINT:
+         case GLSL_TYPE_INT:
+         case GLSL_TYPE_UINT8:
+         case GLSL_TYPE_INT8:
+         case GLSL_TYPE_UINT16:
+         case GLSL_TYPE_INT16:
+         case GLSL_TYPE_UINT64:
+         case GLSL_TYPE_INT64:
+         case GLSL_TYPE_BOOL:
+            op = nir_intrinsic_vote_ieq;
+            break;
+         default:
+            unreachable("Unhandled type");
+         }
+         break;
+      default:
+         unreachable("Unhandled opcode");
+      }
+
+      nir_ssa_def *src0;
+      if (opcode == SpvOpGroupNonUniformAll || opcode == SpvOpGroupAll ||
+          opcode == SpvOpGroupNonUniformAny || opcode == SpvOpGroupAny ||
+          opcode == SpvOpGroupNonUniformAllEqual) {
+         src0 = vtn_ssa_value(b, w[4])->def;
+      } else {
+         src0 = vtn_ssa_value(b, w[3])->def;
+      }
+      nir_intrinsic_instr *intrin =
+         nir_intrinsic_instr_create(b->nb.shader, op);
+      if (nir_intrinsic_infos[op].src_components[0] == 0)
+         intrin->num_components = src0->num_components;
+      intrin->src[0] = nir_src_for_ssa(src0);
+      nir_ssa_dest_init_for_type(&intrin->instr, &intrin->dest,
+                                 val->type->type, NULL);
+      nir_builder_instr_insert(&b->nb, &intrin->instr);
+
+      val->ssa->def = &intrin->dest.ssa;
+      break;
+   }
+
    case SpvOpGroupNonUniformShuffle:
    case SpvOpGroupNonUniformShuffleXor:
    case SpvOpGroupNonUniformShuffleUp:
-   case SpvOpGroupNonUniformShuffleDown:
+   case SpvOpGroupNonUniformShuffleDown: {
+      nir_intrinsic_op op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformShuffle:
+         op = nir_intrinsic_shuffle;
+         break;
+      case SpvOpGroupNonUniformShuffleXor:
+         op = nir_intrinsic_shuffle_xor;
+         break;
+      case SpvOpGroupNonUniformShuffleUp:
+         op = nir_intrinsic_shuffle_up;
+         break;
+      case SpvOpGroupNonUniformShuffleDown:
+         op = nir_intrinsic_shuffle_down;
+         break;
+      default:
+         unreachable("Invalid opcode");
+      }
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+                               vtn_ssa_value(b, w[5])->def, 0, 0);
+      break;
+   }
+
+   case SpvOpGroupNonUniformQuadBroadcast:
+      vtn_build_subgroup_instr(b, nir_intrinsic_quad_broadcast,
+                               val->ssa, vtn_ssa_value(b, w[4]),
+                               vtn_ssa_value(b, w[5])->def, 0, 0);
+      break;
+
+   case SpvOpGroupNonUniformQuadSwap: {
+      unsigned direction = vtn_constant_uint(b, w[5]);
+      nir_intrinsic_op op;
+      switch (direction) {
+      case 0:
+         op = nir_intrinsic_quad_swap_horizontal;
+         break;
+      case 1:
+         op = nir_intrinsic_quad_swap_vertical;
+         break;
+      case 2:
+         op = nir_intrinsic_quad_swap_diagonal;
+         break;
+      default:
+         vtn_fail("Invalid constant value in OpGroupNonUniformQuadSwap");
+      }
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[4]),
+                               NULL, 0, 0);
+      break;
+   }
+
    case SpvOpGroupNonUniformIAdd:
    case SpvOpGroupNonUniformFAdd:
    case SpvOpGroupNonUniformIMul:
@@ -74,8 +332,112 @@ vtn_handle_subgroup(struct vtn_builder *b, SpvOp opcode,
    case SpvOpGroupNonUniformLogicalAnd:
    case SpvOpGroupNonUniformLogicalOr:
    case SpvOpGroupNonUniformLogicalXor:
-   case SpvOpGroupNonUniformQuadBroadcast:
-   case SpvOpGroupNonUniformQuadSwap:
+   case SpvOpGroupIAdd:
+   case SpvOpGroupFAdd:
+   case SpvOpGroupFMin:
+   case SpvOpGroupUMin:
+   case SpvOpGroupSMin:
+   case SpvOpGroupFMax:
+   case SpvOpGroupUMax:
+   case SpvOpGroupSMax:
+   case SpvOpGroupIAddNonUniformAMD:
+   case SpvOpGroupFAddNonUniformAMD:
+   case SpvOpGroupFMinNonUniformAMD:
+   case SpvOpGroupUMinNonUniformAMD:
+   case SpvOpGroupSMinNonUniformAMD:
+   case SpvOpGroupFMaxNonUniformAMD:
+   case SpvOpGroupUMaxNonUniformAMD:
+   case SpvOpGroupSMaxNonUniformAMD: {
+      nir_op reduction_op;
+      switch (opcode) {
+      case SpvOpGroupNonUniformIAdd:
+      case SpvOpGroupIAdd:
+      case SpvOpGroupIAddNonUniformAMD:
+         reduction_op = nir_op_iadd;
+         break;
+      case SpvOpGroupNonUniformFAdd:
+      case SpvOpGroupFAdd:
+      case SpvOpGroupFAddNonUniformAMD:
+         reduction_op = nir_op_fadd;
+         break;
+      case SpvOpGroupNonUniformIMul:
+         reduction_op = nir_op_imul;
+         break;
+      case SpvOpGroupNonUniformFMul:
+         reduction_op = nir_op_fmul;
+         break;
+      case SpvOpGroupNonUniformSMin:
+      case SpvOpGroupSMin:
+      case SpvOpGroupSMinNonUniformAMD:
+         reduction_op = nir_op_imin;
+         break;
+      case SpvOpGroupNonUniformUMin:
+      case SpvOpGroupUMin:
+      case SpvOpGroupUMinNonUniformAMD:
+         reduction_op = nir_op_umin;
+         break;
+      case SpvOpGroupNonUniformFMin:
+      case SpvOpGroupFMin:
+      case SpvOpGroupFMinNonUniformAMD:
+         reduction_op = nir_op_fmin;
+         break;
+      case SpvOpGroupNonUniformSMax:
+      case SpvOpGroupSMax:
+      case SpvOpGroupSMaxNonUniformAMD:
+         reduction_op = nir_op_imax;
+         break;
+      case SpvOpGroupNonUniformUMax:
+      case SpvOpGroupUMax:
+      case SpvOpGroupUMaxNonUniformAMD:
+         reduction_op = nir_op_umax;
+         break;
+      case SpvOpGroupNonUniformFMax:
+      case SpvOpGroupFMax:
+      case SpvOpGroupFMaxNonUniformAMD:
+         reduction_op = nir_op_fmax;
+         break;
+      case SpvOpGroupNonUniformBitwiseAnd:
+      case SpvOpGroupNonUniformLogicalAnd:
+         reduction_op = nir_op_iand;
+         break;
+      case SpvOpGroupNonUniformBitwiseOr:
+      case SpvOpGroupNonUniformLogicalOr:
+         reduction_op = nir_op_ior;
+         break;
+      case SpvOpGroupNonUniformBitwiseXor:
+      case SpvOpGroupNonUniformLogicalXor:
+         reduction_op = nir_op_ixor;
+         break;
+      default:
+         unreachable("Invalid reduction operation");
+      }
+
+      nir_intrinsic_op op;
+      unsigned cluster_size = 0;
+      switch ((SpvGroupOperation)w[4]) {
+      case SpvGroupOperationReduce:
+         op = nir_intrinsic_reduce;
+         break;
+      case SpvGroupOperationInclusiveScan:
+         op = nir_intrinsic_inclusive_scan;
+         break;
+      case SpvGroupOperationExclusiveScan:
+         op = nir_intrinsic_exclusive_scan;
+         break;
+      case SpvGroupOperationClusteredReduce:
+         op = nir_intrinsic_reduce;
+         assert(count == 7);
+         cluster_size = vtn_constant_uint(b, w[6]);
+         break;
+      default:
+         unreachable("Invalid group operation");
+      }
+
+      vtn_build_subgroup_instr(b, op, val->ssa, vtn_ssa_value(b, w[5]),
+                               NULL, reduction_op, cluster_size);
+      break;
+   }
+
    default:
       unreachable("Invalid SPIR-V opcode");
    }