draw: pass number of images to image soa create
[mesa.git] / src / gallium / auxiliary / draw / draw_llvm.c
index e7cf576cbd13cdb140af3a1f6338dacc03a6658c..41018e82dcf55078288a61d2dc1253fdc279be7a 100644 (file)
@@ -51,7 +51,7 @@
 #include "gallivm/lp_bld_type.h"
 #include "gallivm/lp_bld_pack.h"
 #include "gallivm/lp_bld_format.h"
-
+#include "gallivm/lp_bld_misc.h"
 #include "tgsi/tgsi_exec.h"
 #include "tgsi/tgsi_dump.h"
 
@@ -59,8 +59,8 @@
 #include "util/u_pointer.h"
 #include "util/u_string.h"
 #include "util/simple_list.h"
-
-
+#include "nir_serialize.h"
+#include "util/mesa-sha1.h"
 #define DEBUG_STORE 0
 
 
@@ -152,6 +152,8 @@ create_jit_texture_type(struct gallivm_state *gallivm, const char *struct_name)
    elem_types[DRAW_JIT_TEXTURE_WIDTH]  =
    elem_types[DRAW_JIT_TEXTURE_HEIGHT] =
    elem_types[DRAW_JIT_TEXTURE_DEPTH] =
+   elem_types[DRAW_JIT_TEXTURE_NUM_SAMPLES] =
+   elem_types[DRAW_JIT_TEXTURE_SAMPLE_STRIDE] =
    elem_types[DRAW_JIT_TEXTURE_FIRST_LEVEL] =
    elem_types[DRAW_JIT_TEXTURE_LAST_LEVEL] = int32_type;
    elem_types[DRAW_JIT_TEXTURE_BASE] =
@@ -192,6 +194,12 @@ create_jit_texture_type(struct gallivm_state *gallivm, const char *struct_name)
    LP_CHECK_MEMBER_OFFSET(struct draw_jit_texture, mip_offsets,
                           target, texture_type,
                           DRAW_JIT_TEXTURE_MIP_OFFSETS);
+   LP_CHECK_MEMBER_OFFSET(struct draw_jit_texture, num_samples,
+                          target, texture_type,
+                          DRAW_JIT_TEXTURE_NUM_SAMPLES);
+   LP_CHECK_MEMBER_OFFSET(struct draw_jit_texture, sample_stride,
+                          target, texture_type,
+                          DRAW_JIT_TEXTURE_SAMPLE_STRIDE);
 
    LP_CHECK_STRUCT_SIZE(struct draw_jit_texture, target, texture_type);
 
@@ -252,7 +260,9 @@ create_jit_image_type(struct gallivm_state *gallivm, const char *struct_name)
    elem_types[DRAW_JIT_IMAGE_HEIGHT] =
    elem_types[DRAW_JIT_IMAGE_DEPTH] =
    elem_types[DRAW_JIT_IMAGE_ROW_STRIDE] =
-   elem_types[DRAW_JIT_IMAGE_IMG_STRIDE] = int32_type;
+   elem_types[DRAW_JIT_IMAGE_IMG_STRIDE] =
+   elem_types[DRAW_JIT_IMAGE_NUM_SAMPLES] =
+   elem_types[DRAW_JIT_IMAGE_SAMPLE_STRIDE] = int32_type;
    elem_types[DRAW_JIT_IMAGE_BASE] =
       LLVMPointerType(LLVMInt8TypeInContext(gallivm->context), 0);
 
@@ -278,6 +288,12 @@ create_jit_image_type(struct gallivm_state *gallivm, const char *struct_name)
    LP_CHECK_MEMBER_OFFSET(struct draw_jit_image, img_stride,
                           target, image_type,
                           DRAW_JIT_IMAGE_IMG_STRIDE);
+   LP_CHECK_MEMBER_OFFSET(struct draw_jit_image, num_samples,
+                          target, image_type,
+                          DRAW_JIT_IMAGE_NUM_SAMPLES);
+   LP_CHECK_MEMBER_OFFSET(struct draw_jit_image, sample_stride,
+                          target, image_type,
+                          DRAW_JIT_IMAGE_SAMPLE_STRIDE);
 
    LP_CHECK_STRUCT_SIZE(struct draw_jit_image, target, image_type);
 
@@ -807,6 +823,30 @@ draw_llvm_destroy(struct draw_llvm *llvm)
    FREE(llvm);
 }
 
