#include "ir3/ir3_nir.h"
-static nir_function *
+static nir_shader *
tu_spirv_to_nir(struct ir3_compiler *compiler,
const uint32_t *words,
size_t word_count,
assert(entry_point->shader->info.stage == stage);
nir_validate_shader(entry_point->shader, "after spirv_to_nir");
- return entry_point;
+ return entry_point->shader;
}
static void
/* translate SPIR-V to NIR */
assert(module->code_size % 4 == 0);
- nir_function *entry_point = tu_spirv_to_nir(
+ nir_shader *nir = tu_spirv_to_nir(
dev->compiler, (const uint32_t *) module->code, module->code_size / 4,
stage, stage_info->pName, stage_info->pSpecializationInfo);
- if (!entry_point) {
+ if (!nir) {
vk_free2(&dev->alloc, alloc, shader);
return NULL;
}
- nir_shader *nir = entry_point->shader;
-
if (unlikely(dev->physical_device->instance->debug_flags & TU_DEBUG_NIR)) {
fprintf(stderr, "translated nir:\n");
nir_print_shader(nir, stderr);
NIR_PASS_V(nir, nir_lower_frexp);
NIR_PASS_V(nir, nir_lower_io, nir_var_all, ir3_glsl_type_size, 0);
- nir_shader_gather_info(nir, entry_point->impl);
+ nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
shader->ir3_shader.compiler = dev->compiler;
shader->ir3_shader.type = stage;