nir: add support for flushing to zero denorm constants
authorSamuel Iglesias Gonsálvez <siglesias@igalia.com>
Wed, 20 Jun 2018 07:11:14 +0000 (09:11 +0200)
committerAndres Gomez <agomez@igalia.com>
Tue, 17 Sep 2019 20:39:18 +0000 (23:39 +0300)
v2:
- Refactor conditions and shared function (Connor).
- Move code to nir_eval_const_opcode() (Connor).
- Don't flush to zero on fquantize2f16
  From Vulkan spec, VK_KHR_shader_float_controls section:

  "3) Do denorm and rounding mode controls apply to OpSpecConstantOp?

  RESOLVED: Yes, except when the opcode is OpQuantizeToF16."

v3:
- Fix bit size (Connor).
- Fix execution mode on nir_loop_analize (Connor).

v4:
- Adapt after API changes to nir_eval_const_opcode (Andres).

v5:
- Simplify constant_denorm_flush_to_zero (Caio).

v6:
- Adapt after API changes and to use the new constant
  constructors (Andres).
- Replace MAYBE_UNUSED with UNUSED as the first is going
  away (Andres).

v7:
- Adapt to newly added calls (Andres).
- Simplified the auxiliary to flush denorms to zero (Caio).
- Updated to renamed supported capabilities member (Andres).

Signed-off-by: Samuel Iglesias Gonsálvez <siglesias@igalia.com>
Signed-off-by: Andres Gomez <agomez@igalia.com>
Reviewed-by: Connor Abbott <cwabbott0@gmail.com> [v4]
Reviewed-by: Caio Marcelo de Oliveira Filho <caio.oliveira@intel.com>
src/compiler/nir/nir_constant_expressions.h
src/compiler/nir/nir_constant_expressions.py
src/compiler/nir/nir_loop_analyze.c
src/compiler/nir/nir_opt_constant_folding.c
src/compiler/spirv/spirv_to_nir.c

index 087663f74803e32fd670f1e482545302d7bbc8b3..6450d5c3d03f38f20e81cf2959f9351ed8d1be04 100644 (file)
@@ -32,6 +32,7 @@
 
 void nir_eval_const_opcode(nir_op op, nir_const_value *dest,
                            unsigned num_components, unsigned bit_size,
-                           nir_const_value **src);
+                           nir_const_value **src,
+                           unsigned float_controls_execution_mode);
 
 #endif /* NIR_CONSTANT_EXPRESSIONS_H */
index fca14602855acca5abebec049126614906b15bee..1df97aa100094ef1bf1976ca6d8277e5264c82e0 100644 (file)
@@ -68,6 +68,27 @@ template = """\
 
 #define MAX_UINT_FOR_SIZE(bits) (UINT64_MAX >> (64 - (bits)))
 
+/**
+ * \brief Checks if the provided value is a denorm and flushes it to zero.
+ */
+static void
+constant_denorm_flush_to_zero(nir_const_value *value, unsigned bit_size)
+{
+    switch(bit_size) {
+    case 64:
+        if (0 == (value->u64 & 0x7ff0000000000000))
+            value->u64 &= 0x8000000000000000;
+        break;
+    case 32:
+        if (0 == (value->u32 & 0x7f800000))
+            value->u32 &= 0x80000000;
+        break;
+    case 16:
+        if (0 == (value->u16 & 0x7c00))
+            value->u16 &= 0x8000;
+    }
+}
+
 /**
  * Evaluate one component of packSnorm4x8.
  */
@@ -262,7 +283,7 @@ struct ${type}${width}_vec {
 % endfor
 % endfor
 
-<%def name="evaluate_op(op, bit_size)">
+<%def name="evaluate_op(op, bit_size, execution_mode)">
    <%
    output_type = type_add_size(op.output_type, bit_size)
    input_types = [type_add_size(type_, bit_size) for type_ in op.input_types]
@@ -345,6 +366,18 @@ struct ${type}${width}_vec {
          % else:
             _dst_val[_i].${get_const_field(output_type)} = dst;
          % endif
+
+         % if op.name != "fquantize2f16" and type_base_type(output_type) == "float":
+            % if type_has_size(output_type):
+               if (nir_is_denorm_flush_to_zero(execution_mode, ${type_size(output_type)})) {
+                  constant_denorm_flush_to_zero(&_dst_val[_i], ${type_size(output_type)});
+               }
+            % else:
+               if (nir_is_denorm_flush_to_zero(execution_mode, ${bit_size})) {
+                  constant_denorm_flush_to_zero(&_dst_val[i], bit_size);
+               }
+            %endif
+         % endif
       }
    % else:
       ## In the non-per-component case, create a struct dst with
@@ -377,6 +410,18 @@ struct ${type}${width}_vec {
          % else:
             _dst_val[${k}].${get_const_field(output_type)} = dst.${"xyzw"[k]};
          % endif
+
+         % if op.name != "fquantize2f16" and type_base_type(output_type) == "float":
+            % if type_has_size(output_type):
+               if (nir_is_denorm_flush_to_zero(execution_mode, ${type_size(output_type)})) {
+                  constant_denorm_flush_to_zero(&_dst_val[${k}], ${type_size(output_type)});
+               }
+            % else:
+               if (nir_is_denorm_flush_to_zero(execution_mode, ${bit_size})) {
+                  constant_denorm_flush_to_zero(&_dst_val[${k}], bit_size);
+               }
+            % endif
+         % endif
       % endfor
    % endif
 </%def>
@@ -386,13 +431,14 @@ static void
 evaluate_${name}(nir_const_value *_dst_val,
                  UNUSED unsigned num_components,
                  ${"UNUSED" if op_bit_sizes(op) is None else ""} unsigned bit_size,
-                 UNUSED nir_const_value **_src)
+                 UNUSED nir_const_value **_src,
+                 UNUSED unsigned execution_mode)
 {
    % if op_bit_sizes(op) is not None:
       switch (bit_size) {
       % for bit_size in op_bit_sizes(op):
       case ${bit_size}: {
-         ${evaluate_op(op, bit_size)}
+         ${evaluate_op(op, bit_size, execution_mode)}
          break;
       }
       % endfor
@@ -401,7 +447,7 @@ evaluate_${name}(nir_const_value *_dst_val,
          unreachable("unknown bit width");
       }
    % else:
-      ${evaluate_op(op, 0)}
+      ${evaluate_op(op, 0, execution_mode)}
    % endif
 }
 % endfor
@@ -409,12 +455,13 @@ evaluate_${name}(nir_const_value *_dst_val,
 void
 nir_eval_const_opcode(nir_op op, nir_const_value *dest,
                       unsigned num_components, unsigned bit_width,
-                      nir_const_value **src)
+                      nir_const_value **src,
+                      unsigned float_controls_execution_mode)
 {
    switch (op) {
 % for name in sorted(opcodes.keys()):
    case nir_op_${name}:
-      evaluate_${name}(dest, num_components, bit_width, src);
+      evaluate_${name}(dest, num_components, bit_width, src, float_controls_execution_mode);
       return;
 % endfor
    default:
@@ -425,6 +472,8 @@ nir_eval_const_opcode(nir_op op, nir_const_value *dest,
 from mako.template import Template
 
 print(Template(template).render(opcodes=opcodes, type_sizes=type_sizes,
+                                type_base_type=type_base_type,
+                                type_size=type_size,
                                 type_has_size=type_has_size,
                                 type_add_size=type_add_size,
                                 op_bit_sizes=op_bit_sizes,
index 4689d2230af46a09912c667c2e48b79bf655d64e..c2473421215e4a8cedce85a0647b189df4f9d3ba 100644 (file)
@@ -589,29 +589,32 @@ try_find_limit_of_alu(nir_ssa_scalar limit, nir_const_value *limit_val,
 }
 
 static nir_const_value
-eval_const_unop(nir_op op, unsigned bit_size, nir_const_value src0)
+eval_const_unop(nir_op op, unsigned bit_size, nir_const_value src0,
+                unsigned execution_mode)
 {
    assert(nir_op_infos[op].num_inputs == 1);
    nir_const_value dest;
    nir_const_value *src[1] = { &src0 };
-   nir_eval_const_opcode(op, &dest, 1, bit_size, src);
+   nir_eval_const_opcode(op, &dest, 1, bit_size, src, execution_mode);
    return dest;
 }
 
 static nir_const_value
 eval_const_binop(nir_op op, unsigned bit_size,
-                 nir_const_value src0, nir_const_value src1)
+                 nir_const_value src0, nir_const_value src1,
+                 unsigned execution_mode)
 {
    assert(nir_op_infos[op].num_inputs == 2);
    nir_const_value dest;
    nir_const_value *src[2] = { &src0, &src1 };
-   nir_eval_const_opcode(op, &dest, 1, bit_size, src);
+   nir_eval_const_opcode(op, &dest, 1, bit_size, src, execution_mode);
    return dest;
 }
 
 static int32_t
 get_iteration(nir_op cond_op, nir_const_value initial, nir_const_value step,
-              nir_const_value limit, unsigned bit_size)
+              nir_const_value limit, unsigned bit_size,
+              unsigned execution_mode)
 {
    nir_const_value span, iter;
 
@@ -620,23 +623,29 @@ get_iteration(nir_op cond_op, nir_const_value initial, nir_const_value step,
    case nir_op_ilt:
    case nir_op_ieq:
    case nir_op_ine:
-      span = eval_const_binop(nir_op_isub, bit_size, limit, initial);
-      iter = eval_const_binop(nir_op_idiv, bit_size, span, step);
+      span = eval_const_binop(nir_op_isub, bit_size, limit, initial,
+                              execution_mode);
+      iter = eval_const_binop(nir_op_idiv, bit_size, span, step,
+                              execution_mode);
       break;
 
    case nir_op_uge:
    case nir_op_ult:
-      span = eval_const_binop(nir_op_isub, bit_size, limit, initial);
-      iter = eval_const_binop(nir_op_udiv, bit_size, span, step);
+      span = eval_const_binop(nir_op_isub, bit_size, limit, initial,
+                              execution_mode);
+      iter = eval_const_binop(nir_op_udiv, bit_size, span, step,
+                              execution_mode);
       break;
 
    case nir_op_fge:
    case nir_op_flt:
    case nir_op_feq:
    case nir_op_fne:
-      span = eval_const_binop(nir_op_fsub, bit_size, limit, initial);
-      iter = eval_const_binop(nir_op_fdiv, bit_size, span, step);
-      iter = eval_const_unop(nir_op_f2i64, bit_size, iter);
+      span = eval_const_binop(nir_op_fsub, bit_size, limit, initial,
+                              execution_mode);
+      iter = eval_const_binop(nir_op_fdiv, bit_size, span,
+                              step, execution_mode);
+      iter = eval_const_unop(nir_op_f2i64, bit_size, iter, execution_mode);
       break;
 
    default:
@@ -654,7 +663,8 @@ will_break_on_first_iteration(nir_const_value step,
                               nir_op cond_op, unsigned bit_size,
                               nir_const_value initial,
                               nir_const_value limit,
-                              bool limit_rhs, bool invert_cond)
+                              bool limit_rhs, bool invert_cond,
+                              unsigned execution_mode)
 {
    if (trip_offset == 1) {
       nir_op add_op;
@@ -670,7 +680,8 @@ will_break_on_first_iteration(nir_const_value step,
          unreachable("Unhandled induction variable base type!");
       }
 
-      initial = eval_const_binop(add_op, bit_size, initial, step);
+      initial = eval_const_binop(add_op, bit_size, initial, step,
+                                 execution_mode);
    }
 
    nir_const_value *src[2];
@@ -679,7 +690,7 @@ will_break_on_first_iteration(nir_const_value step,
 
    /* Evaluate the loop exit condition */
    nir_const_value result;
-   nir_eval_const_opcode(cond_op, &result, 1, bit_size, src);
+   nir_eval_const_opcode(cond_op, &result, 1, bit_size, src, execution_mode);
 
    return invert_cond ? !result.b : result.b;
 }
@@ -688,7 +699,8 @@ static bool
 test_iterations(int32_t iter_int, nir_const_value step,
                 nir_const_value limit, nir_op cond_op, unsigned bit_size,
                 nir_alu_type induction_base_type,
-                nir_const_value initial, bool limit_rhs, bool invert_cond)
+                nir_const_value initial, bool limit_rhs, bool invert_cond,
+                unsigned execution_mode)
 {
    assert(nir_op_infos[cond_op].num_inputs == 2);
 
@@ -715,11 +727,11 @@ test_iterations(int32_t iter_int, nir_const_value step,
     * step the induction variable each iteration.
     */
    nir_const_value mul_result =
-      eval_const_binop(mul_op, bit_size, iter_src, step);
+      eval_const_binop(mul_op, bit_size, iter_src, step, execution_mode);
 
    /* Add the initial value to the accumulated induction variable total */
    nir_const_value add_result =
-      eval_const_binop(add_op, bit_size, mul_result, initial);
+      eval_const_binop(add_op, bit_size, mul_result, initial, execution_mode);
 
    nir_const_value *src[2];
    src[limit_rhs ? 0 : 1] = &add_result;
@@ -727,7 +739,7 @@ test_iterations(int32_t iter_int, nir_const_value step,
 
    /* Evaluate the loop exit condition */
    nir_const_value result;
-   nir_eval_const_opcode(cond_op, &result, 1, bit_size, src);
+   nir_eval_const_opcode(cond_op, &result, 1, bit_size, src, execution_mode);
 
    return invert_cond ? !result.b : result.b;
 }
@@ -736,7 +748,7 @@ static int
 calculate_iterations(nir_const_value initial, nir_const_value step,
                      nir_const_value limit, nir_alu_instr *alu,
                      nir_ssa_scalar cond, nir_op alu_op, bool limit_rhs,
-                     bool invert_cond)
+                     bool invert_cond, unsigned execution_mode)
 {
    /* nir_op_isub should have been lowered away by this point */
    assert(alu->op != nir_op_isub);
@@ -786,11 +798,13 @@ calculate_iterations(nir_const_value initial, nir_const_value step,
     */
    if (will_break_on_first_iteration(step, induction_base_type, trip_offset,
                                      alu_op, bit_size, initial,
-                                     limit, limit_rhs, invert_cond)) {
+                                     limit, limit_rhs, invert_cond,
+                                     execution_mode)) {
       return 0;
    }
 
-   int iter_int = get_iteration(alu_op, initial, step, limit, bit_size);
+   int iter_int = get_iteration(alu_op, initial, step, limit, bit_size,
+                                execution_mode);
 
    /* If iter_int is negative the loop is ill-formed or is the conditional is
     * unsigned with a huge iteration count so don't bother going any further.
@@ -812,7 +826,7 @@ calculate_iterations(nir_const_value initial, nir_const_value step,
 
       if (test_iterations(iter_bias, step, limit, alu_op, bit_size,
                           induction_base_type, initial,
-                          limit_rhs, invert_cond)) {
+                          limit_rhs, invert_cond, execution_mode)) {
          return iter_bias > 0 ? iter_bias - trip_offset : iter_bias;
       }
    }
@@ -950,7 +964,7 @@ try_find_trip_count_vars_in_iand(nir_ssa_scalar *cond,
  * loop.
  */
 static void
-find_trip_count(loop_info_state *state)
+find_trip_count(loop_info_state *state, unsigned execution_mode)
 {
    bool trip_count_known = true;
    bool guessed_trip_count = false;
@@ -1063,7 +1077,8 @@ find_trip_count(loop_info_state *state)
       int iterations = calculate_iterations(initial_val, step_val, limit_val,
                                             ind_var->alu, cond,
                                             alu_op, limit_rhs,
-                                            terminator->continue_from_then);
+                                            terminator->continue_from_then,
+                                            execution_mode);
 
       /* Where we not able to calculate the iteration count */
       if (iterations == -1) {
@@ -1203,7 +1218,7 @@ get_loop_info(loop_info_state *state, nir_function_impl *impl)
       return;
 
    /* Run through each of the terminators and try to compute a trip-count */
-   find_trip_count(state);
+   find_trip_count(state, impl->function->shader->info.float_controls_execution_mode);
 
    nir_foreach_block_in_cf_node(block, &state->loop->cf_node) {
       if (force_unroll_heuristics(state, block)) {
index 4123bc8af6d5d2f8d35ca16f70c8f64696f2ad8d..80d9c613e5b4adb2e00a186f3522463d44a05c86 100644 (file)
@@ -33,7 +33,7 @@
  */
 
 static bool
-constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx)
+constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx, unsigned execution_mode)
 {
    nir_const_value src[NIR_MAX_VEC_COMPONENTS][NIR_MAX_VEC_COMPONENTS];
 
@@ -88,7 +88,7 @@ constant_fold_alu_instr(nir_alu_instr *instr, void *mem_ctx)
    for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; ++i)
       srcs[i] = src[i];
    nir_eval_const_opcode(instr->op, dest, instr->dest.dest.ssa.num_components,
-                         bit_size, srcs);
+                         bit_size, srcs, execution_mode);
 
    nir_load_const_instr *new_instr =
       nir_load_const_instr_create(mem_ctx,
@@ -144,14 +144,14 @@ constant_fold_intrinsic_instr(nir_intrinsic_instr *instr)
 }
 
 static bool
-constant_fold_block(nir_block *block, void *mem_ctx)
+constant_fold_block(nir_block *block, void *mem_ctx, unsigned execution_mode)
 {
    bool progress = false;
 
    nir_foreach_instr_safe(instr, block) {
       switch (instr->type) {
       case nir_instr_type_alu:
-         progress |= constant_fold_alu_instr(nir_instr_as_alu(instr), mem_ctx);
+         progress |= constant_fold_alu_instr(nir_instr_as_alu(instr), mem_ctx, execution_mode);
          break;
       case nir_instr_type_intrinsic:
          progress |=
@@ -167,13 +167,13 @@ constant_fold_block(nir_block *block, void *mem_ctx)
 }
 
 static bool
-nir_opt_constant_folding_impl(nir_function_impl *impl)
+nir_opt_constant_folding_impl(nir_function_impl *impl, unsigned execution_mode)
 {
    void *mem_ctx = ralloc_parent(impl);
    bool progress = false;
 
    nir_foreach_block(block, impl) {
-      progress |= constant_fold_block(block, mem_ctx);
+      progress |= constant_fold_block(block, mem_ctx, execution_mode);
    }
 
    if (progress) {
@@ -192,10 +192,11 @@ bool
 nir_opt_constant_folding(nir_shader *shader)
 {
    bool progress = false;
+   unsigned execution_mode = shader->info.float_controls_execution_mode;
 
    nir_foreach_function(function, shader) {
       if (function->impl)
-         progress |= nir_opt_constant_folding_impl(function->impl);
+         progress |= nir_opt_constant_folding_impl(function->impl, execution_mode);
    }
 
    return progress;
index f46b0a9bb02654a785e51a131a3007d95046991c..54052062392e44fe9b42f3ed3535dde4237da2ea 100644 (file)
@@ -1896,7 +1896,9 @@ vtn_handle_constant(struct vtn_builder *b, SpvOp opcode,
          nir_const_value *srcs[3] = {
             src[0], src[1], src[2],
          };
-         nir_eval_const_opcode(op, val->constant->values, num_components, bit_size, srcs);
+         nir_eval_const_opcode(op, val->constant->values,
+                               num_components, bit_size, srcs,
+                               b->shader->info.float_controls_execution_mode);
          break;
       } /* default */
       }