+static void
+draw_get_ir_cache_key(struct nir_shader *nir,
+                      const void *key, size_t key_size,
+                      uint32_t val_32bit,
+                      unsigned char ir_sha1_cache_key[20])
+{
+   struct blob blob = { 0 };
+   unsigned ir_size;
+   void *ir_binary;
+
+   blob_init(&blob);
+   nir_serialize(&blob, nir, true);
+   ir_binary = blob.data;
+   ir_size = blob.size;
+
+   struct mesa_sha1 ctx;
+   _mesa_sha1_init(&ctx);
+   _mesa_sha1_update(&ctx, key, key_size);
+   _mesa_sha1_update(&ctx, ir_binary, ir_size);
+   _mesa_sha1_update(&ctx, &val_32bit, 4);
+   _mesa_sha1_final(&ctx, ir_sha1_cache_key);
+
+   blob_finish(&blob);
+}
 
 /**
  * Create LLVM-generated code for a vertex shader.
@@ -821,7 +861,9 @@ draw_llvm_create_variant(struct draw_llvm *llvm,
       llvm_vertex_shader(llvm->draw->vs.vertex_shader);
    LLVMTypeRef vertex_header;
    char module_name[64];
-
+   unsigned char ir_sha1_cache_key[20];
+   struct lp_cached_code cached = { 0 };
+   bool needs_caching = false;
    variant = MALLOC(sizeof *variant +
                     shader->variant_key_size -
                     sizeof variant->key);
@@ -830,16 +872,28 @@ draw_llvm_create_variant(struct draw_llvm *llvm,
 
    variant->llvm = llvm;
    variant->shader = shader;
+   memcpy(&variant->key, key, shader->variant_key_size);
 
    snprintf(module_name, sizeof(module_name), "draw_llvm_vs_variant%u",
             variant->shader->variants_cached);
 
-   variant->gallivm = gallivm_create(module_name, llvm->context);
+   if (shader->base.state.ir.nir && llvm->draw->disk_cache_cookie) {
+      draw_get_ir_cache_key(shader->base.state.ir.nir,
+                            key,
+                            shader->variant_key_size,
+                            num_inputs,
+                            ir_sha1_cache_key);
+
+      llvm->draw->disk_cache_find_shader(llvm->draw->disk_cache_cookie,
+                                         &cached,
+                                         ir_sha1_cache_key);
+      if (!cached.data_size)
+         needs_caching = true;
+   }
+   variant->gallivm = gallivm_create(module_name, llvm->context, &cached);
 
    create_jit_types(variant);
 
-   memcpy(&variant->key, key, shader->variant_key_size);
-
    if (gallivm_debug & (GALLIVM_DEBUG_TGSI | GALLIVM_DEBUG_IR)) {
       if (llvm->draw->vs.vertex_shader->state.type == PIPE_SHADER_IR_TGSI)
          tgsi_dump(llvm->draw->vs.vertex_shader->state.tokens, 0);
@@ -859,6 +913,10 @@ draw_llvm_create_variant(struct draw_llvm *llvm,
    variant->jit_func = (draw_jit_vert_func)
          gallivm_jit_function(variant->gallivm, variant->function);
 
+   if (needs_caching)
+      llvm->draw->disk_cache_insert_shader(llvm->draw->disk_cache_cookie,
+                                           &cached,
+                                           ir_sha1_cache_key);
    gallivm_free_ir(variant->gallivm);
 
    variant->list_item_global.base = variant;
@@ -1123,8 +1181,8 @@ store_aos(struct gallivm_state *gallivm,
  * {
  *   return (x >> 16) |              // vertex_id
  *          ((x & 0x3fff) << 18) |   // clipmask
- *          ((x & 0x4000) << 3) |    // pad
- *          ((x & 0x8000) << 1);     // edgeflag
+ *          ((x & 0x4000) << 3) |    // edgeflag
+ *          ((x & 0x8000) << 1);     // pad
  * }
  */
 static LLVMValueRef
