return NULL;
}
+static nir_function *
+vtn_emit_kernel_entry_point_wrapper(struct vtn_builder *b,
+ const nir_function *entry_point)
+{
+ vtn_assert(entry_point == b->entry_point->func->impl->function);
+ vtn_fail_if(!entry_point->name, "entry points are required to have a name");
+ const char *func_name =
+ ralloc_asprintf(b->shader, "__wrapped_%s", entry_point->name);
+
+ /* we shouldn't have any inputs yet */
+ vtn_assert(!entry_point->shader->num_inputs);
+ vtn_assert(b->shader->info.stage == MESA_SHADER_KERNEL);
+
+ nir_function *main_entry_point = nir_function_create(b->shader, func_name);
+ main_entry_point->impl = nir_function_impl_create(main_entry_point);
+ nir_builder_init(&b->nb, main_entry_point->impl);
+ b->nb.cursor = nir_after_cf_list(&main_entry_point->impl->body);
+ b->func_param_idx = 0;
+
+ nir_call_instr *call = nir_call_instr_create(b->nb.shader, entry_point);
+
+ for (unsigned i = 0; i < entry_point->num_params; ++i) {
+ struct vtn_type *param_type = b->entry_point->func->type->params[i];
+
+ /* consider all pointers to function memory to be parameters passed
+ * by value
+ */
+ bool is_by_val = param_type->base_type == vtn_base_type_pointer &&
+ param_type->storage_class == SpvStorageClassFunction;
+
+ /* input variable */
+ nir_variable *in_var = rzalloc(b->nb.shader, nir_variable);
+ in_var->data.mode = nir_var_shader_in;
+ in_var->data.read_only = true;
+ in_var->data.location = i;
+
+ if (is_by_val)
+ in_var->type = param_type->deref->type;
+ else
+ in_var->type = param_type->type;
+
+ nir_shader_add_variable(b->nb.shader, in_var);
+ b->nb.shader->num_inputs++;
+
+ /* we have to copy the entire variable into function memory */
+ if (is_by_val) {
+ nir_variable *copy_var =
+ nir_local_variable_create(main_entry_point->impl, in_var->type,
+ "copy_in");
+ nir_copy_var(&b->nb, copy_var, in_var);
+ call->params[i] =
+ nir_src_for_ssa(&nir_build_deref_var(&b->nb, copy_var)->dest.ssa);
+ } else {
+ call->params[i] = nir_src_for_ssa(nir_load_var(&b->nb, in_var));
+ }
+ }
+
+ nir_builder_instr_insert(&b->nb, &call->instr);
+
+ return main_entry_point;
+}
+
nir_function *
spirv_to_nir(const uint32_t *words, size_t word_count,
struct nir_spirv_specialization *spec, unsigned num_spec,
nir_function *entry_point = b->entry_point->func->impl->function;
vtn_assert(entry_point);
+ /* post process entry_points with input params */
+ if (entry_point->num_params && b->shader->info.stage == MESA_SHADER_KERNEL)
+ entry_point = vtn_emit_kernel_entry_point_wrapper(b, entry_point);
+
entry_point->is_entrypoint = true;
/* When multiple shader stages exist in the same SPIR-V module, we