turnip: fix binning shader compilation
[mesa.git] / src / freedreno / vulkan / tu_shader.c
index 8d6ccecdd9cc9207ebd8e2b8d5e524caceb818fe..f3d81675bb17e692f4407169ed13bb05efc9fb6e 100644 (file)
@@ -28,7 +28,7 @@
 
 #include "ir3/ir3_nir.h"
 
-static nir_function *
+static nir_shader *
 tu_spirv_to_nir(struct ir3_compiler *compiler,
                 const uint32_t *words,
                 size_t word_count,
@@ -38,7 +38,6 @@ tu_spirv_to_nir(struct ir3_compiler *compiler,
 {
    /* TODO these are made-up */
    const struct spirv_to_nir_options spirv_options = {
-      .lower_workgroup_access_to_offsets = true,
       .lower_ubo_ssbo_access_to_offsets = true,
       .caps = { false },
    };
@@ -68,16 +67,16 @@ tu_spirv_to_nir(struct ir3_compiler *compiler,
       num_spec = spec_info->mapEntryCount;
    }
 
-   nir_function *entry_point =
+   nir_shader *nir =
       spirv_to_nir(words, word_count, spec, num_spec, stage, entry_point_name,
                    &spirv_options, nir_options);
 
    free(spec);
 
-   assert(entry_point->shader->info.stage == stage);
-   nir_validate_shader(entry_point->shader, "after spirv_to_nir");
+   assert(nir->info.stage == stage);
+   nir_validate_shader(nir, "after spirv_to_nir");
 
-   return entry_point;
+   return nir;
 }
 
 static void
@@ -128,16 +127,14 @@ tu_shader_create(struct tu_device *dev,
 
    /* translate SPIR-V to NIR */
    assert(module->code_size % 4 == 0);
-   nir_function *entry_point = tu_spirv_to_nir(
+   nir_shader *nir = tu_spirv_to_nir(
       dev->compiler, (const uint32_t *) module->code, module->code_size / 4,
       stage, stage_info->pName, stage_info->pSpecializationInfo);
-   if (!entry_point) {
+   if (!nir) {
       vk_free2(&dev->alloc, alloc, shader);
       return NULL;
    }
 
-   nir_shader *nir = entry_point->shader;
-
    if (unlikely(dev->physical_device->instance->debug_flags & TU_DEBUG_NIR)) {
       fprintf(stderr, "translated nir:\n");
       nir_print_shader(nir, stderr);
@@ -176,7 +173,7 @@ tu_shader_create(struct tu_device *dev,
    NIR_PASS_V(nir, nir_lower_frexp);
    NIR_PASS_V(nir, nir_lower_io, nir_var_all, ir3_glsl_type_size, 0);
 
-   nir_shader_gather_info(nir, entry_point->impl);
+   nir_shader_gather_info(nir, nir_shader_get_entrypoint(nir));
 
    shader->ir3_shader.compiler = dev->compiler;
    shader->ir3_shader.type = stage;
@@ -225,13 +222,14 @@ tu_shader_compile_options_init(
 static uint32_t *
 tu_compile_shader_variant(struct ir3_shader *shader,
                           const struct ir3_shader_key *key,
-                          bool binning_pass,
+                          struct ir3_shader_variant *nonbinning,
                           struct ir3_shader_variant *variant)
 {
    variant->shader = shader;
    variant->type = shader->type;
    variant->key = *key;
-   variant->binning_pass = binning_pass;
+   variant->binning_pass = !!nonbinning;
+   variant->nonbinning = nonbinning;
 
    int ret = ir3_compile_shader_nir(shader->compiler, variant);
    if (ret)
@@ -262,7 +260,7 @@ tu_shader_compile(struct tu_device *dev,
    }
 
    shader->binary = tu_compile_shader_variant(
-      &shader->ir3_shader, &options->key, false, &shader->variants[0]);
+      &shader->ir3_shader, &options->key, NULL, &shader->variants[0]);
    if (!shader->binary)
       return VK_ERROR_OUT_OF_HOST_MEMORY;
 
@@ -270,7 +268,8 @@ tu_shader_compile(struct tu_device *dev,
    if (options->include_binning_pass &&
        shader->ir3_shader.type == MESA_SHADER_VERTEX) {
       shader->binning_binary = tu_compile_shader_variant(
-         &shader->ir3_shader, &options->key, true, &shader->variants[1]);
+         &shader->ir3_shader, &options->key, &shader->variants[0],
+         &shader->variants[1]);
       if (!shader->binning_binary)
          return VK_ERROR_OUT_OF_HOST_MEMORY;