spirv/nir: add atomic counter support on vtn_handle_ssbo_or_shared_atomic
[mesa.git] / src / compiler / spirv / spirv_to_nir.c
index 8e20044ac4dff4e566114d18262ae281dbd225ca..270f263d0471c5017d9cb75ca330668d1b3c6f72 100644 (file)
@@ -1795,42 +1795,54 @@ vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
    vtn_callee->referenced = true;
 
    nir_call_instr *call = nir_call_instr_create(b->nb.shader, callee);
-   for (unsigned i = 0; i < call->num_params; i++) {
-      unsigned arg_id = w[4 + i];
-      struct vtn_value *arg = vtn_untyped_value(b, arg_id);
-      if (arg->value_type == vtn_value_type_pointer &&
-          arg->pointer->ptr_type->type == NULL) {
-         nir_deref_var *d = vtn_pointer_to_deref_var(b, arg->pointer);
-         call->params[i] = nir_deref_var_clone(d, call);
-      } else {
-         struct vtn_ssa_value *arg_ssa = vtn_ssa_value(b, arg_id);
 
-         /* Make a temporary to store the argument in */
-         nir_variable *tmp =
-            nir_local_variable_create(b->nb.impl, arg_ssa->type, "arg_tmp");
-         call->params[i] = nir_deref_var_create(call, tmp);
+   unsigned param_idx = 0;
 
-         vtn_local_store(b, arg_ssa,
-                         nir_build_deref_for_chain(&b->nb, call->params[i]));
-      }
+   nir_deref_instr *ret_deref = NULL;
+   struct vtn_type *ret_type = vtn_callee->type->return_type;
+   if (ret_type->base_type != vtn_base_type_void) {
+      nir_variable *ret_tmp =
+         nir_local_variable_create(b->nb.impl, ret_type->type, "return_tmp");
+      ret_deref = nir_build_deref_var(&b->nb, ret_tmp);
+      call->params[param_idx++] = nir_src_for_ssa(&ret_deref->dest.ssa);
    }
 
-   nir_variable *out_tmp = NULL;
-   vtn_assert(res_type->type == callee->return_type);
-   if (!glsl_type_is_void(callee->return_type)) {
-      out_tmp = nir_local_variable_create(b->nb.impl, callee->return_type,
-                                          "out_tmp");
-      call->return_deref = nir_deref_var_create(call, out_tmp);
+   for (unsigned i = 0; i < vtn_callee->type->length; i++) {
+      struct vtn_type *arg_type = vtn_callee->type->params[i];
+      unsigned arg_id = w[4 + i];
+
+      if (arg_type->base_type == vtn_base_type_sampled_image) {
+         struct vtn_sampled_image *sampled_image =
+            vtn_value(b, arg_id, vtn_value_type_sampled_image)->sampled_image;
+
+         call->params[param_idx++] =
+            nir_src_for_ssa(&sampled_image->image->deref->dest.ssa);
+         call->params[param_idx++] =
+            nir_src_for_ssa(&sampled_image->sampler->deref->dest.ssa);
+      } else if (arg_type->base_type == vtn_base_type_pointer ||
+                 arg_type->base_type == vtn_base_type_image ||
+                 arg_type->base_type == vtn_base_type_sampler) {
+         struct vtn_pointer *pointer =
+            vtn_value(b, arg_id, vtn_value_type_pointer)->pointer;
+         call->params[param_idx++] =
+            nir_src_for_ssa(vtn_pointer_to_ssa(b, pointer));
+      } else {
+         /* This is a regular SSA value and we need a temporary */
+         nir_variable *tmp =
+            nir_local_variable_create(b->nb.impl, arg_type->type, "arg_tmp");
+         nir_deref_instr *tmp_deref = nir_build_deref_var(&b->nb, tmp);
+         vtn_local_store(b, vtn_ssa_value(b, arg_id), tmp_deref);
+         call->params[param_idx++] = nir_src_for_ssa(&tmp_deref->dest.ssa);
+      }
    }
+   assert(param_idx == call->num_params);
 
    nir_builder_instr_insert(&b->nb, &call->instr);
 
-   if (glsl_type_is_void(callee->return_type)) {
+   if (ret_type->base_type == vtn_base_type_void) {
       vtn_push_value(b, w[2], vtn_value_type_undef);
    } else {
-      nir_deref_instr *return_deref =
-         nir_build_deref_for_chain(&b->nb, call->return_deref);
-      vtn_push_ssa(b, w[2], res_type, vtn_local_load(b, return_deref));
+      vtn_push_ssa(b, w[2], res_type, vtn_local_load(b, ret_deref));
    }
 }
 
@@ -2498,6 +2510,35 @@ get_ssbo_nir_atomic_op(struct vtn_builder *b, SpvOp opcode)
    }
 }
 
