spirv: Add a workaround for OpControlBarrier on old GLSLang
[mesa.git] / src / compiler / spirv / vtn_cfg.c
index aef1b7e18fb68a372828acb75323ba4f91cf4a94..25e2f285e79942f1c9f9d37f3300748b4bf807d4 100644 (file)
@@ -232,7 +232,7 @@ vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode,
    if (ret_type->base_type == vtn_base_type_void) {
       vtn_push_value(b, w[2], vtn_value_type_undef);
    } else {
-      vtn_push_ssa(b, w[2], res_type, vtn_local_load(b, ret_deref));
+      vtn_push_ssa(b, w[2], res_type, vtn_local_load(b, ret_deref, 0));
    }
 }
 
@@ -248,7 +248,7 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
       list_inithead(&b->func->body);
       b->func->control = w[3];
 
-      MAYBE_UNUSED const struct glsl_type *result_type =
+      UNUSED const struct glsl_type *result_type =
          vtn_value(b, w[1], vtn_value_type_type)->type->type;
       struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_function);
       val->func = b->func;
@@ -274,9 +274,12 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
 
       unsigned idx = 0;
       if (func_type->return_type->base_type != vtn_base_type_void) {
+         nir_address_format addr_format =
+            vtn_mode_to_address_format(b, vtn_variable_mode_function);
          /* The return value is a regular pointer */
          func->params[idx++] = (nir_parameter) {
-            .num_components = 1, .bit_size = 32,
+            .num_components = nir_address_format_num_components(addr_format),
+            .bit_size = nir_address_format_bit_size(addr_format),
          };
       }
 
@@ -287,6 +290,7 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
       b->func->impl = nir_function_impl_create(func);
       nir_builder_init(&b->nb, func->impl);
       b->nb.cursor = nir_before_cf_list(&b->func->impl->body);
+      b->nb.exact = b->exact;
 
       b->func_param_idx = 0;
 
@@ -314,7 +318,6 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
             vtn_push_value(b, w[2], vtn_value_type_sampled_image);
 
          val->sampled_image = ralloc(b, struct vtn_sampled_image);
-         val->sampled_image->type = type;
 
          struct vtn_type *sampler_type = rzalloc(b, struct vtn_type);
          sampler_type->base_type = vtn_base_type_sampler;
