From: Ian Romanick Date: Mon, 10 Jun 2019 22:05:14 +0000 (-0700) Subject: nir: Handle swizzle in nir_alu_srcs_negative_equal X-Git-Url: https://git.libre-soc.org/?a=commitdiff_plain;h=12217de08cb1fd3dcedcaacb8757ee2f26fc3002;p=mesa.git nir: Handle swizzle in nir_alu_srcs_negative_equal When I added this function, I was not sure if swizzles of immediate values were a thing that occurred in NIR. The only existing user of these functions is the partial redundancy elimination for compares. Since comparison instructions are inherently scalar, this does not occur. However, a couple later patches, "nir/algebraic: Recognize open-coded flrp(-1, 1, a) and flrp(1, -1, a)" combined with "intel/vec4: Try to emit a single load for multiple 3-src instruction operands", collaborate to create a few thousand instances. No shader-db changes on any Intel platform. v2: Handle the swizzle in nir_alu_srcs_negative_equal and leave nir_const_value_negative_equal unchanged. Suggested by Jason. v3: Correctly handle write masks. Add note (and assertion) that the caller is responsible for various compatibility checks. The single existing caller only calls this for combinations of scalar fadd and float comparison instructions, so all of the requirements are met. A later patch (intel/vec4: Try to emit a single load for multiple 3-src instruction operands) will call this for sources of the same instruction, so all of the requirements are met. v4: Add unit test for nir_opt_comparison_pre that is fixed by this commit. Reviewed-by: Matt Turner --- diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c index 6796fcaad5b..d200412dc9c 100644 --- a/src/compiler/nir/nir_instr_set.c +++ b/src/compiler/nir/nir_instr_set.c @@ -352,12 +352,31 @@ nir_const_value_negative_equal(nir_const_value c1, * This function does not detect the general case when \p alu1 and \p alu2 are * SSA values that are the negations of each other (e.g., \p alu1 represents * (a * b) and \p alu2 represents (-a * b)). + * + * \warning + * It is the responsibility of the caller to ensure that the component counts, + * write masks, and base types of the sources being compared are compatible. */ bool nir_alu_srcs_negative_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2, unsigned src1, unsigned src2) { +#ifndef NDEBUG + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) { + assert(nir_alu_instr_channel_used(alu1, src1, i) == + nir_alu_instr_channel_used(alu2, src2, i)); + } + + if (nir_op_infos[alu1->op].input_types[src1] == nir_type_float) { + assert(nir_op_infos[alu1->op].input_types[src1] == + nir_op_infos[alu2->op].input_types[src2]); + } else { + assert(nir_op_infos[alu1->op].input_types[src1] == nir_type_int); + assert(nir_op_infos[alu2->op].input_types[src2] == nir_type_int); + } +#endif + if (alu1->src[src1].abs != alu2->src[src2].abs) return false; @@ -385,12 +404,13 @@ nir_alu_srcs_negative_equal(const nir_alu_instr *alu1, nir_src_bit_size(alu2->src[src2].src)) return false; - /* FINISHME: Apply the swizzle? */ - const unsigned components = nir_ssa_alu_instr_src_components(alu1, src1); const nir_alu_type full_type = nir_op_infos[alu1->op].input_types[src1] | nir_src_bit_size(alu1->src[src1].src); - for (unsigned i = 0; i < components; i++) { - if (!nir_const_value_negative_equal(const1[i], const2[i], full_type)) + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) { + if (nir_alu_instr_channel_used(alu1, src1, i) && + !nir_const_value_negative_equal(const1[alu1->src[src1].swizzle[i]], + const2[alu2->src[src2].swizzle[i]], + full_type)) return false; } diff --git a/src/compiler/nir/tests/comparison_pre_tests.cpp b/src/compiler/nir/tests/comparison_pre_tests.cpp index fe1cc23fb3b..a48aeca8da4 100644 --- a/src/compiler/nir/tests/comparison_pre_tests.cpp +++ b/src/compiler/nir/tests/comparison_pre_tests.cpp @@ -473,6 +473,56 @@ TEST_F(comparison_pre_test, a_lt_neg_imm_vs_a_plus_imm) EXPECT_TRUE(nir_opt_comparison_pre_impl(bld.impl)); } +TEST_F(comparison_pre_test, swizzle_of_same_immediate_vector) +{ + /* Before: + * + * vec4 32 ssa_0 = load_const (-2.0, -1.0, 1.0, 2.0) + * vec4 32 ssa_1 = load_const ( 2.0, 1.0, -1.0, -2.0) + * vec4 32 ssa_2 = load_const ( 3.0, 4.0, 5.0, 6.0) + * vec4 32 ssa_3 = fadd ssa_0, ssa_2 + * vec1 1 ssa_4 = flt ssa_0.x, ssa_3.x + * + * if ssa_4 { + * vec1 32 ssa_5 = fadd ssa_0.w, ssa_3.x + * } else { + * } + */ + nir_ssa_def *a = nir_fadd(&bld, v1, v3); + + nir_alu_instr *flt = nir_alu_instr_create(bld.shader, nir_op_flt); + + flt->src[0].src = nir_src_for_ssa(v1); + flt->src[1].src = nir_src_for_ssa(a); + + memcpy(&flt->src[0].swizzle, xxxx, sizeof(xxxx)); + memcpy(&flt->src[1].swizzle, xxxx, sizeof(xxxx)); + + nir_builder_alu_instr_finish_and_insert(&bld, flt); + + flt->dest.dest.ssa.num_components = 1; + flt->dest.write_mask = 1; + + nir_if *nif = nir_push_if(&bld, &flt->dest.dest.ssa); + + nir_alu_instr *fadd = nir_alu_instr_create(bld.shader, nir_op_fadd); + + fadd->src[0].src = nir_src_for_ssa(v1); + fadd->src[1].src = nir_src_for_ssa(a); + + memcpy(&fadd->src[0].swizzle, wwww, sizeof(wwww)); + memcpy(&fadd->src[1].swizzle, xxxx, sizeof(xxxx)); + + nir_builder_alu_instr_finish_and_insert(&bld, fadd); + + fadd->dest.dest.ssa.num_components = 1; + fadd->dest.write_mask = 1; + + nir_pop_if(&bld, nif); + + EXPECT_TRUE(nir_opt_comparison_pre_impl(bld.impl)); +} + TEST_F(comparison_pre_test, non_scalar_add_result) { /* The optimization pass should not do anything because the result of the diff --git a/src/compiler/nir/tests/negative_equal_tests.cpp b/src/compiler/nir/tests/negative_equal_tests.cpp index 5e13c8fd28a..9fedb987166 100644 --- a/src/compiler/nir/tests/negative_equal_tests.cpp +++ b/src/compiler/nir/tests/negative_equal_tests.cpp @@ -270,6 +270,42 @@ compare_with_negation(nir_type_uint32) compare_with_negation(nir_type_int64) compare_with_negation(nir_type_uint64) +TEST_F(alu_srcs_negative_equal_test, swizzle_scalar_to_vector) +{ + nir_ssa_def *v = nir_imm_vec2(&bld, 1.0, -1.0); + const uint8_t s0[4] = { 0, 0, 0, 0 }; + const uint8_t s1[4] = { 1, 1, 1, 1 }; + + /* We can't use nir_swizzle here because it inserts an extra MOV. */ + nir_alu_instr *instr = nir_alu_instr_create(bld.shader, nir_op_fadd); + + instr->src[0].src = nir_src_for_ssa(v); + instr->src[1].src = nir_src_for_ssa(v); + + memcpy(&instr->src[0].swizzle, s0, sizeof(s0)); + memcpy(&instr->src[1].swizzle, s1, sizeof(s1)); + + nir_builder_alu_instr_finish_and_insert(&bld, instr); + + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); +} + +TEST_F(alu_srcs_negative_equal_test, unused_components_mismatch) +{ + nir_ssa_def *v1 = nir_imm_vec4(&bld, -2.0, 18.0, 43.0, 1.0); + nir_ssa_def *v2 = nir_imm_vec4(&bld, 2.0, 99.0, 76.0, -1.0); + + nir_ssa_def *result = nir_fadd(&bld, v1, v2); + + nir_alu_instr *instr = nir_instr_as_alu(result->parent_instr); + + /* Disable the channels that aren't negations of each other. */ + instr->dest.dest.is_ssa = false; + instr->dest.write_mask = 8 + 1; + + EXPECT_TRUE(nir_alu_srcs_negative_equal(instr, instr, 0, 1)); +} + static void count_sequence(nir_const_value c[NIR_MAX_VEC_COMPONENTS], nir_alu_type full_type, int first)