X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fcompiler%2Fnir%2Fnir_instr_set.c;h=cb0f2befa86d1482436724214485a7448da9e92e;hb=60097cc840e33af8506d7d4d621fefdca1a77695;hp=d106e9ebcae9e30432d01991128cf2ef5375879e;hpb=71c66c254b8021e2c01b1af9b4d16e18bbd26b48;p=mesa.git diff --git a/src/compiler/nir/nir_instr_set.c b/src/compiler/nir/nir_instr_set.c index d106e9ebcae..cb0f2befa86 100644 --- a/src/compiler/nir/nir_instr_set.c +++ b/src/compiler/nir/nir_instr_set.c @@ -23,6 +23,65 @@ #include "nir_instr_set.h" #include "nir_vla.h" +#include "util/half_float.h" + +static bool +src_is_ssa(nir_src *src, void *data) +{ + (void) data; + return src->is_ssa; +} + +static bool +dest_is_ssa(nir_dest *dest, void *data) +{ + (void) data; + return dest->is_ssa; +} + +static inline bool +instr_each_src_and_dest_is_ssa(const nir_instr *instr) +{ + if (!nir_foreach_dest((nir_instr *)instr, dest_is_ssa, NULL) || + !nir_foreach_src((nir_instr *)instr, src_is_ssa, NULL)) + return false; + + return true; +} + +/* This function determines if uses of an instruction can safely be rewritten + * to use another identical instruction instead. Note that this function must + * be kept in sync with hash_instr() and nir_instrs_equal() -- only + * instructions that pass this test will be handed on to those functions, and + * conversely they must handle everything that this function returns true for. + */ +static bool +instr_can_rewrite(const nir_instr *instr) +{ + /* We only handle SSA. */ + assert(instr_each_src_and_dest_is_ssa(instr)); + + switch (instr->type) { + case nir_instr_type_alu: + case nir_instr_type_deref: + case nir_instr_type_tex: + case nir_instr_type_load_const: + case nir_instr_type_phi: + return true; + case nir_instr_type_intrinsic: + return nir_intrinsic_can_reorder(nir_instr_as_intrinsic(instr)); + case nir_instr_type_call: + case nir_instr_type_jump: + case nir_instr_type_ssa_undef: + return false; + case nir_instr_type_parallel_copy: + default: + unreachable("Invalid instruction type"); + } + + return false; +} + #define HASH(hash, data) _mesa_fnv32_1a_accumulate((hash), (data)) @@ -51,12 +110,18 @@ static uint32_t hash_alu(uint32_t hash, const nir_alu_instr *instr) { hash = HASH(hash, instr->op); + + /* We explicitly don't hash instr->exact. */ + uint8_t flags = instr->no_signed_wrap | + instr->no_unsigned_wrap << 1; + hash = HASH(hash, flags); + hash = HASH(hash, instr->dest.dest.ssa.num_components); hash = HASH(hash, instr->dest.dest.ssa.bit_size); - /* We explicitly don't hash instr->dest.dest.exact */ - if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) { - assert(nir_op_infos[instr->op].num_inputs == 2); + if (nir_op_infos[instr->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) { + assert(nir_op_infos[instr->op].num_inputs >= 2); + uint32_t hash0 = hash_alu_src(hash, &instr->src[0], nir_ssa_alu_instr_src_components(instr, 0)); uint32_t hash1 = hash_alu_src(hash, &instr->src[1], @@ -68,6 +133,11 @@ hash_alu(uint32_t hash, const nir_alu_instr *instr) * collision. Either addition or multiplication will also work. */ hash = hash0 * hash1; + + for (unsigned i = 2; i < nir_op_infos[instr->op].num_inputs; i++) { + hash = hash_alu_src(hash, &instr->src[i], + nir_ssa_alu_instr_src_components(instr, i)); + } } else { for (unsigned i = 0; i < nir_op_infos[instr->op].num_inputs; i++) { hash = hash_alu_src(hash, &instr->src[i], @@ -123,12 +193,12 @@ hash_load_const(uint32_t hash, const nir_load_const_instr *instr) if (instr->def.bit_size == 1) { for (unsigned i = 0; i < instr->def.num_components; i++) { - uint8_t b = instr->value.b[i]; + uint8_t b = instr->value[i].b; hash = HASH(hash, b); } } else { - unsigned size = instr->def.num_components * (instr->def.bit_size / 8); - hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value.f32, size); + unsigned size = instr->def.num_components * sizeof(*instr->value); + hash = _mesa_fnv32_1a_accumulate_block(hash, instr->value, size); } return hash; @@ -206,6 +276,8 @@ hash_tex(uint32_t hash, const nir_tex_instr *instr) hash = HASH(hash, instr->texture_index); hash = HASH(hash, instr->texture_array_size); hash = HASH(hash, instr->sampler_index); + hash = HASH(hash, instr->texture_non_uniform); + hash = HASH(hash, instr->sampler_non_uniform); return hash; } @@ -275,6 +347,179 @@ nir_srcs_equal(nir_src src1, nir_src src2) } } +/** + * If the \p s is an SSA value that was generated by a negation instruction, + * that instruction is returned as a \c nir_alu_instr. Otherwise \c NULL is + * returned. + */ +static nir_alu_instr * +get_neg_instr(nir_src s) +{ + nir_alu_instr *alu = nir_src_as_alu_instr(s); + + return alu != NULL && (alu->op == nir_op_fneg || alu->op == nir_op_ineg) + ? alu : NULL; +} + +bool +nir_const_value_negative_equal(nir_const_value c1, + nir_const_value c2, + nir_alu_type full_type) +{ + assert(nir_alu_type_get_base_type(full_type) != nir_type_invalid); + assert(nir_alu_type_get_type_size(full_type) != 0); + + switch (full_type) { + case nir_type_float16: + return _mesa_half_to_float(c1.u16) == -_mesa_half_to_float(c2.u16); + + case nir_type_float32: + return c1.f32 == -c2.f32; + + case nir_type_float64: + return c1.f64 == -c2.f64; + + case nir_type_int8: + case nir_type_uint8: + return c1.i8 == -c2.i8; + + case nir_type_int16: + case nir_type_uint16: + return c1.i16 == -c2.i16; + + case nir_type_int32: + case nir_type_uint32: + return c1.i32 == -c2.i32; + + case nir_type_int64: + case nir_type_uint64: + return c1.i64 == -c2.i64; + + default: + break; + } + + return false; +} + +/** + * Shallow compare of ALU srcs to determine if one is the negation of the other + * + * This function detects cases where \p alu1 is a constant and \p alu2 is a + * constant that is its negation. It will also detect cases where \p alu2 is + * an SSA value that is a \c nir_op_fneg applied to \p alu1 (and vice versa). + * + * This function does not detect the general case when \p alu1 and \p alu2 are + * SSA values that are the negations of each other (e.g., \p alu1 represents + * (a * b) and \p alu2 represents (-a * b)). + * + * \warning + * It is the responsibility of the caller to ensure that the component counts, + * write masks, and base types of the sources being compared are compatible. + */ +bool +nir_alu_srcs_negative_equal(const nir_alu_instr *alu1, + const nir_alu_instr *alu2, + unsigned src1, unsigned src2) +{ +#ifndef NDEBUG + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) { + assert(nir_alu_instr_channel_used(alu1, src1, i) == + nir_alu_instr_channel_used(alu2, src2, i)); + } + + if (nir_op_infos[alu1->op].input_types[src1] == nir_type_float) { + assert(nir_op_infos[alu1->op].input_types[src1] == + nir_op_infos[alu2->op].input_types[src2]); + } else { + assert(nir_op_infos[alu1->op].input_types[src1] == nir_type_int); + assert(nir_op_infos[alu2->op].input_types[src2] == nir_type_int); + } +#endif + + if (alu1->src[src1].abs != alu2->src[src2].abs) + return false; + + bool parity = alu1->src[src1].negate != alu2->src[src2].negate; + + /* Handling load_const instructions is tricky. */ + + const nir_const_value *const const1 = + nir_src_as_const_value(alu1->src[src1].src); + + if (const1 != NULL) { + /* Assume that constant folding will eliminate source mods and unary + * ops. + */ + if (parity) + return false; + + const nir_const_value *const const2 = + nir_src_as_const_value(alu2->src[src2].src); + + if (const2 == NULL) + return false; + + if (nir_src_bit_size(alu1->src[src1].src) != + nir_src_bit_size(alu2->src[src2].src)) + return false; + + const nir_alu_type full_type = nir_op_infos[alu1->op].input_types[src1] | + nir_src_bit_size(alu1->src[src1].src); + for (unsigned i = 0; i < NIR_MAX_VEC_COMPONENTS; i++) { + if (nir_alu_instr_channel_used(alu1, src1, i) && + !nir_const_value_negative_equal(const1[alu1->src[src1].swizzle[i]], + const2[alu2->src[src2].swizzle[i]], + full_type)) + return false; + } + + return true; + } + + uint8_t alu1_swizzle[4] = {0}; + nir_src alu1_actual_src; + nir_alu_instr *neg1 = get_neg_instr(alu1->src[src1].src); + + if (neg1) { + parity = !parity; + alu1_actual_src = neg1->src[0].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg1, 0); i++) + alu1_swizzle[i] = neg1->src[0].swizzle[i]; + } else { + alu1_actual_src = alu1->src[src1].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++) + alu1_swizzle[i] = i; + } + + uint8_t alu2_swizzle[4] = {0}; + nir_src alu2_actual_src; + nir_alu_instr *neg2 = get_neg_instr(alu2->src[src2].src); + + if (neg2) { + parity = !parity; + alu2_actual_src = neg2->src[0].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(neg2, 0); i++) + alu2_swizzle[i] = neg2->src[0].swizzle[i]; + } else { + alu2_actual_src = alu2->src[src2].src; + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu2, src2); i++) + alu2_swizzle[i] = i; + } + + for (unsigned i = 0; i < nir_ssa_alu_instr_src_components(alu1, src1); i++) { + if (alu1_swizzle[alu1->src[src1].swizzle[i]] != + alu2_swizzle[alu2->src[src2].swizzle[i]]) + return false; + } + + return parity && nir_srcs_equal(alu1_actual_src, alu2_actual_src); +} + bool nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2, unsigned src1, unsigned src2) @@ -297,9 +542,11 @@ nir_alu_srcs_equal(const nir_alu_instr *alu1, const nir_alu_instr *alu2, * the same hash for (ignoring collisions, of course). */ -static bool +bool nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2) { + assert(instr_can_rewrite(instr1) && instr_can_rewrite(instr2)); + if (instr1->type != instr2->type) return false; @@ -311,6 +558,14 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2) if (alu1->op != alu2->op) return false; + /* We explicitly don't compare instr->exact. */ + + if (alu1->no_signed_wrap != alu2->no_signed_wrap) + return false; + + if (alu1->no_unsigned_wrap != alu2->no_unsigned_wrap) + return false; + /* TODO: We can probably acutally do something more inteligent such * as allowing different numbers and taking a maximum or something * here */ @@ -320,14 +575,17 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2) if (alu1->dest.dest.ssa.bit_size != alu2->dest.dest.ssa.bit_size) return false; - /* We explicitly don't hash instr->dest.dest.exact */ + if (nir_op_infos[alu1->op].algebraic_properties & NIR_OP_IS_2SRC_COMMUTATIVE) { + if ((!nir_alu_srcs_equal(alu1, alu2, 0, 0) || + !nir_alu_srcs_equal(alu1, alu2, 1, 1)) && + (!nir_alu_srcs_equal(alu1, alu2, 0, 1) || + !nir_alu_srcs_equal(alu1, alu2, 1, 0))) + return false; - if (nir_op_infos[alu1->op].algebraic_properties & NIR_OP_IS_COMMUTATIVE) { - assert(nir_op_infos[alu1->op].num_inputs == 2); - return (nir_alu_srcs_equal(alu1, alu2, 0, 0) && - nir_alu_srcs_equal(alu1, alu2, 1, 1)) || - (nir_alu_srcs_equal(alu1, alu2, 0, 1) && - nir_alu_srcs_equal(alu1, alu2, 1, 0)); + for (unsigned i = 2; i < nir_op_infos[alu1->op].num_inputs; i++) { + if (!nir_alu_srcs_equal(alu1, alu2, i, i)) + return false; + } } else { for (unsigned i = 0; i < nir_op_infos[alu1->op].num_inputs; i++) { if (!nir_alu_srcs_equal(alu1, alu2, i, i)) @@ -423,12 +681,16 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2) return false; if (load1->def.bit_size == 1) { - unsigned size = load1->def.num_components * sizeof(bool); - return memcmp(load1->value.b, load2->value.b, size) == 0; + for (unsigned i = 0; i < load1->def.num_components; ++i) { + if (load1->value[i].b != load2->value[i].b) + return false; + } } else { - unsigned size = load1->def.num_components * (load1->def.bit_size / 8); - return memcmp(load1->value.f32, load2->value.f32, size) == 0; + unsigned size = load1->def.num_components * sizeof(*load1->value); + if (memcmp(load1->value, load2->value, size) != 0) + return false; } + return true; } case nir_instr_type_phi: { nir_phi_instr *phi1 = nir_instr_as_phi(instr1); @@ -491,68 +753,6 @@ nir_instrs_equal(const nir_instr *instr1, const nir_instr *instr2) unreachable("All cases in the above switch should return"); } -static bool -src_is_ssa(nir_src *src, void *data) -{ - (void) data; - return src->is_ssa; -} - -static bool -dest_is_ssa(nir_dest *dest, void *data) -{ - (void) data; - return dest->is_ssa; -} - -static inline bool -instr_each_src_and_dest_is_ssa(nir_instr *instr) -{ - if (!nir_foreach_dest(instr, dest_is_ssa, NULL) || - !nir_foreach_src(instr, src_is_ssa, NULL)) - return false; - - return true; -} - -/* This function determines if uses of an instruction can safely be rewritten - * to use another identical instruction instead. Note that this function must - * be kept in sync with hash_instr() and nir_instrs_equal() -- only - * instructions that pass this test will be handed on to those functions, and - * conversely they must handle everything that this function returns true for. - */ - -static bool -instr_can_rewrite(nir_instr *instr) -{ - /* We only handle SSA. */ - assert(instr_each_src_and_dest_is_ssa(instr)); - - switch (instr->type) { - case nir_instr_type_alu: - case nir_instr_type_deref: - case nir_instr_type_tex: - case nir_instr_type_load_const: - case nir_instr_type_phi: - return true; - case nir_instr_type_intrinsic: { - const nir_intrinsic_info *info = - &nir_intrinsic_infos[nir_instr_as_intrinsic(instr)->intrinsic]; - return (info->flags & NIR_INTRINSIC_CAN_ELIMINATE) && - (info->flags & NIR_INTRINSIC_CAN_REORDER); - } - case nir_instr_type_call: - case nir_instr_type_jump: - case nir_instr_type_ssa_undef: - return false; - case nir_instr_type_parallel_copy: - default: - unreachable("Invalid instruction type"); - } - - return false; -} - static nir_ssa_def * nir_instr_get_dest_ssa_def(nir_instr *instr) { @@ -603,11 +803,10 @@ nir_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr) if (!instr_can_rewrite(instr)) return false; - uint32_t hash = hash_instr(instr); - struct set_entry *e = _mesa_set_search_pre_hashed(instr_set, hash, instr); - if (e) { + struct set_entry *e = _mesa_set_search_or_add(instr_set, instr); + nir_instr *match = (nir_instr *) e->key; + if (match != instr) { nir_ssa_def *def = nir_instr_get_dest_ssa_def(instr); - nir_instr *match = (nir_instr *) e->key; nir_ssa_def *new_def = nir_instr_get_dest_ssa_def(match); /* It's safe to replace an exact instruction with an inexact one as @@ -622,7 +821,6 @@ nir_instr_set_add_or_rewrite(struct set *instr_set, nir_instr *instr) return true; } - _mesa_set_add_pre_hashed(instr_set, hash, instr); return false; }