@@ -327,17 +330,12 @@ vtn_cfg_handle_prepass_instruction(struct vtn_builder *b, SpvOp opcode,
       } else if (type->base_type == vtn_base_type_pointer &&
                  type->type != NULL) {
          /* This is a pointer with an actual storage type */
-         struct vtn_value *val =
-            vtn_push_value(b, w[2], vtn_value_type_pointer);
          nir_ssa_def *ssa_ptr = nir_load_param(&b->nb, b->func_param_idx++);
-         val->pointer = vtn_pointer_from_ssa(b, ssa_ptr, type);
+         vtn_push_value_pointer(b, w[2], vtn_pointer_from_ssa(b, ssa_ptr, type));
       } else if (type->base_type == vtn_base_type_pointer ||
                  type->base_type == vtn_base_type_image ||
                  type->base_type == vtn_base_type_sampler) {
-         struct vtn_value *val =
-            vtn_push_value(b, w[2], vtn_value_type_pointer);
-         val->pointer =
-            vtn_load_param_pointer(b, type, b->func_param_idx++);
+         vtn_push_value_pointer(b, w[2], vtn_load_param_pointer(b, type, b->func_param_idx++));
       } else {
          /* We're a regular SSA value. */
          struct vtn_ssa_value *value = vtn_create_ssa_value(b, type->type);
@@ -585,6 +583,8 @@ vtn_cfg_walk_blocks(struct vtn_builder *b, struct list_head *cf_list,
          if (block->merge &&
              (*block->merge & SpvOpCodeMask) == SpvOpSelectionMerge) {
             if_stmt->control = block->merge[2];
+         } else {
+            if_stmt->control = SpvSelectionControlMaskNone;
          }
 
          if_stmt->then_type = vtn_get_branch_type(b, then_block,
@@ -791,7 +791,7 @@ vtn_handle_phis_first_pass(struct vtn_builder *b, SpvOp opcode,
    _mesa_hash_table_insert(b->phi_table, w, phi_var);
 
    vtn_push_ssa(b, w[2], type,
-                vtn_local_load(b, nir_build_deref_var(&b->nb, phi_var)));
+                vtn_local_load(b, nir_build_deref_var(&b->nb, phi_var), 0));
 
    return true;
 }
@@ -811,11 +811,16 @@ vtn_handle_phi_second_pass(struct vtn_builder *b, SpvOp opcode,
       struct vtn_block *pred =
          vtn_value(b, w[i + 1], vtn_value_type_block)->block;
 
+      /* If block does not have end_nop, that is because it is an unreacheable
+       * block, and hence it is not worth to handle it */
+      if (!pred->end_nop)
+         continue;
+
       b->nb.cursor = nir_after_instr(&pred->end_nop->instr);
 
       struct vtn_ssa_value *src = vtn_ssa_value(b, w[i]);
 
-      vtn_local_store(b, src, nir_build_deref_var(&b->nb, phi_var));
+      vtn_local_store(b, src, nir_build_deref_var(&b->nb, phi_var), 0);
    }
 
    return true;
@@ -852,6 +857,66 @@ vtn_emit_branch(struct vtn_builder *b, enum vtn_branch_type branch_type,
    }
 }
 
+static nir_ssa_def *
+vtn_switch_case_condition(struct vtn_builder *b, struct vtn_switch *swtch,
+                          nir_ssa_def *sel, struct vtn_case *cse)
+{
+   if (cse->is_default) {
+      nir_ssa_def *any = nir_imm_false(&b->nb);
+      list_for_each_entry(struct vtn_case, other, &swtch->cases, link) {
+         if (other->is_default)
+            continue;
+
+         any = nir_ior(&b->nb, any,
+                       vtn_switch_case_condition(b, swtch, sel, other));
+      }
+      return nir_inot(&b->nb, any);
+   } else {
+      nir_ssa_def *cond = nir_imm_false(&b->nb);
+      util_dynarray_foreach(&cse->values, uint64_t, val) {
+         nir_ssa_def *imm = nir_imm_intN_t(&b->nb, *val, sel->bit_size);
+         cond = nir_ior(&b->nb, cond, nir_ieq(&b->nb, sel, imm));
+      }
+      return cond;
+   }
+}
+
+static nir_loop_control
+vtn_loop_control(struct vtn_builder *b, struct vtn_loop *vtn_loop)
+{
+   if (vtn_loop->control == SpvLoopControlMaskNone)
+      return nir_loop_control_none;
+   else if (vtn_loop->control & SpvLoopControlDontUnrollMask)
+      return nir_loop_control_dont_unroll;
+   else if (vtn_loop->control & SpvLoopControlUnrollMask)
+      return nir_loop_control_unroll;
+   else if (vtn_loop->control & SpvLoopControlDependencyInfiniteMask ||
+            vtn_loop->control & SpvLoopControlDependencyLengthMask ||
+            vtn_loop->control & SpvLoopControlMinIterationsMask ||
+            vtn_loop->control & SpvLoopControlMaxIterationsMask ||
+            vtn_loop->control & SpvLoopControlIterationMultipleMask ||
+            vtn_loop->control & SpvLoopControlPeelCountMask ||
+            vtn_loop->control & SpvLoopControlPartialCountMask) {
+      /* We do not do anything special with these yet. */
+      return nir_loop_control_none;
+   } else {
+      vtn_fail("Invalid loop control");
+   }
+}
+
+static nir_selection_control
+vtn_selection_control(struct vtn_builder *b, struct vtn_if *vtn_if)
+{
+   if (vtn_if->control == SpvSelectionControlMaskNone)
+      return nir_selection_control_none;
+   else if (vtn_if->control & SpvSelectionControlDontFlattenMask)
+      return nir_selection_control_dont_flatten;
+   else if (vtn_if->control & SpvSelectionControlFlattenMask)
+      return nir_selection_control_flatten;
+   else
+      vtn_fail("Invalid selection control");
+}
+
 static void
 vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
                  nir_variable *switch_fall_var, bool *has_switch_break,
@@ -884,13 +949,14 @@ vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
                glsl_get_bare_type(b->func->type->return_type->type);
             nir_deref_instr *ret_deref =
                nir_build_deref_cast(&b->nb, nir_load_param(&b->nb, 0),
-                                    nir_var_function, ret_type, 0);
-            vtn_local_store(b, src, ret_deref);
+                                    nir_var_function_temp, ret_type, 0);
+            vtn_local_store(b, src, ret_deref, 0);
          }
 
          if (block->branch_type != vtn_branch_type_none) {
             vtn_emit_branch(b, block->branch_type,
                             switch_fall_var, has_switch_break);
+            return;
          }
 
          break;