+static nir_intrinsic_op
+get_uniform_nir_atomic_op(struct vtn_builder *b, SpvOp opcode)
+{
+   switch (opcode) {
+#define OP(S, N) case SpvOp##S: return nir_intrinsic_atomic_counter_ ##N;
+   OP(AtomicLoad,             read_deref)
+   OP(AtomicExchange,         exchange)
+   OP(AtomicCompareExchange,  comp_swap)
+   OP(AtomicIIncrement,       inc_deref)
+   OP(AtomicIDecrement,       dec_deref)
+   OP(AtomicIAdd,             add_deref)
+   OP(AtomicISub,             add_deref)
+   OP(AtomicUMin,             min_deref)
+   OP(AtomicUMax,             max_deref)
+   OP(AtomicAnd,              and_deref)
+   OP(AtomicOr,               or_deref)
+   OP(AtomicXor,              xor_deref)
+#undef OP
+   default:
+      /* We left the following out: AtomicStore, AtomicSMin and
+       * AtomicSmax. Right now there are not nir intrinsics for them. At this
+       * moment Atomic Counter support is needed for ARB_spirv support, so is
+       * only need to support GLSL Atomic Counters that are uints and don't
+       * allow direct storage.
+       */
+      unreachable("Invalid uniform atomic");
+   }
+}
+
 static nir_intrinsic_op
 get_shared_nir_atomic_op(struct vtn_builder *b, SpvOp opcode)
 {
@@ -2550,9 +2591,12 @@ get_deref_nir_atomic_op(struct vtn_builder *b, SpvOp opcode)
    }
 }
 
+/*
+ * Handles shared atomics, ssbo atomics and atomic counters.
+ */
 static void
-vtn_handle_ssbo_or_shared_atomic(struct vtn_builder *b, SpvOp opcode,
-                                 const uint32_t *w, unsigned count)
+vtn_handle_atomics(struct vtn_builder *b, SpvOp opcode,
+                   const uint32_t *w, unsigned count)
 {
    struct vtn_pointer *ptr;
    nir_intrinsic_instr *atomic;
@@ -2589,8 +2633,54 @@ vtn_handle_ssbo_or_shared_atomic(struct vtn_builder *b, SpvOp opcode,
    SpvMemorySemanticsMask semantics = w[5];
    */
 
-   if (ptr->mode == vtn_variable_mode_workgroup &&
-       !b->options->lower_workgroup_access_to_offsets) {
+   /* uniform as "atomic counter uniform" */
+   if (ptr->mode == vtn_variable_mode_uniform) {
+      nir_deref_instr *deref = vtn_pointer_to_deref(b, ptr);
+      const struct glsl_type *deref_type = deref->type;
+      nir_intrinsic_op op = get_uniform_nir_atomic_op(b, opcode);
+      atomic = nir_intrinsic_instr_create(b->nb.shader, op);
+      atomic->src[0] = nir_src_for_ssa(&deref->dest.ssa);
+
+      /* SSBO needs to initialize index/offset. In this case we don't need to,
+       * as that info is already stored on the ptr->var->var nir_variable (see
+       * vtn_create_variable)
+       */
+
+      switch (opcode) {
+      case SpvOpAtomicLoad:
+         atomic->num_components = glsl_get_vector_elements(deref_type);
+         break;
+
+      case SpvOpAtomicStore:
+         atomic->num_components = glsl_get_vector_elements(deref_type);
+         nir_intrinsic_set_write_mask(atomic, (1 << atomic->num_components) - 1);
+         break;
+
+      case SpvOpAtomicExchange:
+      case SpvOpAtomicCompareExchange:
+      case SpvOpAtomicCompareExchangeWeak:
+      case SpvOpAtomicIIncrement:
+      case SpvOpAtomicIDecrement:
+      case SpvOpAtomicIAdd:
+      case SpvOpAtomicISub:
+      case SpvOpAtomicSMin:
+      case SpvOpAtomicUMin:
+      case SpvOpAtomicSMax:
+      case SpvOpAtomicUMax:
+      case SpvOpAtomicAnd:
+      case SpvOpAtomicOr:
+      case SpvOpAtomicXor:
+         /* Nothing: we don't need to call fill_common_atomic_sources here, as
+          * atomic counter uniforms doesn't have sources
+          */
+         break;
+
+      default:
+         unreachable("Invalid SPIR-V atomic");
+
+      }
+   } else if (ptr->mode == vtn_variable_mode_workgroup &&
+              !b->options->lower_workgroup_access_to_offsets) {
       nir_deref_instr *deref = vtn_pointer_to_deref(b, ptr);
       const struct glsl_type *deref_type = deref->type;
       nir_intrinsic_op op = get_deref_nir_atomic_op(b, opcode);
@@ -3778,7 +3868,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
          vtn_handle_image(b, opcode, w, count);
       } else {
          vtn_assert(pointer->value_type == vtn_value_type_pointer);
-         vtn_handle_ssbo_or_shared_atomic(b, opcode, w, count);
+         vtn_handle_atomics(b, opcode, w, count);
       }
       break;
    }
@@ -3789,7 +3879,7 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode,
          vtn_handle_image(b, opcode, w, count);
       } else {
          vtn_assert(pointer->value_type == vtn_value_type_pointer);
-         vtn_handle_ssbo_or_shared_atomic(b, opcode, w, count);
+         vtn_handle_atomics(b, opcode, w, count);
       }
       break;
    }