* @param dady          shader input dady
  * @param color         color buffer
  * @param depth         depth buffer
- * @param mask          mask of visible pixels in block
+ * @param mask          mask of visible pixels in block (16-bits per sample)
  * @param thread_data   task thread data
  * @param stride        color buffer row stride in bytes
  * @param depth_stride  depth buffer row stride in bytes
                     const void *dady,
                     uint8_t **color,
                     uint8_t *depth,
-                    uint32_t mask,
+                    uint64_t mask,
                     struct lp_jit_thread_data *thread_data,
                     unsigned *stride,
                     unsigned depth_stride,
 
 generate_quad_mask(struct gallivm_state *gallivm,
                    struct lp_type fs_type,
                    unsigned first_quad,
-                   LLVMValueRef mask_input) /* int32 */
+                   unsigned sample,
+                   LLVMValueRef mask_input) /* int64 */
 {
    LLVMBuilderRef builder = gallivm->builder;
    struct lp_type mask_type;
       shift = 0;
    }
 
+   mask_input = LLVMBuildLShr(builder, mask_input, lp_build_const_int64(gallivm, 16 * sample), "");
+   mask_input = LLVMBuildTrunc(builder, mask_input,
+                               i32t, "");
+   mask_input = LLVMBuildAnd(builder, mask_input, lp_build_const_int32(gallivm, 0xffff), "");
+
    mask_input = LLVMBuildLShr(builder,
                               mask_input,
                               LLVMConstInt(i32t, shift, 0),
    arg_types[6] = LLVMPointerType(fs_elem_type, 0);    /* dady */
    arg_types[7] = LLVMPointerType(LLVMPointerType(blend_vec_type, 0), 0);  /* color */
    arg_types[8] = LLVMPointerType(int8_type, 0);       /* depth */
-   arg_types[9] = int32_type;                          /* mask_input */
+   arg_types[9] = LLVMInt64TypeInContext(gallivm->context);  /* mask_input */
    arg_types[10] = variant->jit_thread_data_ptr_type;  /* per thread data */
    arg_types[11] = LLVMPointerType(int32_type, 0);     /* stride */
    arg_types[12] = int32_type;                         /* depth_stride */
 
          if (partial_mask) {
             mask = generate_quad_mask(gallivm, fs_type,
-                                      i*fs_type.length/4, mask_input);
+                                      i*fs_type.length/4, 0, mask_input);
          }
          else {
             mask = lp_build_const_int_vec(gallivm, fs_type, ~0);