@@ -902,6 +968,9 @@ vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
 
          nir_if *nif =
             nir_push_if(&b->nb, vtn_ssa_value(b, vtn_if->condition)->def);
+
+         nif->control = vtn_selection_control(b, vtn_if);
+
          if (vtn_if->then_type == vtn_branch_type_none) {
             vtn_emit_cf_list(b, &vtn_if->then_body,
                              switch_fall_var, &sw_break, handler);
@@ -936,9 +1005,11 @@ vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
          struct vtn_loop *vtn_loop = (struct vtn_loop *)node;
 
          nir_loop *loop = nir_push_loop(&b->nb);
+         loop->control = vtn_loop_control(b, vtn_loop);
+
          vtn_emit_cf_list(b, &vtn_loop->body, NULL, NULL, handler);
 
-         if (!list_empty(&vtn_loop->cont_body)) {
+         if (!list_is_empty(&vtn_loop->cont_body)) {
             /* If we have a non-trivial continue body then we need to put
              * it at the beginning of the loop with a flag to ensure that
              * it doesn't get executed in the first iteration.
@@ -978,46 +1049,13 @@ vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
             nir_local_variable_create(b->nb.impl, glsl_bool_type(), "fall");
          nir_store_var(&b->nb, fall_var, nir_imm_false(&b->nb), 1);
 
-         /* Next, we gather up all of the conditions.  We have to do this
-          * up-front because we also need to build an "any" condition so
-          * that we can use !any for default.
-          */
-         const int num_cases = list_length(&vtn_switch->cases);
-         NIR_VLA(nir_ssa_def *, conditions, num_cases);
-
          nir_ssa_def *sel = vtn_ssa_value(b, vtn_switch->selector)->def;
-         /* An accumulation of all conditions.  Used for the default */
-         nir_ssa_def *any = NULL;
-
-         int i = 0;
-         list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
-            if (cse->is_default) {
-               conditions[i++] = NULL;
-               continue;
-            }
-
-            nir_ssa_def *cond = NULL;
-            util_dynarray_foreach(&cse->values, uint64_t, val) {
-               nir_ssa_def *imm = nir_imm_intN_t(&b->nb, *val, sel->bit_size);
-               nir_ssa_def *is_val = nir_ieq(&b->nb, sel, imm);
-
-               cond = cond ? nir_ior(&b->nb, cond, is_val) : is_val;
-            }
-
-            any = any ? nir_ior(&b->nb, any, cond) : cond;
-            conditions[i++] = cond;
-         }
-         vtn_assert(i == num_cases);
 
          /* Now we can walk the list of cases and actually emit code */
-         i = 0;
          list_for_each_entry(struct vtn_case, cse, &vtn_switch->cases, link) {
             /* Figure out the condition */
-            nir_ssa_def *cond = conditions[i++];
-            if (cse->is_default) {
-               vtn_assert(cond == NULL);
-               cond = nir_inot(&b->nb, any);
-            }
+            nir_ssa_def *cond =
+               vtn_switch_case_condition(b, vtn_switch, sel, cse);
             /* Take fallthrough into account */
             cond = nir_ior(&b->nb, cond, nir_load_var(&b->nb, fall_var));
 
@@ -1030,7 +1068,6 @@ vtn_emit_cf_list(struct vtn_builder *b, struct list_head *cf_list,
 
             nir_pop_if(&b->nb, case_if);
          }
-         vtn_assert(i == num_cases);
 
          break;
       }
@@ -1048,9 +1085,9 @@ vtn_function_emit(struct vtn_builder *b, struct vtn_function *func,
    nir_builder_init(&b->nb, func->impl);
    b->func = func;
    b->nb.cursor = nir_after_cf_list(&func->impl->body);
+   b->nb.exact = b->exact;
    b->has_loop_continue = false;
-   b->phi_table = _mesa_hash_table_create(b, _mesa_hash_pointer,
-                                          _mesa_key_pointer_equal);
+   b->phi_table = _mesa_pointer_hash_table_create(b);
 
    vtn_emit_cf_list(b, &func->body, NULL, NULL, instruction_handler);