@@ -1142,11 +1200,11 @@ adjust_mask(struct gallivm_state *gallivm,
    clipmask  = LLVMBuildAnd(builder, mask, lp_build_const_int32(gallivm, 0x3fff), "");
    clipmask  = LLVMBuildShl(builder, clipmask, lp_build_const_int32(gallivm, 18), "");
    if (0) {
-      pad = LLVMBuildAnd(builder, mask, lp_build_const_int32(gallivm, 0x4000), "");
-      pad = LLVMBuildShl(builder, pad, lp_build_const_int32(gallivm, 3), "");
+      pad = LLVMBuildAnd(builder, mask, lp_build_const_int32(gallivm, 0x8000), "");
+      pad = LLVMBuildShl(builder, pad, lp_build_const_int32(gallivm, 1), "");
    }
-   edgeflag = LLVMBuildAnd(builder, mask, lp_build_const_int32(gallivm, 0x8000), "");
-   edgeflag = LLVMBuildShl(builder, edgeflag, lp_build_const_int32(gallivm, 1), "");
+   edgeflag = LLVMBuildAnd(builder, mask, lp_build_const_int32(gallivm, 0x4000), "");
+   edgeflag = LLVMBuildShl(builder, edgeflag, lp_build_const_int32(gallivm, 3), "");
 
    mask = LLVMBuildOr(builder, vertex_id, clipmask, "");
    if (0) {
@@ -1780,6 +1838,7 @@ draw_gs_llvm_end_primitive(const struct lp_build_gs_iface *gs_base,
       draw_gs_jit_prim_lengths(variant->gallivm, variant->context_ptr);
    unsigned i;
 
+   LLVMValueRef cond = LLVMBuildICmp(gallivm->builder, LLVMIntNE, mask_vec, lp_build_const_int_vec(gallivm, bld->type, 0), "");
    for (i = 0; i < bld->type.length; ++i) {
       LLVMValueRef ind = lp_build_const_int32(gallivm, i);
       LLVMValueRef prims_emitted =
@@ -1788,10 +1847,14 @@ draw_gs_llvm_end_primitive(const struct lp_build_gs_iface *gs_base,
       LLVMValueRef num_vertices =
          LLVMBuildExtractElement(builder, verts_per_prim_vec, ind, "");
 
+      LLVMValueRef this_cond = LLVMBuildExtractElement(gallivm->builder, cond, ind, "");
+      struct lp_build_if_state ifthen;
+      lp_build_if(&ifthen, gallivm, this_cond);
       store_ptr = LLVMBuildGEP(builder, prim_lengts_ptr, &prims_emitted, 1, "");
       store_ptr = LLVMBuildLoad(builder, store_ptr, "");
       store_ptr = LLVMBuildGEP(builder, store_ptr, &ind, 1, "");
       LLVMBuildStore(builder, num_vertices, store_ptr);
+      lp_build_endif(&ifthen);
    }
 }
 
@@ -1875,8 +1938,7 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *variant)
 
    memset(&system_values, 0, sizeof(system_values));
    memset(&outputs, 0, sizeof(outputs));
-   snprintf(func_name, sizeof(func_name), "draw_llvm_vs_variant%u",
-            variant->shader->variants_cached);
+   snprintf(func_name, sizeof(func_name), "draw_llvm_vs_variant");
 
    i = 0;
    arg_types[i++] = get_context_ptr_type(variant);       /* context */
@@ -1903,6 +1965,8 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *variant)
       if (LLVMGetTypeKind(arg_types[i]) == LLVMPointerTypeKind)
          lp_add_function_attr(variant_func, i + 1, LP_FUNC_ATTR_NOALIAS);
 
+   if (gallivm->cache && gallivm->cache->data_size)
+      return;
    context_ptr               = LLVMGetParam(variant_func, 0);
    io_ptr                    = LLVMGetParam(variant_func, 1);
    vbuffers_ptr              = LLVMGetParam(variant_func, 2);
@@ -1968,9 +2032,10 @@ draw_llvm_generate(struct draw_llvm *llvm, struct draw_llvm_variant *variant)
    fake_buf_ptr = LLVMBuildGEP(builder, fake_buf, &bld.zero, 1, "");
 
    /* code generated texture sampling */
-   sampler = draw_llvm_sampler_soa_create(draw_llvm_variant_key_samplers(key));
+   sampler = draw_llvm_sampler_soa_create(draw_llvm_variant_key_samplers(key), key->nr_samplers);
 
-   image = draw_llvm_image_soa_create(draw_llvm_variant_key_images(key));
+   image = draw_llvm_image_soa_create(draw_llvm_variant_key_images(key),
+                                      key->nr_images);
 
    step = lp_build_const_int32(gallivm, vector_length);
 
@@ -2404,6 +2469,8 @@ draw_llvm_set_mapped_texture(struct draw_context *draw,
                              unsigned sview_idx,
                              uint32_t width, uint32_t height, uint32_t depth,
                              uint32_t first_level, uint32_t last_level,
+                             uint32_t num_samples,
+                             uint32_t sample_stride,
                              const void *base_ptr,
                              uint32_t row_stride[PIPE_MAX_TEXTURE_LEVELS],
                              uint32_t img_stride[PIPE_MAX_TEXTURE_LEVELS],
@@ -2412,18 +2479,24 @@ draw_llvm_set_mapped_texture(struct draw_context *draw,
    unsigned j;
    struct draw_jit_texture *jit_tex;
 
-   assert(shader_stage == PIPE_SHADER_VERTEX ||
-          shader_stage == PIPE_SHADER_GEOMETRY);
-
-   if (shader_stage == PIPE_SHADER_VERTEX) {
+   switch (shader_stage) {
+   case PIPE_SHADER_VERTEX:
       assert(sview_idx < ARRAY_SIZE(draw->llvm->jit_context.textures));
-
       jit_tex = &draw->llvm->jit_context.textures[sview_idx];
-   } else if (shader_stage == PIPE_SHADER_GEOMETRY) {
+      break;
+   case PIPE_SHADER_GEOMETRY:
       assert(sview_idx < ARRAY_SIZE(draw->llvm->gs_jit_context.textures));
-
       jit_tex = &draw->llvm->gs_jit_context.textures[sview_idx];
-   } else {
+      break;
+   case PIPE_SHADER_TESS_CTRL:
+      assert(sview_idx < ARRAY_SIZE(draw->llvm->tcs_jit_context.textures));
+      jit_tex = &draw->llvm->tcs_jit_context.textures[sview_idx];
+      break;
+   case PIPE_SHADER_TESS_EVAL:
+      assert(sview_idx < ARRAY_SIZE(draw->llvm->tes_jit_context.textures));
+      jit_tex = &draw->llvm->tes_jit_context.textures[sview_idx];
+      break;
+   default:
       assert(0);
       return;
    }
@@ -2434,6 +2507,8 @@ draw_llvm_set_mapped_texture(struct draw_context *draw,
    jit_tex->first_level = first_level;
    jit_tex->last_level = last_level;
    jit_tex->base = base_ptr;
+   jit_tex->num_samples = num_samples;
+   jit_tex->sample_stride = sample_stride;
 
    for (j = first_level; j <= last_level; j++) {
       jit_tex->mip_offsets[j] = mip_offsets[j];
@@ -2449,22 +2524,30 @@ draw_llvm_set_mapped_image(struct draw_context *draw,
                            uint32_t width, uint32_t height, uint32_t depth,
                            const void *base_ptr,
                            uint32_t row_stride,
-                           uint32_t img_stride)
+                           uint32_t img_stride,
+                           uint32_t num_samples,
+                           uint32_t sample_stride)
 {
    struct draw_jit_image *jit_image;
 
-   assert(shader_stage == PIPE_SHADER_VERTEX ||
-          shader_stage == PIPE_SHADER_GEOMETRY);
-
-   if (shader_stage == PIPE_SHADER_VERTEX) {
+   switch (shader_stage) {
+   case PIPE_SHADER_VERTEX:
       assert(idx < ARRAY_SIZE(draw->llvm->jit_context.images));
-
       jit_image = &draw->llvm->jit_context.images[idx];
-   } else if (shader_stage == PIPE_SHADER_GEOMETRY) {
+      break;
+   case PIPE_SHADER_GEOMETRY:
       assert(idx < ARRAY_SIZE(draw->llvm->gs_jit_context.images));
-
       jit_image = &draw->llvm->gs_jit_context.images[idx];
-   } else {
+      break;
+   case PIPE_SHADER_TESS_CTRL:
+      assert(idx < ARRAY_SIZE(draw->llvm->tcs_jit_context.images));
+      jit_image = &draw->llvm->tcs_jit_context.images[idx];
+      break;
+   case PIPE_SHADER_TESS_EVAL:
+      assert(idx < ARRAY_SIZE(draw->llvm->tes_jit_context.images));
+      jit_image = &draw->llvm->tes_jit_context.images[idx];
+      break;
+   default:
       assert(0);
       return;
    }
@@ -2476,6 +2559,8 @@ draw_llvm_set_mapped_image(struct draw_context *draw,
 
    jit_image->row_stride = row_stride;
    jit_image->img_stride = img_stride;
+   jit_image->num_samples = num_samples;
+   jit_image->sample_stride = sample_stride;
 }
 
 
@@ -2485,7 +2570,8 @@ draw_llvm_set_sampler_state(struct draw_context *draw,
 {
    unsigned i;
 
-   if (shader_type == PIPE_SHADER_VERTEX) {
+   switch (shader_type) {
+   case PIPE_SHADER_VERTEX:
       for (i = 0; i < draw->num_samplers[PIPE_SHADER_VERTEX]; i++) {
          struct draw_jit_sampler *jit_sam = &draw->llvm->jit_context.samplers[i];
 
@@ -2498,7 +2584,8 @@ draw_llvm_set_sampler_state(struct draw_context *draw,
             COPY_4V(jit_sam->border_color, s->border_color.f);
          }
       }
-   } else if (shader_type == PIPE_SHADER_GEOMETRY) {
+      break;
+   case PIPE_SHADER_GEOMETRY:
       for (i = 0; i < draw->num_samplers[PIPE_SHADER_GEOMETRY]; i++) {
          struct draw_jit_sampler *jit_sam = &draw->llvm->gs_jit_context.samplers[i];
 
@@ -2511,6 +2598,38 @@ draw_llvm_set_sampler_state(struct draw_context *draw,
             COPY_4V(jit_sam->border_color, s->border_color.f);
          }
       }
+      break;
+   case PIPE_SHADER_TESS_CTRL:
+      for (i = 0; i < draw->num_samplers[PIPE_SHADER_TESS_CTRL]; i++) {
+         struct draw_jit_sampler *jit_sam = &draw->llvm->tcs_jit_context.samplers[i];
+
+         if (draw->samplers[PIPE_SHADER_TESS_CTRL][i]) {
+            const struct pipe_sampler_state *s
+               = draw->samplers[PIPE_SHADER_TESS_CTRL][i];
+            jit_sam->min_lod = s->min_lod;
+            jit_sam->max_lod = s->max_lod;
+            jit_sam->lod_bias = s->lod_bias;
+            COPY_4V(jit_sam->border_color, s->border_color.f);
+         }
+      }
+      break;
+   case PIPE_SHADER_TESS_EVAL:
+      for (i = 0; i < draw->num_samplers[PIPE_SHADER_TESS_EVAL]; i++) {
+         struct draw_jit_sampler *jit_sam = &draw->llvm->tes_jit_context.samplers[i];
+
+         if (draw->samplers[PIPE_SHADER_TESS_EVAL][i]) {
+            const struct pipe_sampler_state *s
+               = draw->samplers[PIPE_SHADER_TESS_EVAL][i];
+            jit_sam->min_lod = s->min_lod;
+            jit_sam->max_lod = s->max_lod;
+            jit_sam->lod_bias = s->lod_bias;
+            COPY_4V(jit_sam->border_color, s->border_color.f);
+         }
+      }
+      break;
+   default:
+      assert(0);
+      break;
    }
 }
 
@@ -2623,8 +2742,7 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    memset(&system_values, 0, sizeof(system_values));
    memset(&outputs, 0, sizeof(outputs));
 
-   snprintf(func_name, sizeof(func_name), "draw_llvm_gs_variant%u",
-            variant->shader->variants_cached);
+   snprintf(func_name, sizeof(func_name), "draw_llvm_gs_variant");
 
    assert(variant->vertex_header_ptr_type);
 
@@ -2649,6 +2767,8 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
       if (LLVMGetTypeKind(arg_types[i]) == LLVMPointerTypeKind)
          lp_add_function_attr(variant_func, i + 1, LP_FUNC_ATTR_NOALIAS);
 
+   if (gallivm->cache && gallivm->cache->data_size)
+      return;
    context_ptr               = LLVMGetParam(variant_func, 0);
    input_array               = LLVMGetParam(variant_func, 1);
    io_ptr                    = LLVMGetParam(variant_func, 2);
@@ -2702,8 +2822,9 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
       draw_gs_jit_context_num_ssbos(variant->gallivm, context_ptr);
 
    /* code generated texture sampling */
-   sampler = draw_llvm_sampler_soa_create(variant->key.samplers);
-   image = draw_llvm_image_soa_create(draw_gs_llvm_variant_key_images(&variant->key));
+   sampler = draw_llvm_sampler_soa_create(variant->key.samplers, variant->key.nr_samplers);
+   image = draw_llvm_image_soa_create(draw_gs_llvm_variant_key_images(&variant->key),
+                                      variant->key.nr_images);
    mask_val = generate_mask_value(variant, gs_type);
    lp_build_mask_begin(&mask, gallivm, gs_type, mask_val);
 
@@ -2756,7 +2877,6 @@ draw_gs_llvm_generate(struct draw_llvm *llvm,
    gallivm_verify_function(gallivm, variant_func);
 }
 
-
 struct draw_gs_llvm_variant *
 draw_gs_llvm_create_variant(struct draw_llvm *llvm,
                             unsigned num_outputs,
@@ -2767,6 +2887,9 @@ draw_gs_llvm_create_variant(struct draw_llvm *llvm,
       llvm_geometry_shader(llvm->draw->gs.geometry_shader);
    LLVMTypeRef vertex_header;
    char module_name[64];
+   unsigned char ir_sha1_cache_key[20];
+   struct lp_cached_code cached = { 0 };
+   bool needs_caching = false;
 
    variant = MALLOC(sizeof *variant +
                     shader->variant_key_size -
@@ -2780,11 +2903,24 @@ draw_gs_llvm_create_variant(struct draw_llvm *llvm,
    snprintf(module_name, sizeof(module_name), "draw_llvm_gs_variant%u",
             variant->shader->variants_cached);
 
-   variant->gallivm = gallivm_create(module_name, llvm->context);
+   memcpy(&variant->key, key, shader->variant_key_size);
 
-   create_gs_jit_types(variant);
+   if (shader->base.state.ir.nir && llvm->draw->disk_cache_cookie) {
+      draw_get_ir_cache_key(shader->base.state.ir.nir,
+                            key,
+                            shader->variant_key_size,
+                            num_outputs,
+                            ir_sha1_cache_key);
+
+      llvm->draw->disk_cache_find_shader(llvm->draw->disk_cache_cookie,
+                                         &cached,
+                                         ir_sha1_cache_key);
+      if (!cached.data_size)
+         needs_caching = true;
+   }
+   variant->gallivm = gallivm_create(module_name, llvm->context, &cached);
 
-   memcpy(&variant->key, key, shader->variant_key_size);
+   create_gs_jit_types(variant);
 
    vertex_header = create_jit_vertex_header(variant->gallivm, num_outputs);
 
@@ -2797,6 +2933,10 @@ draw_gs_llvm_create_variant(struct draw_llvm *llvm,
    variant->jit_func = (draw_gs_jit_func)
          gallivm_jit_function(variant->gallivm, variant->function);
 
+   if (needs_caching)
+      llvm->draw->disk_cache_insert_shader(llvm->draw->disk_cache_cookie,
+                                           &cached,
+                                           ir_sha1_cache_key);
    gallivm_free_ir(variant->gallivm);
 
    variant->list_item_global.base = variant;
@@ -3160,11 +3300,9 @@ draw_tcs_llvm_generate(struct draw_llvm *llvm,
 
    memset(&system_values, 0, sizeof(system_values));
 
-   snprintf(func_name, sizeof(func_name), "draw_llvm_tcs_variant%u",
-            variant->shader->variants_cached);
+   snprintf(func_name, sizeof(func_name), "draw_llvm_tcs_variant");
 
-   snprintf(func_name_coro, sizeof(func_name_coro), "draw_llvm_tcs_coro_variant%u",
-            variant->shader->variants_cached);
+   snprintf(func_name_coro, sizeof(func_name_coro), "draw_llvm_tcs_coro_variant");
 
    arg_types[0] = get_tcs_context_ptr_type(variant);    /* context */
    arg_types[1] = variant->input_array_type;           /* input */
@@ -3193,6 +3331,8 @@ draw_tcs_llvm_generate(struct draw_llvm *llvm,
       }
    }
 
+   if (gallivm->cache && gallivm->cache->data_size)
+      return;
    context_ptr               = LLVMGetParam(variant_func, 0);
    input_array               = LLVMGetParam(variant_func, 1);
    output_array              = LLVMGetParam(variant_func, 2);
@@ -3291,8 +3431,9 @@ draw_tcs_llvm_generate(struct draw_llvm *llvm,
    ssbos_ptr = draw_tcs_jit_context_ssbos(variant->gallivm, context_ptr);
    num_ssbos_ptr =
       draw_tcs_jit_context_num_ssbos(variant->gallivm, context_ptr);
-   sampler = draw_llvm_sampler_soa_create(variant->key.samplers);
-   image = draw_llvm_image_soa_create(draw_tcs_llvm_variant_key_images(&variant->key));
+   sampler = draw_llvm_sampler_soa_create(variant->key.samplers, variant->key.nr_samplers);
+   image = draw_llvm_image_soa_create(draw_tcs_llvm_variant_key_images(&variant->key),
+                                      variant->key.nr_images);
 
    LLVMValueRef counter = LLVMGetParam(variant_coro, 5);
    LLVMValueRef invocvec = LLVMGetUndef(LLVMVectorType(int32_type, vector_length));
@@ -3375,6 +3516,9 @@ draw_tcs_llvm_create_variant(struct draw_llvm *llvm,
    struct draw_tcs_llvm_variant *variant;
    struct llvm_tess_ctrl_shader *shader = llvm_tess_ctrl_shader(llvm->draw->tcs.tess_ctrl_shader);
    char module_name[64];
+   unsigned char ir_sha1_cache_key[20];
+   struct lp_cached_code cached = { 0 };
+   bool needs_caching = false;
 
    variant = MALLOC(sizeof *variant +
                     shader->variant_key_size - sizeof variant->key);
@@ -3387,24 +3531,44 @@ draw_tcs_llvm_create_variant(struct draw_llvm *llvm,
    snprintf(module_name, sizeof(module_name), "draw_llvm_tcs_variant%u",
             variant->shader->variants_cached);
 
-   variant->gallivm = gallivm_create(module_name, llvm->context);
+   memcpy(&variant->key, key, shader->variant_key_size);
 
-   create_tcs_jit_types(variant);
+   if (shader->base.state.ir.nir && llvm->draw->disk_cache_cookie) {
+      draw_get_ir_cache_key(shader->base.state.ir.nir,
+                            key,
+                            shader->variant_key_size,
+                            num_outputs,
+                            ir_sha1_cache_key);
+
+      llvm->draw->disk_cache_find_shader(llvm->draw->disk_cache_cookie,
+                                         &cached,
+                                         ir_sha1_cache_key);
+      if (!cached.data_size)
+         needs_caching = true;
+   }
 
-   memcpy(&variant->key, key, shader->variant_key_size);
+   variant->gallivm = gallivm_create(module_name, llvm->context, &cached);
+
+   create_tcs_jit_types(variant);
 
    if (gallivm_debug & (GALLIVM_DEBUG_TGSI | GALLIVM_DEBUG_IR)) {
       nir_print_shader(llvm->draw->tcs.tess_ctrl_shader->state.ir.nir, stderr);
       draw_tcs_llvm_dump_variant_key(&variant->key);
    }
 
+   lp_build_coro_declare_malloc_hooks(variant->gallivm);
    draw_tcs_llvm_generate(llvm, variant);
 
    gallivm_compile_module(variant->gallivm);
 
+   lp_build_coro_add_malloc_hooks(variant->gallivm);
    variant->jit_func = (draw_tcs_jit_func)
       gallivm_jit_function(variant->gallivm, variant->function);
 
+   if (needs_caching)
+      llvm->draw->disk_cache_insert_shader(llvm->draw->disk_cache_cookie,
+                                           &cached,
+                                           ir_sha1_cache_key);
    gallivm_free_ir(variant->gallivm);
 
    variant->list_item_global.base = variant;
@@ -3665,12 +3829,12 @@ draw_tes_llvm_generate(struct draw_llvm *llvm,
    LLVMContextRef context = gallivm->context;
    LLVMTypeRef int32_type = LLVMInt32TypeInContext(context);
    LLVMTypeRef flt_type = LLVMFloatTypeInContext(context);
-   LLVMTypeRef arg_types[9];
+   LLVMTypeRef arg_types[10];
    LLVMTypeRef func_type;
    LLVMValueRef variant_func;
    LLVMValueRef context_ptr;
    LLVMValueRef tess_coord[2], io_ptr, input_array, num_tess_coord;
-   LLVMValueRef tess_inner, tess_outer, prim_id;
+   LLVMValueRef tess_inner, tess_outer, prim_id, patch_vertices_in;
    LLVMBasicBlockRef block;
    LLVMBuilderRef builder;
    LLVMValueRef mask_val;
@@ -3692,8 +3856,7 @@ draw_tes_llvm_generate(struct draw_llvm *llvm,
    memset(&system_values, 0, sizeof(system_values));
    memset(&outputs, 0, sizeof(outputs));
 
-   snprintf(func_name, sizeof(func_name), "draw_llvm_tes_variant%u",
-            variant->shader->variants_cached);
+   snprintf(func_name, sizeof(func_name), "draw_llvm_tes_variant");
 
    arg_types[0] = get_tes_context_ptr_type(variant);    /* context */
    arg_types[1] = variant->input_array_type;           /* input */
@@ -3704,6 +3867,7 @@ draw_tes_llvm_generate(struct draw_llvm *llvm,
    arg_types[6] = LLVMPointerType(flt_type, 0);
    arg_types[7] = LLVMPointerType(LLVMArrayType(flt_type, 4), 0);
    arg_types[8] = LLVMPointerType(LLVMArrayType(flt_type, 2), 0);
+   arg_types[9] = int32_type;
 
    func_type = LLVMFunctionType(int32_type, arg_types, ARRAY_SIZE(arg_types), 0);
    variant_func = LLVMAddFunction(gallivm->module, func_name, func_type);
@@ -3715,6 +3879,8 @@ draw_tes_llvm_generate(struct draw_llvm *llvm,
       if (LLVMGetTypeKind(arg_types[i]) == LLVMPointerTypeKind)
          lp_add_function_attr(variant_func, i + 1, LP_FUNC_ATTR_NOALIAS);
 
+   if (gallivm->cache && gallivm->cache->data_size)
+      return;
    context_ptr               = LLVMGetParam(variant_func, 0);
    input_array               = LLVMGetParam(variant_func, 1);
    io_ptr                    = LLVMGetParam(variant_func, 2);
@@ -3724,6 +3890,7 @@ draw_tes_llvm_generate(struct draw_llvm *llvm,
    tess_coord[1]             = LLVMGetParam(variant_func, 6);
    tess_outer                = LLVMGetParam(variant_func, 7);
    tess_inner                = LLVMGetParam(variant_func, 8);
+   patch_vertices_in         = LLVMGetParam(variant_func, 9);
 
    lp_build_name(context_ptr, "context");
    lp_build_name(input_array, "input");
@@ -3734,6 +3901,7 @@ draw_tes_llvm_generate(struct draw_llvm *llvm,
    lp_build_name(tess_coord[1], "tess_coord[1]");
    lp_build_name(tess_outer, "tess_outer");
    lp_build_name(tess_inner, "tess_inner");
+   lp_build_name(patch_vertices_in, "patch_vertices_in");
 
    tes_iface.base.fetch_vertex_input = draw_tes_llvm_fetch_vertex_input;
    tes_iface.base.fetch_patch_input = draw_tes_llvm_fetch_patch_input;
@@ -3761,14 +3929,17 @@ draw_tes_llvm_generate(struct draw_llvm *llvm,
    ssbos_ptr = draw_tes_jit_context_ssbos(variant->gallivm, context_ptr);
    num_ssbos_ptr =
       draw_tes_jit_context_num_ssbos(variant->gallivm, context_ptr);
-   sampler = draw_llvm_sampler_soa_create(variant->key.samplers);
-   image = draw_llvm_image_soa_create(draw_tes_llvm_variant_key_images(&variant->key));
+   sampler = draw_llvm_sampler_soa_create(variant->key.samplers, variant->key.nr_samplers);
+   image = draw_llvm_image_soa_create(draw_tes_llvm_variant_key_images(&variant->key),
+                                      variant->key.nr_images);
    step = lp_build_const_int32(gallivm, vector_length);
 
    system_values.tess_outer = LLVMBuildLoad(builder, tess_outer, "");
    system_values.tess_inner = LLVMBuildLoad(builder, tess_inner, "");
 
    system_values.prim_id = lp_build_broadcast_scalar(&bldvec, prim_id);
+
+   system_values.vertices_in = lp_build_broadcast_scalar(&bldvec, patch_vertices_in);
    struct lp_build_loop_state lp_loop;
    lp_build_loop_begin(&lp_loop, gallivm, bld.zero);
    {
@@ -3844,6 +4015,9 @@ draw_tes_llvm_create_variant(struct draw_llvm *llvm,
    struct llvm_tess_eval_shader *shader = llvm_tess_eval_shader(llvm->draw->tes.tess_eval_shader);
    LLVMTypeRef vertex_header;
    char module_name[64];
+   unsigned char ir_sha1_cache_key[20];
+   struct lp_cached_code cached = { 0 };
+   bool needs_caching = false;
 
    variant = MALLOC(sizeof *variant +
                     shader->variant_key_size - sizeof variant->key);
@@ -3856,12 +4030,24 @@ draw_tes_llvm_create_variant(struct draw_llvm *llvm,
    snprintf(module_name, sizeof(module_name), "draw_llvm_tes_variant%u",
             variant->shader->variants_cached);
 
-   variant->gallivm = gallivm_create(module_name, llvm->context);
+   memcpy(&variant->key, key, shader->variant_key_size);
+   if (shader->base.state.ir.nir && llvm->draw->disk_cache_cookie) {
+      draw_get_ir_cache_key(shader->base.state.ir.nir,
+                            key,
+                            shader->variant_key_size,
+                            num_outputs,
+                            ir_sha1_cache_key);
+
+      llvm->draw->disk_cache_find_shader(llvm->draw->disk_cache_cookie,
+                                         &cached,
+                                         ir_sha1_cache_key);
+      if (!cached.data_size)
+         needs_caching = true;
+   }
+   variant->gallivm = gallivm_create(module_name, llvm->context, &cached);
 
    create_tes_jit_types(variant);
 
-   memcpy(&variant->key, key, shader->variant_key_size);
-
    vertex_header = create_jit_vertex_header(variant->gallivm, num_outputs);
 
    variant->vertex_header_ptr_type = LLVMPointerType(vertex_header, 0);
@@ -3878,6 +4064,10 @@ draw_tes_llvm_create_variant(struct draw_llvm *llvm,
    variant->jit_func = (draw_tes_jit_func)
       gallivm_jit_function(variant->gallivm, variant->function);
 
+   if (needs_caching)
+      llvm->draw->disk_cache_insert_shader(llvm->draw->disk_cache_cookie,
+                                           &cached,
+                                           ir_sha1_cache_key);
    gallivm_free_ir(variant->gallivm);
 
    variant->list_item_global.base = variant;