gallium/ttn: Use variable create/add helpers
[mesa.git] / src / gallium / auxiliary / nir / tgsi_to_nir.c
index 20d6c0bfb29a3e49f7e51b44944a1e56f57f1866..fb6bbb5026cefa478566c7163b37cce033bb23c8 100644 (file)
  * IN THE SOFTWARE.
  */
 
+#include "util/blob.h"
+#include "util/disk_cache.h"
+#include "util/u_memory.h"
 #include "util/ralloc.h"
 #include "pipe/p_screen.h"
 
 #include "compiler/nir/nir.h"
 #include "compiler/nir/nir_control_flow.h"
 #include "compiler/nir/nir_builder.h"
-#include "compiler/glsl/gl_nir.h"
-#include "compiler/glsl/list.h"
+#include "compiler/nir/nir_serialize.h"
 #include "compiler/shader_enums.h"
 
 #include "tgsi_to_nir.h"
@@ -74,6 +76,10 @@ struct ttn_compile {
    nir_variable *images[PIPE_MAX_SHADER_IMAGES];
    nir_variable *ssbo[PIPE_MAX_SHADER_BUFFERS];
 
+   unsigned num_samplers;
+   unsigned num_images;
+   unsigned num_msaa_images;
+
    nir_variable *input_var_face;
    nir_variable *input_var_position;
    nir_variable *input_var_point;
@@ -102,7 +108,6 @@ struct ttn_compile {
    /* How many TGSI_FILE_IMMEDIATE vec4s have been parsed so far. */
    unsigned next_imm;
 
-   bool cap_scalar;
    bool cap_face_is_sysval;
    bool cap_position_is_sysval;
    bool cap_point_is_sysval;
@@ -220,7 +225,7 @@ ttn_translate_interp_mode(unsigned tgsi_interp)
    case TGSI_INTERPOLATE_PERSPECTIVE:
       return INTERP_MODE_SMOOTH;
    case TGSI_INTERPOLATE_COLOR:
-      return INTERP_MODE_SMOOTH;
+      return INTERP_MODE_NONE;
    default:
       unreachable("bad TGSI interpolation mode");
    }
@@ -238,13 +243,11 @@ ttn_emit_declaration(struct ttn_compile *c)
    if (file == TGSI_FILE_TEMPORARY) {
       if (decl->Declaration.Array) {
          /* for arrays, we create variables instead of registers: */
-         nir_variable *var = rzalloc(b->shader, nir_variable);
-
-         var->type = glsl_array_type(glsl_vec4_type(), array_size, 0);
-         var->data.mode = nir_var_shader_temp;
-         var->name = ralloc_asprintf(var, "arr_%d", decl->Array.ArrayID);
-
-         exec_list_push_tail(&b->shader->globals, &var->node);
+         nir_variable *var =
+            nir_variable_create(b->shader, nir_var_shader_temp,
+                                glsl_array_type(glsl_vec4_type(), array_size, 0),
+                                ralloc_asprintf(b->shader, "arr_%d",
+                                                decl->Array.ArrayID));
 
          for (i = 0; i < array_size; i++) {
             /* point all the matching slots to the same var,
@@ -375,7 +378,6 @@ ttn_emit_declaration(struct ttn_compile *c)
             var->data.interpolation =
                ttn_translate_interp_mode(decl->Interp.Interpolate);
 
-            exec_list_push_tail(&b->shader->inputs, &var->node);
             c->inputs[idx] = var;
 
             for (int i = 0; i < array_size; i++)
@@ -439,6 +441,10 @@ ttn_emit_declaration(struct ttn_compile *c)
             } else {
                var->data.location =
                   tgsi_varying_semantic_to_slot(semantic_name, semantic_index);
+               if (var->data.location == VARYING_SLOT_FOGC ||
+                   var->data.location == VARYING_SLOT_PSIZ) {
+                  var->type = glsl_float_type();
+               }
             }
 
             if (is_array) {
@@ -452,7 +458,6 @@ ttn_emit_declaration(struct ttn_compile *c)
                c->output_regs[idx].reg = reg;
             }
 
-            exec_list_push_tail(&b->shader->outputs, &var->node);
             c->outputs[idx] = var;
 
             for (int i = 0; i < array_size; i++)
@@ -463,14 +468,14 @@ ttn_emit_declaration(struct ttn_compile *c)
             var->data.mode = nir_var_uniform;
             var->name = ralloc_asprintf(var, "uniform_%d", idx);
             var->data.location = idx;
-
-            exec_list_push_tail(&b->shader->uniforms, &var->node);
             break;
          default:
             unreachable("bad declaration file");
             return;
          }
 
+         nir_shader_add_variable(b->shader, var);
+
          if (is_array)
             break;
       }
@@ -641,6 +646,10 @@ ttn_src_for_file_and_index(struct ttn_compile *c, unsigned file, unsigned index,
          op = nir_intrinsic_load_work_group_id;
          load = nir_load_work_group_id(b);
          break;
+      case TGSI_SEMANTIC_BLOCK_SIZE:
+         op = nir_intrinsic_load_local_group_size;
+         load = nir_load_local_group_size(b);
+         break;
       case TGSI_SEMANTIC_CS_USER_DATA_AMD:
          op = nir_intrinsic_load_user_data_amd;
          load = nir_load_user_data_amd(b);
@@ -657,7 +666,9 @@ ttn_src_for_file_and_index(struct ttn_compile *c, unsigned file, unsigned index,
          unreachable("bad system value");
       }
 
-      if (load->num_components == 3)
+      if (load->num_components == 2)
+         load = nir_swizzle(b, load, SWIZ(X, Y, Y, Y), 4);
+      else if (load->num_components == 3)
          load = nir_swizzle(b, load, SWIZ(X, Y, Z, Z), 4);
 
       src = nir_src_for_ssa(load);
@@ -729,6 +740,7 @@ ttn_src_for_file_and_index(struct ttn_compile *c, unsigned file, unsigned index,
          }
          /* UBO offsets are in bytes, but TGSI gives them to us in vec4's */
          offset = nir_ishl(b, offset, nir_imm_int(b, 4));
+         nir_intrinsic_set_align(load, 16, 0);
       } else {
          nir_intrinsic_set_base(load, index);
          if (indirect) {
@@ -1099,6 +1111,14 @@ ttn_ucmp(nir_builder *b, nir_op op, nir_alu_dest dest, nir_ssa_def **src)
                                     src[1], src[2]));
 }
 
+static void
+ttn_barrier(nir_builder *b)
+{
+   nir_intrinsic_instr *barrier =
+      nir_intrinsic_instr_create(b->shader, nir_intrinsic_control_barrier);
+   nir_builder_instr_insert(b, &barrier->instr);
+}
+
 static void
 ttn_kill(nir_builder *b, nir_op op, nir_alu_dest dest, nir_ssa_def **src)
 {
@@ -1111,7 +1131,11 @@ ttn_kill(nir_builder *b, nir_op op, nir_alu_dest dest, nir_ssa_def **src)
 static void
 ttn_kill_if(nir_builder *b, nir_op op, nir_alu_dest dest, nir_ssa_def **src)
 {
+   /* flt must be exact, because NaN shouldn't discard. (apps rely on this) */
+   b->exact = true;
    nir_ssa_def *cmp = nir_bany(b, nir_flt(b, src[0], nir_imm_float(b, 0.0)));
+   b->exact = false;
+
    nir_intrinsic_instr *discard =
       nir_intrinsic_instr_create(b->shader, nir_intrinsic_discard_if);
    discard->src[0] = nir_src_for_ssa(cmp);
@@ -1307,7 +1331,8 @@ get_sampler_var(struct ttn_compile *c, int binding,
                 enum glsl_sampler_dim dim,
                 bool is_shadow,
                 bool is_array,
-                enum glsl_base_type base_type)
+                enum glsl_base_type base_type,
+                nir_texop op)
 {
    nir_variable *var = c->samplers[binding];
    if (!var) {
@@ -1317,7 +1342,17 @@ get_sampler_var(struct ttn_compile *c, int binding,
                                 "sampler");
       var->data.binding = binding;
       var->data.explicit_binding = true;
+
       c->samplers[binding] = var;
+      c->num_samplers = MAX2(c->num_samplers, binding + 1);
+
+      /* Record textures used */
+      unsigned mask = 1 << binding;
+      c->build.shader->info.textures_used |= mask;
+      if (op == nir_texop_txf ||
+          op == nir_texop_txf_ms ||
+          op == nir_texop_txf_ms_mcs)
+         c->build.shader->info.textures_used_by_txf |= mask;
    }
 
    return var;
@@ -1329,7 +1364,7 @@ get_image_var(struct ttn_compile *c, int binding,
               bool is_array,
               enum glsl_base_type base_type,
               enum gl_access_qualifier access,
-              GLenum format)
+              enum pipe_format format)
 {
    nir_variable *var = c->images[binding];
 
@@ -1339,9 +1374,13 @@ get_image_var(struct ttn_compile *c, int binding,
       var = nir_variable_create(c->build.shader, nir_var_uniform, type, "image");
       var->data.binding = binding;
       var->data.explicit_binding = true;
-      var->data.image.access = access;
+      var->data.access = access;
       var->data.image.format = format;
+
       c->images[binding] = var;
+      c->num_images = MAX2(c->num_images, binding + 1);
+      if (dim == GLSL_SAMPLER_DIM_MS)
+         c->num_msaa_images = c->num_images;
    }
 
    return var;
@@ -1459,25 +1498,8 @@ ttn_tex(struct ttn_compile *c, nir_alu_dest dest, nir_ssa_def **src)
    get_texture_info(tgsi_inst->Texture.Texture,
                     &instr->sampler_dim, &instr->is_shadow, &instr->is_array);
 
-   switch (instr->sampler_dim) {
-   case GLSL_SAMPLER_DIM_1D:
-   case GLSL_SAMPLER_DIM_BUF:
-      instr->coord_components = 1;
-      break;
-   case GLSL_SAMPLER_DIM_2D:
-   case GLSL_SAMPLER_DIM_RECT:
-   case GLSL_SAMPLER_DIM_EXTERNAL:
-   case GLSL_SAMPLER_DIM_MS:
-      instr->coord_components = 2;
-      break;
-   case GLSL_SAMPLER_DIM_3D:
-   case GLSL_SAMPLER_DIM_CUBE:
-      instr->coord_components = 3;
-      break;
-   case GLSL_SAMPLER_DIM_SUBPASS:
-   case GLSL_SAMPLER_DIM_SUBPASS_MS:
-      unreachable("invalid sampler_dim");
-   }
+   instr->coord_components =
+      glsl_get_sampler_dim_coordinate_components(instr->sampler_dim);
 
    if (instr->is_array)
       instr->coord_components++;
@@ -1502,7 +1524,8 @@ ttn_tex(struct ttn_compile *c, nir_alu_dest dest, nir_ssa_def **src)
       get_sampler_var(c, sview, instr->sampler_dim,
                       instr->is_shadow,
                       instr->is_array,
-                      base_type_for_alu_type(instr->dest_type));
+                      base_type_for_alu_type(instr->dest_type),
+                      op);
 
    nir_deref_instr *deref = nir_build_deref_var(b, var);
 
@@ -1666,7 +1689,8 @@ ttn_txq(struct ttn_compile *c, nir_alu_dest dest, nir_ssa_def **src)
       get_sampler_var(c, tex_index, txs->sampler_dim,
                       txs->is_shadow,
                       txs->is_array,
-                      base_type_for_alu_type(txs->dest_type));
+                      base_type_for_alu_type(txs->dest_type),
+                      nir_texop_txs);
 
    nir_deref_instr *deref = nir_build_deref_var(b, var);
 
@@ -1723,99 +1747,6 @@ get_mem_qualifier(struct tgsi_full_instruction *tgsi_inst)
    return access;
 }
 
-static GLenum
-get_image_format(struct tgsi_full_instruction *tgsi_inst)
-{
-   switch (tgsi_inst->Memory.Format) {
-   case PIPE_FORMAT_R8_UNORM:
-      return GL_R8;
-   case PIPE_FORMAT_R8G8_UNORM:
-      return GL_RG8;
-   case PIPE_FORMAT_R8G8B8A8_UNORM:
-      return GL_RGBA8;
-   case PIPE_FORMAT_R16_UNORM:
-      return GL_R16;
-   case PIPE_FORMAT_R16G16_UNORM:
-      return GL_RG16;
-   case PIPE_FORMAT_R16G16B16A16_UNORM:
-      return GL_RGBA16;
-
-   case PIPE_FORMAT_R8_SNORM:
-      return GL_R8_SNORM;
-   case PIPE_FORMAT_R8G8_SNORM:
-      return GL_RG8_SNORM;
-   case PIPE_FORMAT_R8G8B8A8_SNORM:
-      return GL_RGBA8_SNORM;
-   case PIPE_FORMAT_R16_SNORM:
-      return GL_R16_SNORM;
-   case PIPE_FORMAT_R16G16_SNORM:
-      return GL_RG16_SNORM;
-   case PIPE_FORMAT_R16G16B16A16_SNORM:
-      return GL_RGBA16_SNORM;
-
-   case PIPE_FORMAT_R8_UINT:
-      return GL_R8UI;
-   case PIPE_FORMAT_R8G8_UINT:
-      return GL_RG8UI;
-   case PIPE_FORMAT_R8G8B8A8_UINT:
-      return GL_RGBA8UI;
-   case PIPE_FORMAT_R16_UINT:
-      return GL_R16UI;
-   case PIPE_FORMAT_R16G16_UINT:
-      return GL_RG16UI;
-   case PIPE_FORMAT_R16G16B16A16_UINT:
-      return GL_RGBA16UI;
-   case PIPE_FORMAT_R32_UINT:
-      return GL_R32UI;
-   case PIPE_FORMAT_R32G32_UINT:
-      return GL_RG32UI;
-   case PIPE_FORMAT_R32G32B32A32_UINT:
-      return GL_RGBA32UI;
-
-   case PIPE_FORMAT_R8_SINT:
-      return GL_R8I;
-   case PIPE_FORMAT_R8G8_SINT:
-      return GL_RG8I;
-   case PIPE_FORMAT_R8G8B8A8_SINT:
-      return GL_RGBA8I;
-   case PIPE_FORMAT_R16_SINT:
-      return GL_R16I;
-   case PIPE_FORMAT_R16G16_SINT:
-      return GL_RG16I;
-   case PIPE_FORMAT_R16G16B16A16_SINT:
-      return GL_RGBA16I;
-   case PIPE_FORMAT_R32_SINT:
-      return GL_R32I;
-   case PIPE_FORMAT_R32G32_SINT:
-      return GL_RG32I;
-   case PIPE_FORMAT_R32G32B32A32_SINT:
-      return GL_RGBA32I;
-
-   case PIPE_FORMAT_R16_FLOAT:
-      return GL_R16F;
-   case PIPE_FORMAT_R16G16_FLOAT:
-      return GL_RG16F;
-   case PIPE_FORMAT_R16G16B16A16_FLOAT:
-      return GL_RGBA16F;
-   case PIPE_FORMAT_R32_FLOAT:
-      return GL_R32F;
-   case PIPE_FORMAT_R32G32_FLOAT:
-      return GL_RG32F;
-   case PIPE_FORMAT_R32G32B32A32_FLOAT:
-      return GL_RGBA32F;
-
-   case PIPE_FORMAT_R11G11B10_FLOAT:
-      return GL_R11F_G11F_B10F;
-   case PIPE_FORMAT_R10G10B10A2_UINT:
-      return GL_RGB10_A2UI;
-   case PIPE_FORMAT_R10G10B10A2_UNORM:
-      return GL_RGB10_A2;
-
-   default:
-      unreachable("unhandled image format");
-   }
-}
-
 static void
 ttn_mem(struct ttn_compile *c, nir_alu_dest dest, nir_ssa_def **src)
 {
@@ -1891,15 +1822,15 @@ ttn_mem(struct ttn_compile *c, nir_alu_dest dest, nir_ssa_def **src)
 
       enum glsl_base_type base_type = get_image_base_type(tgsi_inst);
       enum gl_access_qualifier access = get_mem_qualifier(tgsi_inst);
-      GLenum format = get_image_format(tgsi_inst);
 
       nir_variable *image =
          get_image_var(c, resource_index,
-                       dim, is_array, base_type, access, format);
+                       dim, is_array, base_type, access,
+                       tgsi_inst->Memory.Format);
       nir_deref_instr *image_deref = nir_build_deref_var(b, image);
       const struct glsl_type *type = image_deref->type;
 
-      nir_intrinsic_set_access(instr, image_deref->var->data.image.access);
+      nir_intrinsic_set_access(instr, image_deref->var->data.access);
 
       instr->src[0] = nir_src_for_ssa(&image_deref->dest.ssa);
       instr->src[1] = nir_src_for_ssa(src[addr_src_index]);
@@ -1911,19 +1842,26 @@ ttn_mem(struct ttn_compile *c, nir_alu_dest dest, nir_ssa_def **src)
          instr->src[2] = nir_src_for_ssa(nir_ssa_undef(b, 1, 32));
       }
 
