From 393562f47b3a0a5f06f3c031e8e777dab6e6cc97 Mon Sep 17 00:00:00 2001 From: Jason Ekstrand Date: Fri, 8 Jan 2016 11:02:17 -0800 Subject: [PATCH] nir/spirv: Split ALU operations out into their own file --- src/glsl/Makefile.sources | 4 +- src/glsl/nir/spirv/spirv_to_nir.c | 423 +----------------------------- src/glsl/nir/spirv/vtn_alu.c | 420 +++++++++++++++++++++++++++++ src/glsl/nir/spirv/vtn_private.h | 9 + 4 files changed, 438 insertions(+), 418 deletions(-) create mode 100644 src/glsl/nir/spirv/vtn_alu.c diff --git a/src/glsl/Makefile.sources b/src/glsl/Makefile.sources index 97fac8609b6..89113bcd1c6 100644 --- a/src/glsl/Makefile.sources +++ b/src/glsl/Makefile.sources @@ -95,8 +95,10 @@ NIR_FILES = \ SPIRV_FILES = \ nir/spirv/nir_spirv.h \ nir/spirv/spirv_to_nir.c \ + nir/spirv/vtn_alu.c \ nir/spirv/vtn_cfg.c \ - nir/spirv/vtn_glsl450.c + nir/spirv/vtn_glsl450.c \ + nir/spirv/vtn_private.h # libglsl diff --git a/src/glsl/nir/spirv/spirv_to_nir.c b/src/glsl/nir/spirv/spirv_to_nir.c index 919be098fbb..191d35d9102 100644 --- a/src/glsl/nir/spirv/spirv_to_nir.c +++ b/src/glsl/nir/spirv/spirv_to_nir.c @@ -1327,9 +1327,6 @@ _vtn_load_store_tail(struct vtn_builder *b, nir_intrinsic_op op, bool load, (*inout)->def = nir_ine(&b->nb, (*inout)->def, nir_imm_int(&b->nb, 0)); } -static struct vtn_ssa_value * -vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src); - static void _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load, nir_ssa_def *index, nir_ssa_def *offset, nir_deref *deref, @@ -1365,7 +1362,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load, (*inout)->type = glsl_matrix_type(base_type, vec_width, num_ops); } else { - transpose = vtn_transpose(b, *inout); + transpose = vtn_ssa_transpose(b, *inout); inout = &transpose; } } else { @@ -1383,7 +1380,7 @@ _vtn_block_load_store(struct vtn_builder *b, nir_intrinsic_op op, bool load, } if (load && type->row_major) - *inout = vtn_transpose(b, *inout); + *inout = vtn_ssa_transpose(b, *inout); return; } else if (type->row_major) { @@ -2074,7 +2071,7 @@ vtn_handle_function_call(struct vtn_builder *b, SpvOp opcode, } } -static struct vtn_ssa_value * +struct vtn_ssa_value * vtn_create_ssa_value(struct vtn_builder *b, const struct glsl_type *type) { struct vtn_ssa_value *val = rzalloc(b, struct vtn_ssa_value); @@ -2598,8 +2595,8 @@ create_vec(nir_shader *shader, unsigned num_components) return vec; } -static struct vtn_ssa_value * -vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src) +struct vtn_ssa_value * +vtn_ssa_transpose(struct vtn_builder *b, struct vtn_ssa_value *src) { if (src->transposed) return src->transposed; @@ -2628,411 +2625,6 @@ vtn_transpose(struct vtn_builder *b, struct vtn_ssa_value *src) return dest; } -/* - * Normally, column vectors in SPIR-V correspond to a single NIR SSA - * definition. But for matrix multiplies, we want to do one routine for - * multiplying a matrix by a matrix and then pretend that vectors are matrices - * with one column. So we "wrap" these things, and unwrap the result before we - * send it off. - */ - -static struct vtn_ssa_value * -vtn_wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) -{ - if (val == NULL) - return NULL; - - if (glsl_type_is_matrix(val->type)) - return val; - - struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); - dest->type = val->type; - dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); - dest->elems[0] = val; - - return dest; -} - -static struct vtn_ssa_value * -vtn_unwrap_matrix(struct vtn_ssa_value *val) -{ - if (glsl_type_is_matrix(val->type)) - return val; - - return val->elems[0]; -} - -static struct vtn_ssa_value * -vtn_matrix_multiply(struct vtn_builder *b, - struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1) -{ - - struct vtn_ssa_value *src0 = vtn_wrap_matrix(b, _src0); - struct vtn_ssa_value *src1 = vtn_wrap_matrix(b, _src1); - struct vtn_ssa_value *src0_transpose = vtn_wrap_matrix(b, _src0->transposed); - struct vtn_ssa_value *src1_transpose = vtn_wrap_matrix(b, _src1->transposed); - - unsigned src0_rows = glsl_get_vector_elements(src0->type); - unsigned src0_columns = glsl_get_matrix_columns(src0->type); - unsigned src1_columns = glsl_get_matrix_columns(src1->type); - - const struct glsl_type *dest_type; - if (src1_columns > 1) { - dest_type = glsl_matrix_type(glsl_get_base_type(src0->type), - src0_rows, src1_columns); - } else { - dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows); - } - struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); - - dest = vtn_wrap_matrix(b, dest); - - bool transpose_result = false; - if (src0_transpose && src1_transpose) { - /* transpose(A) * transpose(B) = transpose(B * A) */ - src1 = src0_transpose; - src0 = src1_transpose; - src0_transpose = NULL; - src1_transpose = NULL; - transpose_result = true; - } - - if (src0_transpose && !src1_transpose && - glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { - /* We already have the rows of src0 and the columns of src1 available, - * so we can just take the dot product of each row with each column to - * get the result. - */ - - for (unsigned i = 0; i < src1_columns; i++) { - nir_alu_instr *vec = create_vec(b->shader, src0_rows); - for (unsigned j = 0; j < src0_rows; j++) { - vec->src[j].src = - nir_src_for_ssa(nir_fdot(&b->nb, src0_transpose->elems[j]->def, - src1->elems[i]->def)); - } - - nir_builder_instr_insert(&b->nb, &vec->instr); - dest->elems[i]->def = &vec->dest.dest.ssa; - } - } else { - /* We don't handle the case where src1 is transposed but not src0, since - * the general case only uses individual components of src1 so the - * optimizer should chew through the transpose we emitted for src1. - */ - - for (unsigned i = 0; i < src1_columns; i++) { - /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ - dest->elems[i]->def = - nir_fmul(&b->nb, src0->elems[0]->def, - vtn_vector_extract(b, src1->elems[i]->def, 0)); - for (unsigned j = 1; j < src0_columns; j++) { - dest->elems[i]->def = - nir_fadd(&b->nb, dest->elems[i]->def, - nir_fmul(&b->nb, src0->elems[j]->def, - vtn_vector_extract(b, - src1->elems[i]->def, j))); - } - } - } - - dest = vtn_unwrap_matrix(dest); - - if (transpose_result) - dest = vtn_transpose(b, dest); - - return dest; -} - -static struct vtn_ssa_value * -vtn_mat_times_scalar(struct vtn_builder *b, - struct vtn_ssa_value *mat, - nir_ssa_def *scalar) -{ - struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); - for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { - if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT) - dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); - else - dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); - } - - return dest; -} - -static void -vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, - const uint32_t *w, unsigned count) -{ - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - - switch (opcode) { - case SpvOpTranspose: { - struct vtn_ssa_value *src = vtn_ssa_value(b, w[3]); - val->ssa = vtn_transpose(b, src); - break; - } - - case SpvOpOuterProduct: { - struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]); - struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]); - - val->ssa = vtn_matrix_multiply(b, src0, vtn_transpose(b, src1)); - break; - } - - case SpvOpMatrixTimesScalar: { - struct vtn_ssa_value *mat = vtn_ssa_value(b, w[3]); - struct vtn_ssa_value *scalar = vtn_ssa_value(b, w[4]); - - if (mat->transposed) { - val->ssa = vtn_transpose(b, vtn_mat_times_scalar(b, mat->transposed, - scalar->def)); - } else { - val->ssa = vtn_mat_times_scalar(b, mat, scalar->def); - } - break; - } - - case SpvOpVectorTimesMatrix: - case SpvOpMatrixTimesVector: - case SpvOpMatrixTimesMatrix: { - struct vtn_ssa_value *src0 = vtn_ssa_value(b, w[3]); - struct vtn_ssa_value *src1 = vtn_ssa_value(b, w[4]); - - if (opcode == SpvOpVectorTimesMatrix) { - val->ssa = vtn_matrix_multiply(b, vtn_transpose(b, src1), src0); - } else { - val->ssa = vtn_matrix_multiply(b, src0, src1); - } - break; - } - - default: unreachable("unknown matrix opcode"); - } -} - -static void -vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, - const uint32_t *w, unsigned count) -{ - struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); - const struct glsl_type *type = - vtn_value(b, w[1], vtn_value_type_type)->type->type; - val->ssa = vtn_create_ssa_value(b, type); - - /* Collect the various SSA sources */ - const unsigned num_inputs = count - 3; - nir_ssa_def *src[4]; - for (unsigned i = 0; i < num_inputs; i++) - src[i] = vtn_ssa_value(b, w[i + 3])->def; - for (unsigned i = num_inputs; i < 4; i++) - src[i] = NULL; - - /* Indicates that the first two arguments should be swapped. This is - * used for implementing greater-than and less-than-or-equal. - */ - bool swap = false; - - nir_op op; - switch (opcode) { - /* Basic ALU operations */ - case SpvOpSNegate: op = nir_op_ineg; break; - case SpvOpFNegate: op = nir_op_fneg; break; - case SpvOpNot: op = nir_op_inot; break; - - case SpvOpAny: - if (src[0]->num_components == 1) { - op = nir_op_imov; - } else { - switch (src[0]->num_components) { - case 2: op = nir_op_bany_inequal2; break; - case 3: op = nir_op_bany_inequal3; break; - case 4: op = nir_op_bany_inequal4; break; - } - src[1] = nir_imm_int(&b->nb, NIR_FALSE); - } - break; - - case SpvOpAll: - if (src[0]->num_components == 1) { - op = nir_op_imov; - } else { - switch (src[0]->num_components) { - case 2: op = nir_op_ball_iequal2; break; - case 3: op = nir_op_ball_iequal3; break; - case 4: op = nir_op_ball_iequal4; break; - } - src[1] = nir_imm_int(&b->nb, NIR_TRUE); - } - break; - - case SpvOpIAdd: op = nir_op_iadd; break; - case SpvOpFAdd: op = nir_op_fadd; break; - case SpvOpISub: op = nir_op_isub; break; - case SpvOpFSub: op = nir_op_fsub; break; - case SpvOpIMul: op = nir_op_imul; break; - case SpvOpFMul: op = nir_op_fmul; break; - case SpvOpUDiv: op = nir_op_udiv; break; - case SpvOpSDiv: op = nir_op_idiv; break; - case SpvOpFDiv: op = nir_op_fdiv; break; - case SpvOpUMod: op = nir_op_umod; break; - case SpvOpSMod: op = nir_op_umod; break; /* FIXME? */ - case SpvOpFMod: op = nir_op_fmod; break; - - case SpvOpDot: - assert(src[0]->num_components == src[1]->num_components); - switch (src[0]->num_components) { - case 1: op = nir_op_fmul; break; - case 2: op = nir_op_fdot2; break; - case 3: op = nir_op_fdot3; break; - case 4: op = nir_op_fdot4; break; - } - break; - - case SpvOpIAddCarry: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = - nir_b2i(&b->nb, nir_uadd_carry(&b->nb, src[0], src[1])); - return; - - case SpvOpISubBorrow: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = - nir_b2i(&b->nb, nir_usub_borrow(&b->nb, src[0], src[1])); - return; - - case SpvOpUMulExtended: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); - return; - - case SpvOpSMulExtended: - assert(glsl_type_is_struct(val->ssa->type)); - val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); - val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); - return; - - case SpvOpShiftRightLogical: op = nir_op_ushr; break; - case SpvOpShiftRightArithmetic: op = nir_op_ishr; break; - case SpvOpShiftLeftLogical: op = nir_op_ishl; break; - case SpvOpLogicalOr: op = nir_op_ior; break; - case SpvOpLogicalEqual: op = nir_op_ieq; break; - case SpvOpLogicalNotEqual: op = nir_op_ine; break; - case SpvOpLogicalAnd: op = nir_op_iand; break; - case SpvOpLogicalNot: op = nir_op_inot; break; - case SpvOpBitwiseOr: op = nir_op_ior; break; - case SpvOpBitwiseXor: op = nir_op_ixor; break; - case SpvOpBitwiseAnd: op = nir_op_iand; break; - case SpvOpSelect: op = nir_op_bcsel; break; - case SpvOpIEqual: op = nir_op_ieq; break; - - case SpvOpBitFieldInsert: op = nir_op_bitfield_insert; break; - case SpvOpBitFieldSExtract: op = nir_op_ibitfield_extract; break; - case SpvOpBitFieldUExtract: op = nir_op_ubitfield_extract; break; - case SpvOpBitReverse: op = nir_op_bitfield_reverse; break; - case SpvOpBitCount: op = nir_op_bit_count; break; - - /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */ - case SpvOpFOrdEqual: op = nir_op_feq; break; - case SpvOpFUnordEqual: op = nir_op_feq; break; - case SpvOpINotEqual: op = nir_op_ine; break; - case SpvOpFOrdNotEqual: op = nir_op_fne; break; - case SpvOpFUnordNotEqual: op = nir_op_fne; break; - case SpvOpULessThan: op = nir_op_ult; break; - case SpvOpSLessThan: op = nir_op_ilt; break; - case SpvOpFOrdLessThan: op = nir_op_flt; break; - case SpvOpFUnordLessThan: op = nir_op_flt; break; - case SpvOpUGreaterThan: op = nir_op_ult; swap = true; break; - case SpvOpSGreaterThan: op = nir_op_ilt; swap = true; break; - case SpvOpFOrdGreaterThan: op = nir_op_flt; swap = true; break; - case SpvOpFUnordGreaterThan: op = nir_op_flt; swap = true; break; - case SpvOpULessThanEqual: op = nir_op_uge; swap = true; break; - case SpvOpSLessThanEqual: op = nir_op_ige; swap = true; break; - case SpvOpFOrdLessThanEqual: op = nir_op_fge; swap = true; break; - case SpvOpFUnordLessThanEqual: op = nir_op_fge; swap = true; break; - case SpvOpUGreaterThanEqual: op = nir_op_uge; break; - case SpvOpSGreaterThanEqual: op = nir_op_ige; break; - case SpvOpFOrdGreaterThanEqual: op = nir_op_fge; break; - case SpvOpFUnordGreaterThanEqual:op = nir_op_fge; break; - - /* Conversions: */ - case SpvOpConvertFToU: op = nir_op_f2u; break; - case SpvOpConvertFToS: op = nir_op_f2i; break; - case SpvOpConvertSToF: op = nir_op_i2f; break; - case SpvOpConvertUToF: op = nir_op_u2f; break; - case SpvOpBitcast: op = nir_op_imov; break; - case SpvOpUConvert: - case SpvOpSConvert: - op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */ - break; - case SpvOpFConvert: - op = nir_op_fmov; - break; - - /* Derivatives: */ - case SpvOpDPdx: op = nir_op_fddx; break; - case SpvOpDPdy: op = nir_op_fddy; break; - case SpvOpDPdxFine: op = nir_op_fddx_fine; break; - case SpvOpDPdyFine: op = nir_op_fddy_fine; break; - case SpvOpDPdxCoarse: op = nir_op_fddx_coarse; break; - case SpvOpDPdyCoarse: op = nir_op_fddy_coarse; break; - case SpvOpFwidth: - val->ssa->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), - nir_fabs(&b->nb, nir_fddx(&b->nb, src[1]))); - return; - case SpvOpFwidthFine: - val->ssa->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), - nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1]))); - return; - case SpvOpFwidthCoarse: - val->ssa->def = nir_fadd(&b->nb, - nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), - nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1]))); - return; - - case SpvOpVectorTimesScalar: - /* The builder will take care of splatting for us. */ - val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); - return; - - case SpvOpSRem: - case SpvOpFRem: - unreachable("No NIR equivalent"); - - case SpvOpIsNan: - val->ssa->def = nir_fne(&b->nb, src[0], src[0]); - return; - - case SpvOpIsInf: - val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]), - nir_imm_float(&b->nb, INFINITY)); - return; - - case SpvOpIsFinite: - case SpvOpIsNormal: - case SpvOpSignBitSet: - case SpvOpLessOrGreater: - case SpvOpOrdered: - case SpvOpUnordered: - default: - unreachable("Unhandled opcode"); - } - - if (swap) { - nir_ssa_def *tmp = src[0]; - src[0] = src[1]; - src[1] = tmp; - } - - val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); -} - static nir_ssa_def * vtn_vector_extract(struct vtn_builder *b, nir_ssa_def *src, unsigned index) { @@ -3835,16 +3427,13 @@ vtn_handle_body_instruction(struct vtn_builder *b, SpvOp opcode, case SpvOpBitFieldUExtract: case SpvOpBitReverse: case SpvOpBitCount: - vtn_handle_alu(b, opcode, w, count); - break; - case SpvOpTranspose: case SpvOpOuterProduct: case SpvOpMatrixTimesScalar: case SpvOpVectorTimesMatrix: case SpvOpMatrixTimesVector: case SpvOpMatrixTimesMatrix: - vtn_handle_matrix_alu(b, opcode, w, count); + vtn_handle_alu(b, opcode, w, count); break; case SpvOpVectorExtractDynamic: diff --git a/src/glsl/nir/spirv/vtn_alu.c b/src/glsl/nir/spirv/vtn_alu.c new file mode 100644 index 00000000000..a8c6e5cd890 --- /dev/null +++ b/src/glsl/nir/spirv/vtn_alu.c @@ -0,0 +1,420 @@ +/* + * Copyright © 2016 Intel Corporation + * + * Permission is hereby granted, free of charge, to any person obtaining a + * copy of this software and associated documentation files (the "Software"), + * to deal in the Software without restriction, including without limitation + * the rights to use, copy, modify, merge, publish, distribute, sublicense, + * and/or sell copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice (including the next + * paragraph) shall be included in all copies or substantial portions of the + * Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR + * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, + * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL + * THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER + * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS + * IN THE SOFTWARE. + */ + +#include "vtn_private.h" + +/* + * Normally, column vectors in SPIR-V correspond to a single NIR SSA + * definition. But for matrix multiplies, we want to do one routine for + * multiplying a matrix by a matrix and then pretend that vectors are matrices + * with one column. So we "wrap" these things, and unwrap the result before we + * send it off. + */ + +static struct vtn_ssa_value * +wrap_matrix(struct vtn_builder *b, struct vtn_ssa_value *val) +{ + if (val == NULL) + return NULL; + + if (glsl_type_is_matrix(val->type)) + return val; + + struct vtn_ssa_value *dest = rzalloc(b, struct vtn_ssa_value); + dest->type = val->type; + dest->elems = ralloc_array(b, struct vtn_ssa_value *, 1); + dest->elems[0] = val; + + return dest; +} + +static struct vtn_ssa_value * +unwrap_matrix(struct vtn_ssa_value *val) +{ + if (glsl_type_is_matrix(val->type)) + return val; + + return val->elems[0]; +} + +static struct vtn_ssa_value * +matrix_multiply(struct vtn_builder *b, + struct vtn_ssa_value *_src0, struct vtn_ssa_value *_src1) +{ + + struct vtn_ssa_value *src0 = wrap_matrix(b, _src0); + struct vtn_ssa_value *src1 = wrap_matrix(b, _src1); + struct vtn_ssa_value *src0_transpose = wrap_matrix(b, _src0->transposed); + struct vtn_ssa_value *src1_transpose = wrap_matrix(b, _src1->transposed); + + unsigned src0_rows = glsl_get_vector_elements(src0->type); + unsigned src0_columns = glsl_get_matrix_columns(src0->type); + unsigned src1_columns = glsl_get_matrix_columns(src1->type); + + const struct glsl_type *dest_type; + if (src1_columns > 1) { + dest_type = glsl_matrix_type(glsl_get_base_type(src0->type), + src0_rows, src1_columns); + } else { + dest_type = glsl_vector_type(glsl_get_base_type(src0->type), src0_rows); + } + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, dest_type); + + dest = wrap_matrix(b, dest); + + bool transpose_result = false; + if (src0_transpose && src1_transpose) { + /* transpose(A) * transpose(B) = transpose(B * A) */ + src1 = src0_transpose; + src0 = src1_transpose; + src0_transpose = NULL; + src1_transpose = NULL; + transpose_result = true; + } + + if (src0_transpose && !src1_transpose && + glsl_get_base_type(src0->type) == GLSL_TYPE_FLOAT) { + /* We already have the rows of src0 and the columns of src1 available, + * so we can just take the dot product of each row with each column to + * get the result. + */ + + for (unsigned i = 0; i < src1_columns; i++) { + nir_ssa_def *vec_src[4]; + for (unsigned j = 0; j < src0_rows; j++) { + vec_src[j] = nir_fdot(&b->nb, src0_transpose->elems[j]->def, + src1->elems[i]->def); + } + dest->elems[i]->def = nir_vec(&b->nb, vec_src, src0_rows); + } + } else { + /* We don't handle the case where src1 is transposed but not src0, since + * the general case only uses individual components of src1 so the + * optimizer should chew through the transpose we emitted for src1. + */ + + for (unsigned i = 0; i < src1_columns; i++) { + /* dest[i] = sum(src0[j] * src1[i][j] for all j) */ + dest->elems[i]->def = + nir_fmul(&b->nb, src0->elems[0]->def, + nir_channel(&b->nb, src1->elems[i]->def, 0)); + for (unsigned j = 1; j < src0_columns; j++) { + dest->elems[i]->def = + nir_fadd(&b->nb, dest->elems[i]->def, + nir_fmul(&b->nb, src0->elems[j]->def, + nir_channel(&b->nb, src1->elems[i]->def, j))); + } + } + } + + dest = unwrap_matrix(dest); + + if (transpose_result) + dest = vtn_ssa_transpose(b, dest); + + return dest; +} + +static struct vtn_ssa_value * +mat_times_scalar(struct vtn_builder *b, + struct vtn_ssa_value *mat, + nir_ssa_def *scalar) +{ + struct vtn_ssa_value *dest = vtn_create_ssa_value(b, mat->type); + for (unsigned i = 0; i < glsl_get_matrix_columns(mat->type); i++) { + if (glsl_get_base_type(mat->type) == GLSL_TYPE_FLOAT) + dest->elems[i]->def = nir_fmul(&b->nb, mat->elems[i]->def, scalar); + else + dest->elems[i]->def = nir_imul(&b->nb, mat->elems[i]->def, scalar); + } + + return dest; +} + +static void +vtn_handle_matrix_alu(struct vtn_builder *b, SpvOp opcode, + struct vtn_value *dest, + struct vtn_ssa_value *src0, struct vtn_ssa_value *src1) +{ + switch (opcode) { + case SpvOpTranspose: + dest->ssa = vtn_ssa_transpose(b, src0); + break; + + case SpvOpOuterProduct: + dest->ssa = matrix_multiply(b, src0, vtn_ssa_transpose(b, src1)); + break; + + case SpvOpMatrixTimesScalar: + if (src0->transposed) { + dest->ssa = vtn_ssa_transpose(b, mat_times_scalar(b, src0->transposed, + src1->def)); + } else { + dest->ssa = mat_times_scalar(b, src0, src1->def); + } + break; + + case SpvOpVectorTimesMatrix: + case SpvOpMatrixTimesVector: + case SpvOpMatrixTimesMatrix: + if (opcode == SpvOpVectorTimesMatrix) { + dest->ssa = matrix_multiply(b, vtn_ssa_transpose(b, src1), src0); + } else { + dest->ssa = matrix_multiply(b, src0, src1); + } + break; + + default: unreachable("unknown matrix opcode"); + } +} + +void +vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count) +{ + struct vtn_value *val = vtn_push_value(b, w[2], vtn_value_type_ssa); + const struct glsl_type *type = + vtn_value(b, w[1], vtn_value_type_type)->type->type; + + /* Collect the various SSA sources */ + const unsigned num_inputs = count - 3; + struct vtn_ssa_value *vtn_src[4] = { NULL, }; + for (unsigned i = 0; i < num_inputs; i++) + vtn_src[i] = vtn_ssa_value(b, w[i + 3]); + + if (glsl_type_is_matrix(vtn_src[0]->type) || + (num_inputs >= 2 && glsl_type_is_matrix(vtn_src[1]->type))) { + vtn_handle_matrix_alu(b, opcode, val, vtn_src[0], vtn_src[1]); + return; + } + + val->ssa = vtn_create_ssa_value(b, type); + nir_ssa_def *src[4] = { NULL, }; + for (unsigned i = 0; i < num_inputs; i++) { + assert(glsl_type_is_vector_or_scalar(vtn_src[i]->type)); + src[i] = vtn_src[i]->def; + } + + /* Indicates that the first two arguments should be swapped. This is + * used for implementing greater-than and less-than-or-equal. + */ + bool swap = false; + + nir_op op; + switch (opcode) { + /* Basic ALU operations */ + case SpvOpSNegate: op = nir_op_ineg; break; + case SpvOpFNegate: op = nir_op_fneg; break; + case SpvOpNot: op = nir_op_inot; break; + + case SpvOpAny: + if (src[0]->num_components == 1) { + op = nir_op_imov; + } else { + switch (src[0]->num_components) { + case 2: op = nir_op_bany_inequal2; break; + case 3: op = nir_op_bany_inequal3; break; + case 4: op = nir_op_bany_inequal4; break; + } + src[1] = nir_imm_int(&b->nb, NIR_FALSE); + } + break; + + case SpvOpAll: + if (src[0]->num_components == 1) { + op = nir_op_imov; + } else { + switch (src[0]->num_components) { + case 2: op = nir_op_ball_iequal2; break; + case 3: op = nir_op_ball_iequal3; break; + case 4: op = nir_op_ball_iequal4; break; + } + src[1] = nir_imm_int(&b->nb, NIR_TRUE); + } + break; + + case SpvOpIAdd: op = nir_op_iadd; break; + case SpvOpFAdd: op = nir_op_fadd; break; + case SpvOpISub: op = nir_op_isub; break; + case SpvOpFSub: op = nir_op_fsub; break; + case SpvOpIMul: op = nir_op_imul; break; + case SpvOpFMul: op = nir_op_fmul; break; + case SpvOpUDiv: op = nir_op_udiv; break; + case SpvOpSDiv: op = nir_op_idiv; break; + case SpvOpFDiv: op = nir_op_fdiv; break; + case SpvOpUMod: op = nir_op_umod; break; + case SpvOpSMod: op = nir_op_umod; break; /* FIXME? */ + case SpvOpFMod: op = nir_op_fmod; break; + + case SpvOpDot: + assert(src[0]->num_components == src[1]->num_components); + switch (src[0]->num_components) { + case 1: op = nir_op_fmul; break; + case 2: op = nir_op_fdot2; break; + case 3: op = nir_op_fdot3; break; + case 4: op = nir_op_fdot4; break; + } + break; + + case SpvOpIAddCarry: + assert(glsl_type_is_struct(val->ssa->type)); + val->ssa->elems[0]->def = nir_iadd(&b->nb, src[0], src[1]); + val->ssa->elems[1]->def = + nir_b2i(&b->nb, nir_uadd_carry(&b->nb, src[0], src[1])); + return; + + case SpvOpISubBorrow: + assert(glsl_type_is_struct(val->ssa->type)); + val->ssa->elems[0]->def = nir_isub(&b->nb, src[0], src[1]); + val->ssa->elems[1]->def = + nir_b2i(&b->nb, nir_usub_borrow(&b->nb, src[0], src[1])); + return; + + case SpvOpUMulExtended: + assert(glsl_type_is_struct(val->ssa->type)); + val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); + val->ssa->elems[1]->def = nir_umul_high(&b->nb, src[0], src[1]); + return; + + case SpvOpSMulExtended: + assert(glsl_type_is_struct(val->ssa->type)); + val->ssa->elems[0]->def = nir_imul(&b->nb, src[0], src[1]); + val->ssa->elems[1]->def = nir_imul_high(&b->nb, src[0], src[1]); + return; + + case SpvOpShiftRightLogical: op = nir_op_ushr; break; + case SpvOpShiftRightArithmetic: op = nir_op_ishr; break; + case SpvOpShiftLeftLogical: op = nir_op_ishl; break; + case SpvOpLogicalOr: op = nir_op_ior; break; + case SpvOpLogicalEqual: op = nir_op_ieq; break; + case SpvOpLogicalNotEqual: op = nir_op_ine; break; + case SpvOpLogicalAnd: op = nir_op_iand; break; + case SpvOpLogicalNot: op = nir_op_inot; break; + case SpvOpBitwiseOr: op = nir_op_ior; break; + case SpvOpBitwiseXor: op = nir_op_ixor; break; + case SpvOpBitwiseAnd: op = nir_op_iand; break; + case SpvOpSelect: op = nir_op_bcsel; break; + case SpvOpIEqual: op = nir_op_ieq; break; + + case SpvOpBitFieldInsert: op = nir_op_bitfield_insert; break; + case SpvOpBitFieldSExtract: op = nir_op_ibitfield_extract; break; + case SpvOpBitFieldUExtract: op = nir_op_ubitfield_extract; break; + case SpvOpBitReverse: op = nir_op_bitfield_reverse; break; + case SpvOpBitCount: op = nir_op_bit_count; break; + + /* Comparisons: (TODO: How do we want to handled ordered/unordered?) */ + case SpvOpFOrdEqual: op = nir_op_feq; break; + case SpvOpFUnordEqual: op = nir_op_feq; break; + case SpvOpINotEqual: op = nir_op_ine; break; + case SpvOpFOrdNotEqual: op = nir_op_fne; break; + case SpvOpFUnordNotEqual: op = nir_op_fne; break; + case SpvOpULessThan: op = nir_op_ult; break; + case SpvOpSLessThan: op = nir_op_ilt; break; + case SpvOpFOrdLessThan: op = nir_op_flt; break; + case SpvOpFUnordLessThan: op = nir_op_flt; break; + case SpvOpUGreaterThan: op = nir_op_ult; swap = true; break; + case SpvOpSGreaterThan: op = nir_op_ilt; swap = true; break; + case SpvOpFOrdGreaterThan: op = nir_op_flt; swap = true; break; + case SpvOpFUnordGreaterThan: op = nir_op_flt; swap = true; break; + case SpvOpULessThanEqual: op = nir_op_uge; swap = true; break; + case SpvOpSLessThanEqual: op = nir_op_ige; swap = true; break; + case SpvOpFOrdLessThanEqual: op = nir_op_fge; swap = true; break; + case SpvOpFUnordLessThanEqual: op = nir_op_fge; swap = true; break; + case SpvOpUGreaterThanEqual: op = nir_op_uge; break; + case SpvOpSGreaterThanEqual: op = nir_op_ige; break; + case SpvOpFOrdGreaterThanEqual: op = nir_op_fge; break; + case SpvOpFUnordGreaterThanEqual:op = nir_op_fge; break; + + /* Conversions: */ + case SpvOpConvertFToU: op = nir_op_f2u; break; + case SpvOpConvertFToS: op = nir_op_f2i; break; + case SpvOpConvertSToF: op = nir_op_i2f; break; + case SpvOpConvertUToF: op = nir_op_u2f; break; + case SpvOpBitcast: op = nir_op_imov; break; + case SpvOpUConvert: + case SpvOpSConvert: + op = nir_op_imov; /* TODO: NIR is 32-bit only; these are no-ops. */ + break; + case SpvOpFConvert: + op = nir_op_fmov; + break; + + /* Derivatives: */ + case SpvOpDPdx: op = nir_op_fddx; break; + case SpvOpDPdy: op = nir_op_fddy; break; + case SpvOpDPdxFine: op = nir_op_fddx_fine; break; + case SpvOpDPdyFine: op = nir_op_fddy_fine; break; + case SpvOpDPdxCoarse: op = nir_op_fddx_coarse; break; + case SpvOpDPdyCoarse: op = nir_op_fddy_coarse; break; + case SpvOpFwidth: + val->ssa->def = nir_fadd(&b->nb, + nir_fabs(&b->nb, nir_fddx(&b->nb, src[0])), + nir_fabs(&b->nb, nir_fddx(&b->nb, src[1]))); + return; + case SpvOpFwidthFine: + val->ssa->def = nir_fadd(&b->nb, + nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[0])), + nir_fabs(&b->nb, nir_fddx_fine(&b->nb, src[1]))); + return; + case SpvOpFwidthCoarse: + val->ssa->def = nir_fadd(&b->nb, + nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[0])), + nir_fabs(&b->nb, nir_fddx_coarse(&b->nb, src[1]))); + return; + + case SpvOpVectorTimesScalar: + /* The builder will take care of splatting for us. */ + val->ssa->def = nir_fmul(&b->nb, src[0], src[1]); + return; + + case SpvOpSRem: + case SpvOpFRem: + unreachable("No NIR equivalent"); + + case SpvOpIsNan: + val->ssa->def = nir_fne(&b->nb, src[0], src[0]); + return; + + case SpvOpIsInf: + val->ssa->def = nir_feq(&b->nb, nir_fabs(&b->nb, src[0]), + nir_imm_float(&b->nb, INFINITY)); + return; + + case SpvOpIsFinite: + case SpvOpIsNormal: + case SpvOpSignBitSet: + case SpvOpLessOrGreater: + case SpvOpOrdered: + case SpvOpUnordered: + default: + unreachable("Unhandled opcode"); + } + + if (swap) { + nir_ssa_def *tmp = src[0]; + src[0] = src[1]; + src[1] = tmp; + } + + val->ssa->def = nir_build_alu(&b->nb, op, src[0], src[1], src[2], src[3]); +} diff --git a/src/glsl/nir/spirv/vtn_private.h b/src/glsl/nir/spirv/vtn_private.h index 0fa7dd4b041..14355c901f0 100644 --- a/src/glsl/nir/spirv/vtn_private.h +++ b/src/glsl/nir/spirv/vtn_private.h @@ -363,6 +363,12 @@ vtn_value(struct vtn_builder *b, uint32_t value_id, struct vtn_ssa_value *vtn_ssa_value(struct vtn_builder *b, uint32_t value_id); +struct vtn_ssa_value *vtn_create_ssa_value(struct vtn_builder *b, + const struct glsl_type *type); + +struct vtn_ssa_value *vtn_ssa_transpose(struct vtn_builder *b, + struct vtn_ssa_value *src); + void vtn_variable_store(struct vtn_builder *b, struct vtn_ssa_value *src, nir_deref_var *dest, struct vtn_type *dest_type); @@ -384,5 +390,8 @@ typedef void (*vtn_execution_mode_foreach_cb)(struct vtn_builder *, void vtn_foreach_execution_mode(struct vtn_builder *b, struct vtn_value *value, vtn_execution_mode_foreach_cb cb, void *data); +void vtn_handle_alu(struct vtn_builder *b, SpvOp opcode, + const uint32_t *w, unsigned count); + bool vtn_handle_glsl450_instruction(struct vtn_builder *b, uint32_t ext_opcode, const uint32_t *words, unsigned count); -- 2.30.2