X-Git-Url: https://git.libre-soc.org/?a=blobdiff_plain;f=src%2Fcompiler%2Fspirv%2Fvtn_glsl450.c;h=753e74cf73cb93ac660eecd81e6a791a2df0296d;hb=272f9cfe6a19212354c89dc443959473ac5d398e;hp=b54aeb9b21746ec7da0190a3cee44f11c2b10caa;hpb=dca6cd9ce65100896976e913bf72c2c68ea4e1a7;p=mesa.git diff --git a/src/compiler/spirv/vtn_glsl450.c b/src/compiler/spirv/vtn_glsl450.c index b54aeb9b217..753e74cf73c 100644 --- a/src/compiler/spirv/vtn_glsl450.c +++ b/src/compiler/spirv/vtn_glsl450.c @@ -40,7 +40,7 @@ static nir_ssa_def * build_mat2_det(nir_builder *b, nir_ssa_def *col[2]) { unsigned swiz[2] = {1, 0 }; - nir_ssa_def *p = nir_fmul(b, col[0], nir_swizzle(b, col[1], swiz, 2, true)); + nir_ssa_def *p = nir_fmul(b, col[0], nir_swizzle(b, col[1], swiz, 2)); return nir_fsub(b, nir_channel(b, p, 0), nir_channel(b, p, 1)); } @@ -52,12 +52,12 @@ build_mat3_det(nir_builder *b, nir_ssa_def *col[3]) nir_ssa_def *prod0 = nir_fmul(b, col[0], - nir_fmul(b, nir_swizzle(b, col[1], yzx, 3, true), - nir_swizzle(b, col[2], zxy, 3, true))); + nir_fmul(b, nir_swizzle(b, col[1], yzx, 3), + nir_swizzle(b, col[2], zxy, 3))); nir_ssa_def *prod1 = nir_fmul(b, col[0], - nir_fmul(b, nir_swizzle(b, col[1], zxy, 3, true), - nir_swizzle(b, col[2], yzx, 3, true))); + nir_fmul(b, nir_swizzle(b, col[1], zxy, 3), + nir_swizzle(b, col[2], yzx, 3))); nir_ssa_def *diff = nir_fsub(b, prod0, prod1); @@ -76,9 +76,9 @@ build_mat4_det(nir_builder *b, nir_ssa_def **col) swiz[j] = j + (j >= i); nir_ssa_def *subcol[3]; - subcol[0] = nir_swizzle(b, col[1], swiz, 3, true); - subcol[1] = nir_swizzle(b, col[2], swiz, 3, true); - subcol[2] = nir_swizzle(b, col[3], swiz, 3, true); + subcol[0] = nir_swizzle(b, col[1], swiz, 3); + subcol[1] = nir_swizzle(b, col[2], swiz, 3); + subcol[2] = nir_swizzle(b, col[3], swiz, 3); subdet[i] = build_mat3_det(b, subcol); } @@ -130,7 +130,7 @@ build_mat_subdet(struct nir_builder *b, struct vtn_ssa_value *src, for (unsigned j = 0; j < size; j++) { if (j != col) { subcol[j - (j > col)] = nir_swizzle(b, src->elems[j]->def, - swiz, size - 1, true); + swiz, size - 1); } } @@ -177,7 +177,7 @@ matrix_inverse(struct vtn_builder *b, struct vtn_ssa_value *src) static nir_ssa_def * build_exp(nir_builder *b, nir_ssa_def *x) { - return nir_fexp2(b, nir_fmul(b, x, nir_imm_float(b, M_LOG2E))); + return nir_fexp2(b, nir_fmul_imm(b, x, M_LOG2E)); } /** @@ -186,7 +186,7 @@ build_exp(nir_builder *b, nir_ssa_def *x) static nir_ssa_def * build_log(nir_builder *b, nir_ssa_def *x) { - return nir_fmul(b, nir_flog2(b, x), nir_imm_float(b, 1.0 / M_LOG2E)); + return nir_fmul_imm(b, nir_flog2(b, x), 1.0 / M_LOG2E); } /** @@ -202,17 +202,36 @@ build_log(nir_builder *b, nir_ssa_def *x) static nir_ssa_def * build_asin(nir_builder *b, nir_ssa_def *x, float p0, float p1) { + if (x->bit_size == 16) { + /* The polynomial approximation isn't precise enough to meet half-float + * precision requirements. Alternatively, we could implement this using + * the formula: + * + * asin(x) = atan2(x, sqrt(1 - x*x)) + * + * But that is very expensive, so instead we just do the polynomial + * approximation in 32-bit math and then we convert the result back to + * 16-bit. + */ + return nir_f2f16(b, build_asin(b, nir_f2f32(b, x), p0, p1)); + } + + nir_ssa_def *one = nir_imm_floatN_t(b, 1.0f, x->bit_size); nir_ssa_def *abs_x = nir_fabs(b, x); + + nir_ssa_def *p0_plus_xp1 = nir_fadd_imm(b, nir_fmul_imm(b, abs_x, p1), p0); + + nir_ssa_def *expr_tail = + nir_fadd_imm(b, nir_fmul(b, abs_x, + nir_fadd_imm(b, nir_fmul(b, abs_x, + p0_plus_xp1), + M_PI_4f - 1.0f)), + M_PI_2f); + return nir_fmul(b, nir_fsign(b, x), - nir_fsub(b, nir_imm_float(b, M_PI_2f), - nir_fmul(b, nir_fsqrt(b, nir_fsub(b, nir_imm_float(b, 1.0f), abs_x)), - nir_fadd(b, nir_imm_float(b, M_PI_2f), - nir_fmul(b, abs_x, - nir_fadd(b, nir_imm_float(b, M_PI_4f - 1.0f), - nir_fmul(b, abs_x, - nir_fadd(b, nir_imm_float(b, p0), - nir_fmul(b, abs_x, - nir_imm_float(b, p1)))))))))); + nir_fsub(b, nir_imm_floatN_t(b, M_PI_2f, x->bit_size), + nir_fmul(b, nir_fsqrt(b, nir_fsub(b, one, abs_x)), + expr_tail))); } /** @@ -232,8 +251,10 @@ build_fsum(nir_builder *b, nir_ssa_def **xs, int terms) static nir_ssa_def * build_atan(nir_builder *b, nir_ssa_def *y_over_x) { + const uint32_t bit_size = y_over_x->bit_size; + nir_ssa_def *abs_y_over_x = nir_fabs(b, y_over_x); - nir_ssa_def *one = nir_imm_float(b, 1.0f); + nir_ssa_def *one = nir_imm_floatN_t(b, 1.0f, bit_size); /* * range-reduction, first step: @@ -260,12 +281,12 @@ build_atan(nir_builder *b, nir_ssa_def *y_over_x) nir_ssa_def *x_11 = nir_fmul(b, x_9, x_2); nir_ssa_def *polynomial_terms[] = { - nir_fmul(b, x, nir_imm_float(b, 0.9999793128310355f)), - nir_fmul(b, x_3, nir_imm_float(b, -0.3326756418091246f)), - nir_fmul(b, x_5, nir_imm_float(b, 0.1938924977115610f)), - nir_fmul(b, x_7, nir_imm_float(b, -0.1173503194786851f)), - nir_fmul(b, x_9, nir_imm_float(b, 0.0536813784310406f)), - nir_fmul(b, x_11, nir_imm_float(b, -0.0121323213173444f)), + nir_fmul_imm(b, x, 0.9999793128310355f), + nir_fmul_imm(b, x_3, -0.3326756418091246f), + nir_fmul_imm(b, x_5, 0.1938924977115610f), + nir_fmul_imm(b, x_7, -0.1173503194786851f), + nir_fmul_imm(b, x_9, 0.0536813784310406f), + nir_fmul_imm(b, x_11, -0.0121323213173444f), }; nir_ssa_def *tmp = @@ -273,11 +294,8 @@ build_atan(nir_builder *b, nir_ssa_def *y_over_x) /* range-reduction fixup */ tmp = nir_fadd(b, tmp, - nir_fmul(b, - nir_b2f32(b, nir_flt(b, one, abs_y_over_x)), - nir_fadd(b, nir_fmul(b, tmp, - nir_imm_float(b, -2.0f)), - nir_imm_float(b, M_PI_2f)))); + nir_fmul(b, nir_b2f(b, nir_flt(b, one, abs_y_over_x), bit_size), + nir_fadd_imm(b, nir_fmul_imm(b, tmp, -2.0f), M_PI_2f))); /* sign fixup */ return nir_fmul(b, tmp, nir_fsign(b, y_over_x)); @@ -286,8 +304,11 @@ build_atan(nir_builder *b, nir_ssa_def *y_over_x) static nir_ssa_def * build_atan2(nir_builder *b, nir_ssa_def *y, nir_ssa_def *x) { - nir_ssa_def *zero = nir_imm_float(b, 0); - nir_ssa_def *one = nir_imm_float(b, 1); + assert(y->bit_size == x->bit_size); + const uint32_t bit_size = x->bit_size; + + nir_ssa_def *zero = nir_imm_floatN_t(b, 0, bit_size); + nir_ssa_def *one = nir_imm_floatN_t(b, 1, bit_size); /* If we're on the left half-plane rotate the coordinates π/2 clock-wise * for the y=0 discontinuity to end up aligned with the vertical @@ -317,9 +338,10 @@ build_atan2(nir_builder *b, nir_ssa_def *y, nir_ssa_def *x) * floating point representations with at least the dynamic range of ATI's * 24-bit representation. */ - nir_ssa_def *huge = nir_imm_float(b, 1e18f); + const double huge_val = bit_size >= 32 ? 1e18 : 16384; + nir_ssa_def *huge = nir_imm_floatN_t(b, huge_val, bit_size); nir_ssa_def *scale = nir_bcsel(b, nir_fge(b, nir_fabs(b, t), huge), - nir_imm_float(b, 0.25), one); + nir_imm_floatN_t(b, 0.25, bit_size), one); nir_ssa_def *rcp_scaled_t = nir_frcp(b, nir_fmul(b, t, scale)); nir_ssa_def *s_over_t = nir_fmul(b, nir_fmul(b, s, scale), rcp_scaled_t); @@ -346,9 +368,9 @@ build_atan2(nir_builder *b, nir_ssa_def *y, nir_ssa_def *x) /* Calculate the arctangent and fix up the result if we had flipped the * coordinate system. */ - nir_ssa_def *arc = nir_fadd(b, nir_fmul(b, nir_b2f32(b, flip), - nir_imm_float(b, M_PI_2f)), - build_atan(b, tan)); + nir_ssa_def *arc = + nir_fadd(b, nir_fmul_imm(b, nir_b2f(b, flip, bit_size), M_PI_2f), + build_atan(b, tan)); /* Rather convoluted calculation of the sign of the result. When x < 0 we * cannot use fsign because we need to be able to distinguish between @@ -363,84 +385,6 @@ build_atan2(nir_builder *b, nir_ssa_def *y, nir_ssa_def *x) nir_fneg(b, arc), arc); } -static nir_ssa_def * -build_frexp32(nir_builder *b, nir_ssa_def *x, nir_ssa_def **exponent) -{ - nir_ssa_def *abs_x = nir_fabs(b, x); - nir_ssa_def *zero = nir_imm_float(b, 0.0f); - - /* Single-precision floating-point values are stored as - * 1 sign bit; - * 8 exponent bits; - * 23 mantissa bits. - * - * An exponent shift of 23 will shift the mantissa out, leaving only the - * exponent and sign bit (which itself may be zero, if the absolute value - * was taken before the bitcast and shift. - */ - nir_ssa_def *exponent_shift = nir_imm_int(b, 23); - nir_ssa_def *exponent_bias = nir_imm_int(b, -126); - - nir_ssa_def *sign_mantissa_mask = nir_imm_int(b, 0x807fffffu); - - /* Exponent of floating-point values in the range [0.5, 1.0). */ - nir_ssa_def *exponent_value = nir_imm_int(b, 0x3f000000u); - - nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero); - - *exponent = - nir_iadd(b, nir_ushr(b, abs_x, exponent_shift), - nir_bcsel(b, is_not_zero, exponent_bias, zero)); - - return nir_ior(b, nir_iand(b, x, sign_mantissa_mask), - nir_bcsel(b, is_not_zero, exponent_value, zero)); -} - -static nir_ssa_def * -build_frexp64(nir_builder *b, nir_ssa_def *x, nir_ssa_def **exponent) -{ - nir_ssa_def *abs_x = nir_fabs(b, x); - nir_ssa_def *zero = nir_imm_double(b, 0.0); - nir_ssa_def *zero32 = nir_imm_float(b, 0.0f); - - /* Double-precision floating-point values are stored as - * 1 sign bit; - * 11 exponent bits; - * 52 mantissa bits. - * - * We only need to deal with the exponent so first we extract the upper 32 - * bits using nir_unpack_64_2x32_split_y. - */ - nir_ssa_def *upper_x = nir_unpack_64_2x32_split_y(b, x); - nir_ssa_def *abs_upper_x = nir_unpack_64_2x32_split_y(b, abs_x); - - /* An exponent shift of 20 will shift the remaining mantissa bits out, - * leaving only the exponent and sign bit (which itself may be zero, if the - * absolute value was taken before the bitcast and shift. - */ - nir_ssa_def *exponent_shift = nir_imm_int(b, 20); - nir_ssa_def *exponent_bias = nir_imm_int(b, -1022); - - nir_ssa_def *sign_mantissa_mask = nir_imm_int(b, 0x800fffffu); - - /* Exponent of floating-point values in the range [0.5, 1.0). */ - nir_ssa_def *exponent_value = nir_imm_int(b, 0x3fe00000u); - - nir_ssa_def *is_not_zero = nir_fne(b, abs_x, zero); - - *exponent = - nir_iadd(b, nir_ushr(b, abs_upper_x, exponent_shift), - nir_bcsel(b, is_not_zero, exponent_bias, zero32)); - - nir_ssa_def *new_upper = - nir_ior(b, nir_iand(b, upper_x, sign_mantissa_mask), - nir_bcsel(b, is_not_zero, exponent_value, zero32)); - - nir_ssa_def *lower_x = nir_unpack_64_2x32_split_x(b, x); - - return nir_pack_64_2x32_split(b, lower_x, new_upper); -} - static nir_op vtn_nir_alu_op_for_spirv_glsl_opcode(struct vtn_builder *b, enum GLSLstd450 opcode) @@ -545,7 +489,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, case GLSLstd450ModfStruct: { nir_ssa_def *sign = nir_fsign(nb, src[0]); nir_ssa_def *abs = nir_fabs(nb, src[0]); - vtn_assert(glsl_type_is_struct(val->ssa->type)); + vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type)); val->ssa->elems[0]->def = nir_fmul(nb, sign, nir_ffract(nb, abs)); val->ssa->elems[1]->def = nir_fmul(nb, sign, nir_ffloor(nb, abs)); return; @@ -585,7 +529,7 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, return; case GLSLstd450Cross: { - val->ssa->def = nir_cross(nb, src[0], src[1]); + val->ssa->def = nir_cross3(nb, src[0], src[1]); return; } @@ -646,17 +590,17 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, case GLSLstd450Sinh: /* 0.5 * (e^x - e^(-x)) */ val->ssa->def = - nir_fmul(nb, nir_imm_float(nb, 0.5f), - nir_fsub(nb, build_exp(nb, src[0]), - build_exp(nb, nir_fneg(nb, src[0])))); + nir_fmul_imm(nb, nir_fsub(nb, build_exp(nb, src[0]), + build_exp(nb, nir_fneg(nb, src[0]))), + 0.5f); return; case GLSLstd450Cosh: /* 0.5 * (e^x + e^(-x)) */ val->ssa->def = - nir_fmul(nb, nir_imm_float(nb, 0.5f), - nir_fadd(nb, build_exp(nb, src[0]), - build_exp(nb, nir_fneg(nb, src[0])))); + nir_fmul_imm(nb, nir_fadd(nb, build_exp(nb, src[0]), + build_exp(nb, nir_fneg(nb, src[0]))), + 0.5f); return; case GLSLstd450Tanh: { @@ -667,30 +611,38 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, * We clamp x to (-inf, +10] to avoid precision problems. When x > 10, * e^2x is so much larger than 1.0 that 1.0 gets flushed to zero in the * computation e^2x +/- 1 so it can be ignored. + * + * For 16-bit precision we clamp x to (-inf, +4.2] since the maximum + * representable number is only 65,504 and e^(2*6) exceeds that. Also, + * if x > 4.2, tanh(x) will return 1.0 in fp16. */ - nir_ssa_def *x = nir_fmin(nb, src[0], nir_imm_float(nb, 10)); - nir_ssa_def *exp2x = build_exp(nb, nir_fmul(nb, x, nir_imm_float(nb, 2))); - val->ssa->def = nir_fdiv(nb, nir_fsub(nb, exp2x, nir_imm_float(nb, 1)), - nir_fadd(nb, exp2x, nir_imm_float(nb, 1))); + const uint32_t bit_size = src[0]->bit_size; + const double clamped_x = bit_size > 16 ? 10.0 : 4.2; + nir_ssa_def *x = nir_fmin(nb, src[0], + nir_imm_floatN_t(nb, clamped_x, bit_size)); + nir_ssa_def *exp2x = build_exp(nb, nir_fmul_imm(nb, x, 2.0)); + val->ssa->def = nir_fdiv(nb, nir_fadd_imm(nb, exp2x, -1.0), + nir_fadd_imm(nb, exp2x, 1.0)); return; } case GLSLstd450Asinh: val->ssa->def = nir_fmul(nb, nir_fsign(nb, src[0]), build_log(nb, nir_fadd(nb, nir_fabs(nb, src[0]), - nir_fsqrt(nb, nir_fadd(nb, nir_fmul(nb, src[0], src[0]), - nir_imm_float(nb, 1.0f)))))); + nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]), + 1.0f))))); return; case GLSLstd450Acosh: val->ssa->def = build_log(nb, nir_fadd(nb, src[0], - nir_fsqrt(nb, nir_fsub(nb, nir_fmul(nb, src[0], src[0]), - nir_imm_float(nb, 1.0f))))); + nir_fsqrt(nb, nir_fadd_imm(nb, nir_fmul(nb, src[0], src[0]), + -1.0f)))); return; case GLSLstd450Atanh: { - nir_ssa_def *one = nir_imm_float(nb, 1.0); - val->ssa->def = nir_fmul(nb, nir_imm_float(nb, 0.5f), - build_log(nb, nir_fdiv(nb, nir_fadd(nb, one, src[0]), - nir_fsub(nb, one, src[0])))); + nir_ssa_def *one = nir_imm_floatN_t(nb, 1.0, src[0]->bit_size); + val->ssa->def = + nir_fmul_imm(nb, build_log(nb, nir_fdiv(nb, nir_fadd(nb, src[0], one), + nir_fsub(nb, one, src[0]))), + 0.5f); return; } @@ -699,8 +651,9 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, return; case GLSLstd450Acos: - val->ssa->def = nir_fsub(nb, nir_imm_float(nb, M_PI_2f), - build_asin(nb, src[0], 0.08132463, -0.02363318)); + val->ssa->def = + nir_fsub(nb, nir_imm_floatN_t(nb, M_PI_2f, src[0]->bit_size), + build_asin(nb, src[0], 0.08132463, -0.02363318)); return; case GLSLstd450Atan: @@ -712,23 +665,16 @@ handle_glsl450_alu(struct vtn_builder *b, enum GLSLstd450 entrypoint, return; case GLSLstd450Frexp: { - nir_ssa_def *exponent; - if (src[0]->bit_size == 64) - val->ssa->def = build_frexp64(nb, src[0], &exponent); - else - val->ssa->def = build_frexp32(nb, src[0], &exponent); + nir_ssa_def *exponent = nir_frexp_exp(nb, src[0]); + val->ssa->def = nir_frexp_sig(nb, src[0]); nir_store_deref(nb, vtn_nir_deref(b, w[6]), exponent, 0xf); return; } case GLSLstd450FrexpStruct: { - vtn_assert(glsl_type_is_struct(val->ssa->type)); - if (src[0]->bit_size == 64) - val->ssa->elems[0]->def = build_frexp64(nb, src[0], - &val->ssa->elems[1]->def); - else - val->ssa->elems[0]->def = build_frexp32(nb, src[0], - &val->ssa->elems[1]->def); + vtn_assert(glsl_type_is_struct_or_ifc(val->ssa->type)); + val->ssa->elems[0]->def = nir_frexp_sig(nb, src[0]); + val->ssa->elems[1]->def = nir_frexp_exp(nb, src[0]); return; } @@ -807,10 +753,9 @@ handle_glsl450_interpolation(struct vtn_builder *b, enum GLSLstd450 opcode, if (vec_array_deref) { assert(vec_deref); - nir_const_value *const_index = nir_src_as_const_value(vec_deref->arr.index); - if (const_index) { + if (nir_src_is_const(vec_deref->arr.index)) { val->ssa->def = vtn_vector_extract(b, &intrin->dest.ssa, - const_index->u32[0]); + nir_src_as_uint(vec_deref->arr.index)); } else { val->ssa->def = vtn_vector_extract_dynamic(b, &intrin->dest.ssa, vec_deref->arr.index.ssa); @@ -842,7 +787,7 @@ vtn_handle_glsl450_instruction(struct vtn_builder *b, SpvOp ext_opcode, case GLSLstd450InterpolateAtCentroid: case GLSLstd450InterpolateAtSample: case GLSLstd450InterpolateAtOffset: - handle_glsl450_interpolation(b, ext_opcode, w, count); + handle_glsl450_interpolation(b, (enum GLSLstd450)ext_opcode, w, count); break; default: