X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fcompiler%2Fnir%2Fnir_opt_if.c;h=8a971c43f24b7c4a85bc92dc3d840db703e0e1cd;hb=769ae9fb7f8cea1d4a03e31f7f4a1c988e424c03;hp=dacf2d6c667a5a2733747815aecb3196dd36d37d;hpb=6f0647c0b207cfe0805eb899c2d97703ae00b1e7;p=mesa.git diff --git a/src/compiler/nir/nir_opt_if.c b/src/compiler/nir/nir_opt_if.c index dacf2d6c667..8a971c43f24 100644 --- a/src/compiler/nir/nir_opt_if.c +++ b/src/compiler/nir/nir_opt_if.c @@ -23,6 +23,7 @@ #include "nir.h" #include "nir/nir_builder.h" +#include "nir_constant_expressions.h" #include "nir_control_flow.h" #include "nir_loop_analyze.h" @@ -39,7 +40,6 @@ find_continue_block(nir_loop *loop) assert(header_block->predecessors->entries == 2); - struct set_entry *pred_entry; set_foreach(header_block->predecessors, pred_entry) { if (pred_entry->key != prev_block) return (nir_block*)pred_entry->key; @@ -180,6 +180,13 @@ opt_peel_loop_initial_if(nir_loop *loop) } } + /* We're about to re-arrange a bunch of blocks so make sure that we don't + * have deref uses which cross block boundaries. We don't want a deref + * accidentally ending up in a phi. + */ + nir_rematerialize_derefs_in_use_blocks_impl( + nir_cf_node_get_function(&loop->cf_node)); + /* Before we do anything, convert the loop to LCSSA. We're about to * replace a bunch of SSA defs with registers and this will prevent any of * it from leaking outside the loop. @@ -369,6 +376,204 @@ opt_if_loop_terminator(nir_if *nif) return true; } +static bool +evaluate_if_condition(nir_if *nif, nir_cursor cursor, bool *value) +{ + nir_block *use_block = nir_cursor_current_block(cursor); + if (nir_block_dominates(nir_if_first_then_block(nif), use_block)) { + *value = true; + return true; + } else if (nir_block_dominates(nir_if_first_else_block(nif), use_block)) { + *value = false; + return true; + } else { + return false; + } +} + +static nir_ssa_def * +clone_alu_and_replace_src_defs(nir_builder *b, const nir_alu_instr *alu, + nir_ssa_def **src_defs) +{ + nir_alu_instr *nalu = nir_alu_instr_create(b->shader, alu->op); + nalu->exact = alu->exact; + + nir_ssa_dest_init(&nalu->instr, &nalu->dest.dest, + alu->dest.dest.ssa.num_components, + alu->dest.dest.ssa.bit_size, alu->dest.dest.ssa.name); + + nalu->dest.saturate = alu->dest.saturate; + nalu->dest.write_mask = alu->dest.write_mask; + + for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { + assert(alu->src[i].src.is_ssa); + nalu->src[i].src = nir_src_for_ssa(src_defs[i]); + nalu->src[i].negate = alu->src[i].negate; + nalu->src[i].abs = alu->src[i].abs; + memcpy(nalu->src[i].swizzle, alu->src[i].swizzle, + sizeof(nalu->src[i].swizzle)); + } + + nir_builder_instr_insert(b, &nalu->instr); + + return &nalu->dest.dest.ssa;; +} + +/* + * This propagates if condition evaluation down the chain of some alu + * instructions. For example by checking the use of some of the following alu + * instruction we can eventually replace ssa_107 with NIR_TRUE. + * + * loop { + * block block_1: + * vec1 32 ssa_85 = load_const (0x00000002) + * vec1 32 ssa_86 = ieq ssa_48, ssa_85 + * vec1 32 ssa_87 = load_const (0x00000001) + * vec1 32 ssa_88 = ieq ssa_48, ssa_87 + * vec1 32 ssa_89 = ior ssa_86, ssa_88 + * vec1 32 ssa_90 = ieq ssa_48, ssa_0 + * vec1 32 ssa_91 = ior ssa_89, ssa_90 + * if ssa_86 { + * block block_2: + * ... + * break + * } else { + * block block_3: + * } + * block block_4: + * if ssa_88 { + * block block_5: + * ... + * break + * } else { + * block block_6: + * } + * block block_7: + * if ssa_90 { + * block block_8: + * ... + * break + * } else { + * block block_9: + * } + * block block_10: + * vec1 32 ssa_107 = inot ssa_91 + * if ssa_107 { + * block block_11: + * break + * } else { + * block block_12: + * } + * } + */ +static bool +propagate_condition_eval(nir_builder *b, nir_if *nif, nir_src *use_src, + nir_src *alu_use, nir_alu_instr *alu, + bool is_if_condition) +{ + bool bool_value; + b->cursor = nir_before_src(alu_use, is_if_condition); + if (!evaluate_if_condition(nif, b->cursor, &bool_value)) + return false; + + nir_ssa_def *def[4] = {0}; + for (unsigned i = 0; i < nir_op_infos[alu->op].num_inputs; i++) { + if (alu->src[i].src.ssa == use_src->ssa) { + def[i] = nir_imm_bool(b, bool_value); + } else { + def[i] = alu->src[i].src.ssa; + } + } + + nir_ssa_def *nalu = clone_alu_and_replace_src_defs(b, alu, def); + + /* Rewrite use to use new alu instruction */ + nir_src new_src = nir_src_for_ssa(nalu); + + if (is_if_condition) + nir_if_rewrite_condition(alu_use->parent_if, new_src); + else + nir_instr_rewrite_src(alu_use->parent_instr, alu_use, new_src); + + return true; +} + +static bool +can_propagate_through_alu(nir_src *src) +{ + if (src->parent_instr->type != nir_instr_type_alu) + return false; + + nir_alu_instr *alu = nir_instr_as_alu(src->parent_instr); + switch (alu->op) { + case nir_op_ior: + case nir_op_iand: + case nir_op_inot: + case nir_op_b2i: + return true; + case nir_op_bcsel: + return src == &alu->src[0].src; + default: + return false; + } +} + +static bool +evaluate_condition_use(nir_builder *b, nir_if *nif, nir_src *use_src, + bool is_if_condition) +{ + bool progress = false; + + b->cursor = nir_before_src(use_src, is_if_condition); + + bool bool_value; + if (evaluate_if_condition(nif, b->cursor, &bool_value)) { + /* Rewrite use to use const */ + nir_src imm_src = nir_src_for_ssa(nir_imm_bool(b, bool_value)); + if (is_if_condition) + nir_if_rewrite_condition(use_src->parent_if, imm_src); + else + nir_instr_rewrite_src(use_src->parent_instr, use_src, imm_src); + + progress = true; + } + + if (!is_if_condition && can_propagate_through_alu(use_src)) { + nir_alu_instr *alu = nir_instr_as_alu(use_src->parent_instr); + + nir_foreach_use_safe(alu_use, &alu->dest.dest.ssa) { + progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu, + false); + } + + nir_foreach_if_use_safe(alu_use, &alu->dest.dest.ssa) { + progress |= propagate_condition_eval(b, nif, use_src, alu_use, alu, + true); + } + } + + return progress; +} + +static bool +opt_if_evaluate_condition_use(nir_builder *b, nir_if *nif) +{ + bool progress = false; + + /* Evaluate any uses of the if condition inside the if branches */ + assert(nif->condition.is_ssa); + nir_foreach_use_safe(use_src, nif->condition.ssa) { + progress |= evaluate_condition_use(b, nif, use_src, false); + } + + nir_foreach_if_use_safe(use_src, nif->condition.ssa) { + if (use_src->parent_if != nif) + progress |= evaluate_condition_use(b, nif, use_src, true); + } + + return progress; +} + static bool opt_if_cf_list(nir_builder *b, struct exec_list *cf_list) { @@ -402,6 +607,41 @@ opt_if_cf_list(nir_builder *b, struct exec_list *cf_list) return progress; } +/** + * These optimisations depend on nir_metadata_block_index and therefore must + * not do anything to cause the metadata to become invalid. + */ +static bool +opt_if_safe_cf_list(nir_builder *b, struct exec_list *cf_list) +{ + bool progress = false; + foreach_list_typed(nir_cf_node, cf_node, node, cf_list) { + switch (cf_node->type) { + case nir_cf_node_block: + break; + + case nir_cf_node_if: { + nir_if *nif = nir_cf_node_as_if(cf_node); + progress |= opt_if_safe_cf_list(b, &nif->then_list); + progress |= opt_if_safe_cf_list(b, &nif->else_list); + progress |= opt_if_evaluate_condition_use(b, nif); + break; + } + + case nir_cf_node_loop: { + nir_loop *loop = nir_cf_node_as_loop(cf_node); + progress |= opt_if_safe_cf_list(b, &loop->body); + break; + } + + case nir_cf_node_function: + unreachable("Invalid cf type"); + } + } + + return progress; +} + bool nir_opt_if(nir_shader *shader) { @@ -414,6 +654,12 @@ nir_opt_if(nir_shader *shader) nir_builder b; nir_builder_init(&b, function->impl); + nir_metadata_require(function->impl, nir_metadata_block_index | + nir_metadata_dominance); + progress = opt_if_safe_cf_list(&b, &function->impl->body); + nir_metadata_preserve(function->impl, nir_metadata_block_index | + nir_metadata_dominance); + if (opt_if_cf_list(&b, &function->impl->body)) { nir_metadata_preserve(function->impl, nir_metadata_none); @@ -423,12 +669,6 @@ nir_opt_if(nir_shader *shader) */ nir_lower_regs_to_ssa_impl(function->impl); - /* Calling nir_convert_loop_to_lcssa() in opt_peel_loop_initial_if() - * adds extra phi nodes which may not be valid if they're used for - * something such as a deref. Remove any unneeded phis. - */ - nir_opt_remove_phis_impl(function->impl); - progress = true; } }