nir: Add a nir_foreach_function_temp_variable helper
[mesa.git] / src / compiler / nir / nir_opt_large_constants.c
index aa22f05d8099d5a8082e5485ca9114854344e439..2575407ca0b5b4e3bb72b745a1a1e2f30e364cd2 100644 (file)
 #include "nir_deref.h"
 
 struct var_info {
+   nir_variable *var;
+
    bool is_constant;
    bool found_read;
+   bool duplicate;
+
+   /* Block that has all the variable stores.  All the blocks with reads
+    * should be dominated by this block.
+    */
+   nir_block *block;
+
+   /* If is_constant, hold the collected constant data for this var. */
+   uint32_t constant_data_size;
+   void *constant_data;
 };
 
+static int
+var_info_cmp(const void *_a, const void *_b)
+{
+   const struct var_info *a = _a;
+   const struct var_info *b = _b;
+   uint32_t a_size = a->constant_data_size;
+   uint32_t b_size = b->constant_data_size;
+
+   if (a_size < b_size) {
+      return -1;
+   } else if (a_size > b_size) {
+      return 1;
+   } else if (a_size == 0) {
+      /* Don't call memcmp with invalid pointers. */
+      return 0;
+   } else {
+      return memcmp(a->constant_data, b->constant_data, a_size);
+   }
+}
+
 static nir_ssa_def *
 build_constant_load(nir_builder *b, nir_deref_instr *deref,
                     glsl_type_size_align_func size_align)
@@ -43,81 +75,81 @@ build_constant_load(nir_builder *b, nir_deref_instr *deref,
    size_align(var->type, &var_size, &var_align);
    assert(var->data.location % var_align == 0);
 
+   UNUSED unsigned deref_size, deref_align;
+   size_align(deref->type, &deref_size, &deref_align);
+
    nir_intrinsic_instr *load =
       nir_intrinsic_instr_create(b->shader, nir_intrinsic_load_constant);
    load->num_components = num_components;
    nir_intrinsic_set_base(load, var->data.location);
    nir_intrinsic_set_range(load, var_size);
+   nir_intrinsic_set_align(load, deref_align, 0);
    load->src[0] = nir_src_for_ssa(nir_build_deref_offset(b, deref, size_align));
    nir_ssa_dest_init(&load->instr, &load->dest,
                      num_components, bit_size, NULL);
    nir_builder_instr_insert(b, &load->instr);
 
    if (load->dest.ssa.bit_size < 8) {
-      /* Booleans are special-cased to be 32-bit
-       *
-       * Ideally, for drivers that can handle 32-bit booleans, we wouldn't
-       * emit the i2b here.  However, at this point, the driver is likely to
-       * still have 1-bit booleans so we need to at least convert bit sizes.
-       * Unfortunately, we don't have a good way to annotate the load as
-       * loading a known boolean value so the optimizer isn't going to be
-       * able to get rid of the conversion.  Some day, we may solve that
-       * problem but not today.
-       */
+      /* Booleans are special-cased to be 32-bit */
       assert(glsl_type_is_boolean(deref->type));
+      assert(deref_size == num_components * 4);
       load->dest.ssa.bit_size = 32;
-      return nir_i2b(b, &load->dest.ssa);
+      return nir_b2b1(b, &load->dest.ssa);
    } else {
+      assert(deref_size == num_components * bit_size / 8);
       return &load->dest.ssa;
    }
 }
 
 static void