+      if (tgsi_inst->Instruction.Opcode == TGSI_OPCODE_LOAD) {
+         instr->src[3] = nir_src_for_ssa(nir_imm_int(b, 0)); /* LOD */
+      }
+
+      unsigned num_components = util_last_bit(tgsi_inst->Dst[0].Register.WriteMask);
+
       if (tgsi_inst->Instruction.Opcode == TGSI_OPCODE_STORE) {
-         instr->src[3] = nir_src_for_ssa(nir_swizzle(b, src[1], SWIZ(X, Y, Z, W), 4));
+         instr->src[3] = nir_src_for_ssa(nir_swizzle(b, src[1], SWIZ(X, Y, Z, W),
+                                                     num_components));
+         instr->src[4] = nir_src_for_ssa(nir_imm_int(b, 0)); /* LOD */
       }
 
-      instr->num_components = 4;
+      instr->num_components = num_components;
    } else {
       unreachable("unexpected file");
    }
 
 
    if (tgsi_inst->Instruction.Opcode == TGSI_OPCODE_LOAD) {
-      nir_ssa_dest_init(&instr->instr, &instr->dest,
-                        util_last_bit(tgsi_inst->Dst[0].Register.WriteMask),
+      nir_ssa_dest_init(&instr->instr, &instr->dest, instr->num_components,
                         32, NULL);
       nir_builder_instr_insert(b, &instr->instr);
       ttn_move_dest(b, dest, &instr->dest.ssa);
@@ -2298,6 +2236,10 @@ ttn_emit_instruction(struct ttn_compile *c)
       ttn_endloop(c);
       break;
 
+   case TGSI_OPCODE_BARRIER:
+      ttn_barrier(b);
+      break;
+
    default:
       if (op_trans[tgsi_op] != 0 || tgsi_op == TGSI_OPCODE_MOV) {
          ttn_alu(b, op_trans[tgsi_op], dest, dst_bitsize, src);
@@ -2360,6 +2302,12 @@ ttn_add_output_stores(struct ttn_compile *c)
             store_value = nir_channel(b, store_value, 2);
          else if (var->data.location == FRAG_RESULT_STENCIL)
             store_value = nir_channel(b, store_value, 1);
+      } else {
+         /* FOGC and PSIZ are scalar values */
+         if (var->data.location == VARYING_SLOT_FOGC ||
+             var->data.location == VARYING_SLOT_PSIZ) {
+            store_value = nir_channel(b, store_value, 0);
+         }
       }
 
       nir_store_deref(b, nir_build_deref_var(b, var), store_value,
@@ -2374,7 +2322,7 @@ static void
 ttn_parse_tgsi(struct ttn_compile *c, const void *tgsi_tokens)
 {
    struct tgsi_parse_context parser;
-   int ret;
+   ASSERTED int ret;
 
    ret = tgsi_parse_init(&parser, tgsi_tokens);
    assert(ret == TGSI_PARSE_OK);
@@ -2405,7 +2353,6 @@ static void
 ttn_read_pipe_caps(struct ttn_compile *c,
                    struct pipe_screen *screen)
 {
-   c->cap_scalar = screen->get_shader_param(screen, c->scan->processor, PIPE_SHADER_CAP_SCALAR_ISA);
    c->cap_packed_uniforms = screen->get_param(screen, PIPE_CAP_PACKED_UNIFORMS);
    c->cap_samplers_as_deref = screen->get_param(screen, PIPE_CAP_NIR_SAMPLERS_AS_DEREF);
    c->cap_face_is_sysval = screen->get_param(screen, PIPE_CAP_TGSI_FS_FACE_IS_INTEGER_SYSVAL);
@@ -2550,7 +2497,7 @@ ttn_compile_init(const void *tgsi_tokens,
 }
 
 static void
-ttn_optimize_nir(nir_shader *nir, bool scalar)
+ttn_optimize_nir(nir_shader *nir)
 {
    bool progress;
    do {
@@ -2558,8 +2505,8 @@ ttn_optimize_nir(nir_shader *nir, bool scalar)
 
       NIR_PASS_V(nir, nir_lower_vars_to_ssa);
 
-      if (scalar) {
-         NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL);
+      if (nir->options->lower_to_scalar) {
+         NIR_PASS_V(nir, nir_lower_alu_to_scalar, NULL, NULL);
          NIR_PASS_V(nir, nir_lower_phis_to_scalar);
       }
 
@@ -2601,7 +2548,7 @@ ttn_optimize_nir(nir_shader *nir, bool scalar)
  * so we have to do it here too.
  */
 static void
-ttn_finalize_nir(struct ttn_compile *c)
+ttn_finalize_nir(struct ttn_compile *c, struct pipe_screen *screen)
 {
    struct nir_shader *nir = c->build.shader;
 
@@ -2616,28 +2563,113 @@ ttn_finalize_nir(struct ttn_compile *c)
    if (c->cap_packed_uniforms)
       NIR_PASS_V(nir, nir_lower_uniforms_to_ubo, 16);
 
-   if (c->cap_samplers_as_deref)
-      NIR_PASS_V(nir, gl_nir_lower_samplers_as_deref, NULL);
-   else
-      NIR_PASS_V(nir, gl_nir_lower_samplers, NULL);
+   if (!c->cap_samplers_as_deref)
+      NIR_PASS_V(nir, nir_lower_samplers);
+
+   if (screen->finalize_nir) {
+      screen->finalize_nir(screen, nir, true);
+   } else {
+      ttn_optimize_nir(nir);
+      nir_shader_gather_info(nir, c->build.impl);
+   }
+
+   nir->info.num_images = c->num_images;
+   nir->info.num_textures = c->num_samplers;
 
-   ttn_optimize_nir(nir, c->cap_scalar);
-   nir_shader_gather_info(nir, c->build.impl);
    nir_validate_shader(nir, "TTN: after all optimizations");
 }
 
+static void save_nir_to_disk_cache(struct disk_cache *cache,
+                                   uint8_t key[CACHE_KEY_SIZE],
+                                   const nir_shader *s)
+{
+   struct blob blob = {0};
+
+   blob_init(&blob);
+   /* Because we cannot fully trust disk_cache_put
+    * (EGL_ANDROID_blob_cache) we add the shader size,
+    * which we'll check after disk_cache_get().
+    */
+   if (blob_reserve_uint32(&blob) != 0) {
+      blob_finish(&blob);
+      return;
+   }
+
+   nir_serialize(&blob, s, true);
+   *(uint32_t *)blob.data = blob.size;
+
+   disk_cache_put(cache, key, blob.data, blob.size, NULL);
+   blob_finish(&blob);
+}
+
+static nir_shader *
+load_nir_from_disk_cache(struct disk_cache *cache,
+                         struct pipe_screen *screen,
+                         uint8_t key[CACHE_KEY_SIZE],
+                         unsigned processor)
+{
+   const nir_shader_compiler_options *options =
+      screen->get_compiler_options(screen, PIPE_SHADER_IR_NIR, processor);
+   struct blob_reader blob_reader;
+   size_t size;
+   nir_shader *s;
+
+   uint32_t *buffer = (uint32_t *)disk_cache_get(cache, key, &size);
+   if (!buffer)
+      return NULL;
+
+   /* Match found. No need to check crc32 or other things.
+    * disk_cache_get is supposed to do that for us.
+    * However we do still check if the first element is indeed the size,
+    * as we cannot fully trust disk_cache_get (EGL_ANDROID_blob_cache) */
+   if (buffer[0] != size) {
+      return NULL;
+   }
+
+   size -= 4;
+   blob_reader_init(&blob_reader, buffer + 1, size);
+   s = nir_deserialize(NULL, options, &blob_reader);
+   free(buffer); /* buffer was malloc-ed */
+   return s;
+}
+
 struct nir_shader *
 tgsi_to_nir(const void *tgsi_tokens,
-            struct pipe_screen *screen)
+            struct pipe_screen *screen,
+            bool allow_disk_cache)
 {
+   struct disk_cache *cache = NULL;
    struct ttn_compile *c;
-   struct nir_shader *s;
+   struct nir_shader *s = NULL;
+   uint8_t key[CACHE_KEY_SIZE];
+   unsigned processor;
+
+   if (allow_disk_cache)
+      cache = screen->get_disk_shader_cache(screen);
+
+   /* Look first in the cache */
+   if (cache) {
+      disk_cache_compute_key(cache,
+                             tgsi_tokens,
+                             tgsi_num_tokens(tgsi_tokens) * sizeof(struct tgsi_token),
+                             key);
+      processor = tgsi_get_processor_type(tgsi_tokens);
+      s = load_nir_from_disk_cache(cache, screen, key, processor);
+   }
+
+   if (s)
+      return s;
+
+   /* Not in the cache */
 
    c = ttn_compile_init(tgsi_tokens, NULL, screen);
    s = c->build.shader;
-   ttn_finalize_nir(c);
+   ttn_finalize_nir(c, screen);
    ralloc_free(c);
 
+   if (cache)
+      save_nir_to_disk_cache(cache, key, s);
+
    return s;
 }