zink: start using per-stage flags for new shaders, refcount shader modules
[mesa.git] / src / gallium / drivers / zink / zink_program.c
index ee735a810d28d0dd526790d5a7781248bde21283..f6b3acdc84d20cab17b5a68e6bb64f5c69d929c8 100644 (file)
@@ -45,6 +45,12 @@ debug_describe_zink_gfx_program(char *buf, const struct zink_gfx_program *ptr)
    sprintf(buf, "zink_gfx_program");
 }
 
+static void
+debug_describe_zink_shader_module(char *buf, const struct zink_shader_module *ptr)
+{
+   sprintf(buf, "zink_shader_module");
+}
+
 static VkDescriptorSetLayout
 create_desc_set_layout(VkDevice dev,
                        struct zink_shader *stages[ZINK_SHADER_COUNT],
@@ -107,15 +113,47 @@ create_pipeline_layout(VkDevice dev, VkDescriptorSetLayout dsl)
    return layout;
 }
 
+static void
+zink_destroy_shader_module(struct zink_screen *screen, struct zink_shader_module *zm)
+{
+   vkDestroyShaderModule(screen->dev, zm->shader, NULL);
+   free(zm);
+}
+
+static inline void
+zink_shader_module_reference(struct zink_screen *screen,
+                           struct zink_shader_module **dst,
+                           struct zink_shader_module *src)
+{
+   struct zink_shader_module *old_dst = dst ? *dst : NULL;
+
+   if (pipe_reference_described(old_dst ? &old_dst->reference : NULL, &src->reference,
+                                (debug_reference_descriptor)debug_describe_zink_shader_module))
+      zink_destroy_shader_module(screen, old_dst);
+   if (dst) *dst = src;
+}
+
 static void
 update_shader_modules(struct zink_context *ctx, struct zink_shader *stages[ZINK_SHADER_COUNT], struct zink_gfx_program *prog)
 {
+   struct zink_shader *dirty[ZINK_SHADER_COUNT] = {NULL};
+
+   unsigned dirty_shader_stages = ctx->dirty_shader_stages;
+   while (dirty_shader_stages) {
+      unsigned type = u_bit_scan(&dirty_shader_stages);
+      dirty[type] = stages[type];
+   }
    for (int i = 0; i < ZINK_SHADER_COUNT; ++i) {
-      if (stages[i]) {
-         prog->stages[i] = zink_shader_compile(zink_screen(ctx->base.screen), stages[i]);
-         prog->shaders[i] = stages[i];
-      }
+      if (dirty[i]) {
+         prog->stages[i] = CALLOC_STRUCT(zink_shader_module);
+         assert(prog->stages[i]);
+         pipe_reference_init(&prog->stages[i]->reference, 1);
+         prog->stages[i]->shader = zink_shader_compile(zink_screen(ctx->base.screen), stages[i]);
+      } else if (stages[i]) /* reuse existing shader module */
+         zink_shader_module_reference(zink_screen(ctx->base.screen), &prog->stages[i], ctx->curr_program->stages[i]);
+      prog->shaders[i] = stages[i];
    }
+   ctx->dirty_shader_stages = 0;
 }
 
 static uint32_t
@@ -204,7 +242,7 @@ zink_destroy_gfx_program(struct zink_screen *screen,
       if (prog->shaders[i])
          gfx_program_remove_shader(prog, prog->shaders[i]);
       if (prog->stages[i])
-         vkDestroyShaderModule(screen->dev, prog->stages[i], NULL);
+         zink_shader_module_reference(screen, &prog->stages[i], NULL);
    }
 
    /* unref all used render-passes */
@@ -323,7 +361,7 @@ bind_stage(struct zink_context *ctx, enum pipe_shader_type stage,
 {
    assert(stage < PIPE_SHADER_COMPUTE);
    ctx->gfx_stages[stage] = shader;
-   ctx->dirty_program = true;
+   ctx->dirty_shader_stages |= 1 << stage;
 }
 
 static void