-handle_constant_store(nir_builder *b, nir_intrinsic_instr *store,
+handle_constant_store(void *mem_ctx, struct var_info *info,
+                      nir_deref_instr *deref, nir_const_value *val,
+                      unsigned writemask,
                       glsl_type_size_align_func size_align)
 {
-   nir_deref_instr *deref = nir_src_as_deref(store->src[0]);
    assert(!nir_deref_instr_has_indirect(deref));
-
-   nir_variable *var = nir_deref_instr_get_variable(deref);
-
    const unsigned bit_size = glsl_get_bit_size(deref->type);
    const unsigned num_components = glsl_get_vector_elements(deref->type);
 
-   char *dst = (char *)b->shader->constant_data +
-               var->data.location +
+   if (info->constant_data_size == 0) {
+      unsigned var_size, var_align;
+      size_align(info->var->type, &var_size, &var_align);
+      info->constant_data_size = var_size;
+      info->constant_data = rzalloc_size(mem_ctx, var_size);
+   }
+
+   char *dst = (char *)info->constant_data +
                nir_deref_instr_get_const_offset(deref, size_align);
 
-   nir_const_value *val = nir_src_as_const_value(store->src[1]);
-   switch (bit_size) {
-   case 1:
-      /* Booleans are special-cased to be 32-bit */
-      for (unsigned i = 0; i < num_components; i++)
-         ((int32_t *)dst)[i] = -(int)val->b[i];
-      break;
-
-   case 8:
-      for (unsigned i = 0; i < num_components; i++)
-         ((uint8_t *)dst)[i] = val->u8[i];
-      break;
-
-   case 16:
-      for (unsigned i = 0; i < num_components; i++)
-         ((uint16_t *)dst)[i] = val->u16[i];
-      break;
-
-   case 32:
-      for (unsigned i = 0; i < num_components; i++)
-         ((uint32_t *)dst)[i] = val->u32[i];
-      break;
-
-   case 64:
-      for (unsigned i = 0; i < num_components; i++)
-         ((uint64_t *)dst)[i] = val->u64[i];
-      break;
-
-   default:
-      unreachable("Invalid bit size");
+   for (unsigned i = 0; i < num_components; i++) {
+      if (!(writemask & (1 << i)))
+         continue;
+
+      switch (bit_size) {
+      case 1:
+         /* Booleans are special-cased to be 32-bit */
+         ((int32_t *)dst)[i] = -(int)val[i].b;
+         break;
+
+      case 8:
+         ((uint8_t *)dst)[i] = val[i].u8;
+         break;
+
+      case 16:
+         ((uint16_t *)dst)[i] = val[i].u16;
+         break;
+
+      case 32:
+         ((uint32_t *)dst)[i] = val[i].u32;
+         break;
+
+      case 64:
+         ((uint64_t *)dst)[i] = val[i].u64;
+         break;
+
+      default:
+         unreachable("Invalid bit size");
+      }
    }
 }
 
@@ -144,25 +176,28 @@ nir_opt_large_constants(nir_shader *shader,
    /* This pass can only be run once */
    assert(shader->constant_data == NULL && shader->constant_data_size == 0);
 
-   /* The index parameter is unused for local variables so we'll use it for
-    * indexing into our array of variable metadata.
-    */
-   unsigned num_locals = 0;
-   nir_foreach_variable(var, &impl->locals)
-      var->data.index = num_locals++;
+   unsigned num_locals = exec_list_length(&impl->locals);
+   nir_index_vars(shader, impl, nir_var_function_temp);
+
+   if (num_locals == 0) {
+      nir_shader_preserve_all_metadata(shader);
+      return false;
+   }
 
-   struct var_info *var_infos = malloc(num_locals * sizeof(struct var_info));
-   for (unsigned i = 0; i < num_locals; i++) {
-      var_infos[i] = (struct var_info) {
+   struct var_info *var_infos = ralloc_array(NULL, struct var_info, num_locals);
+   nir_foreach_function_temp_variable(var, impl) {
+      var_infos[var->index] = (struct var_info) {
+         .var = var,
          .is_constant = true,
          .found_read = false,
       };
    }
 
+   nir_metadata_require(impl, nir_metadata_dominance);
+
    /* First, walk through the shader and figure out what variables we can
     * lower to the constant blob.
     */
-   bool first_block = true;
    nir_foreach_block(block, impl) {
       nir_foreach_instr(instr, block) {
          if (instr->type != nir_instr_type_intrinsic)
@@ -172,10 +207,12 @@ nir_opt_large_constants(nir_shader *shader,
 
          bool src_is_const = false;
          nir_deref_instr *src_deref = NULL, *dst_deref = NULL;
+         unsigned writemask = 0;
          switch (intrin->intrinsic) {
          case nir_intrinsic_store_deref:
             dst_deref = nir_src_as_deref(intrin->src[0]);
             src_is_const = nir_src_is_const(intrin->src[1]);
+            writemask = nir_intrinsic_write_mask(intrin);
             break;
 
          case nir_intrinsic_load_deref:
@@ -183,67 +220,103 @@ nir_opt_large_constants(nir_shader *shader,
             break;
 
          case nir_intrinsic_copy_deref:
-            /* We always assume the src and therefore the dst are not
-             * constants here. Copy and constant propagation passes should
-             * have taken care of this in most cases anyway.
-             */
-            dst_deref = nir_src_as_deref(intrin->src[0]);
-            src_deref = nir_src_as_deref(intrin->src[1]);
-            src_is_const = false;
+            assert(!"Lowering of copy_deref with large constants is prohibited");
             break;
 
          default:
             continue;
          }
 
-         if (dst_deref && dst_deref->mode == nir_var_local) {
+         if (dst_deref && dst_deref->mode == nir_var_function_temp) {
             nir_variable *var = nir_deref_instr_get_variable(dst_deref);
-            assert(var->data.mode == nir_var_local);
+            assert(var->data.mode == nir_var_function_temp);
+
+            struct var_info *info = &var_infos[var->index];
+            if (!info->is_constant)
+               continue;
+
+            if (!info->block)
+               info->block = block;
 
             /* We only consider variables constant if they only have constant
              * stores, all the stores come before any reads, and all stores
-             * come in the first block.  We also can't handle indirect stores.
+             * come from the same block.  We also can't handle indirect stores.
              */
-            struct var_info *info = &var_infos[var->data.index];
-            if (!src_is_const || info->found_read || !first_block ||
-                nir_deref_instr_has_indirect(dst_deref))
+            if (!src_is_const || info->found_read || block != info->block ||
+                nir_deref_instr_has_indirect(dst_deref)) {
                info->is_constant = false;
+            } else {
+               nir_const_value *val = nir_src_as_const_value(intrin->src[1]);
+               handle_constant_store(var_infos, info, dst_deref, val, writemask,
+                                     size_align);
+            }
          }
 
-         if (src_deref && src_deref->mode == nir_var_local) {
+         if (src_deref && src_deref->mode == nir_var_function_temp) {
             nir_variable *var = nir_deref_instr_get_variable(src_deref);
-            assert(var->data.mode == nir_var_local);
+            assert(var->data.mode == nir_var_function_temp);
+
+            /* We only consider variables constant if all the reads are
+             * dominated by the block that writes to it.
+             */
+            struct var_info *info = &var_infos[var->index];
+            if (!info->is_constant)
+               continue;
 
-            var_infos[var->data.index].found_read = true;
+            if (!info->block || !nir_block_dominates(info->block, block))
+               info->is_constant = false;
+
+            info->found_read = true;
          }
       }
-      first_block = false;
    }
 
+   /* Allocate constant data space for each variable that just has constant
+    * data.  We sort them by size and content so we can easily find
+    * duplicates.
+    */
    shader->constant_data_size = 0;
-   nir_foreach_variable(var, &impl->locals) {
-      struct var_info *info = &var_infos[var->data.index];
+   qsort(var_infos, num_locals, sizeof(struct var_info), var_info_cmp);
+   for (int i = 0; i < num_locals; i++) {
+      struct var_info *info = &var_infos[i];
+
+      /* Fix up indices after we sorted. */
+      info->var->index = i;
+
       if (!info->is_constant)
          continue;
 
       unsigned var_size, var_align;
-      size_align(var->type, &var_size, &var_align);
+      size_align(info->var->type, &var_size, &var_align);
       if (var_size <= threshold || !info->found_read) {
          /* Don't bother lowering small stuff or data that's never read */
          info->is_constant = false;
          continue;
       }
 
-      var->data.location = ALIGN_POT(shader->constant_data_size, var_align);
-      shader->constant_data_size = var->data.location + var_size;
+      if (i > 0 && var_info_cmp(info, &var_infos[i - 1]) == 0) {
+         info->var->data.location = var_infos[i - 1].var->data.location;
+         info->duplicate = true;
+      } else {
+         info->var->data.location = ALIGN_POT(shader->constant_data_size, var_align);
+         shader->constant_data_size = info->var->data.location + var_size;
+      }
    }
 
    if (shader->constant_data_size == 0) {
-      free(var_infos);
+      nir_shader_preserve_all_metadata(shader);
+      ralloc_free(var_infos);
       return false;
    }
 
    shader->constant_data = rzalloc_size(shader, shader->constant_data_size);
+   for (int i = 0; i < num_locals; i++) {
+      struct var_info *info = &var_infos[i];
+      if (!info->duplicate && info->is_constant) {
+         memcpy((char *)shader->constant_data + info->var->data.location,
+                info->constant_data, info->constant_data_size);
+      }
+   }
 
    nir_builder b;
    nir_builder_init(&b, impl);
@@ -258,11 +331,11 @@ nir_opt_large_constants(nir_shader *shader,
          switch (intrin->intrinsic) {
          case nir_intrinsic_load_deref: {
             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
-            if (deref->mode != nir_var_local)
+            if (deref->mode != nir_var_function_temp)
                continue;
 
             nir_variable *var = nir_deref_instr_get_variable(deref);
-            struct var_info *info = &var_infos[var->data.index];
+            struct var_info *info = &var_infos[var->index];
             if (info->is_constant) {
                b.cursor = nir_after_instr(&intrin->instr);
                nir_ssa_def *val = build_constant_load(&b, deref, size_align);
@@ -276,37 +349,18 @@ nir_opt_large_constants(nir_shader *shader,
 
          case nir_intrinsic_store_deref: {
             nir_deref_instr *deref = nir_src_as_deref(intrin->src[0]);
-            if (deref->mode != nir_var_local)
+            if (deref->mode != nir_var_function_temp)
                continue;
 
             nir_variable *var = nir_deref_instr_get_variable(deref);
-            struct var_info *info = &var_infos[var->data.index];
+            struct var_info *info = &var_infos[var->index];
             if (info->is_constant) {
-               b.cursor = nir_after_instr(&intrin->instr);
-               handle_constant_store(&b, intrin, size_align);
                nir_instr_remove(&intrin->instr);
                nir_deref_instr_remove_if_unused(deref);
             }
             break;
          }
-
-         case nir_intrinsic_copy_deref: {
-            nir_deref_instr *deref = nir_src_as_deref(intrin->src[1]);
-            if (deref->mode != nir_var_local)
-               continue;
-
-            nir_variable *var = nir_deref_instr_get_variable(deref);
-            struct var_info *info = &var_infos[var->data.index];
-            if (info->is_constant) {
-               b.cursor = nir_after_instr(&intrin->instr);
-               nir_ssa_def *val = build_constant_load(&b, deref, size_align);
-               nir_store_deref(&b, nir_src_as_deref(intrin->src[0]), val, ~0);
-               nir_instr_remove(&intrin->instr);
-               nir_deref_instr_remove_if_unused(deref);
-            }
-            break;
-         }
-
+         case nir_intrinsic_copy_deref:
          default:
             continue;
          }
@@ -314,12 +368,13 @@ nir_opt_large_constants(nir_shader *shader,
    }
 
    /* Clean up the now unused variables */
-   nir_foreach_variable_safe(var, &impl->locals) {
-      if (var_infos[var->data.index].is_constant)
-         exec_node_remove(&var->node);
+   for (int i = 0; i < num_locals; i++) {
+      struct var_info *info = &var_infos[i];
+      if (info->is_constant)
+         exec_node_remove(&info->var->node);
    }
 
-   free(var_infos);
+   ralloc_free(var_infos);
 
    nir_metadata_preserve(impl, nir_metadata_block_index |
                                nir_metadata_dominance);