diff --git a/av1/common/blockd.h b/av1/common/blockd.h index 0e5a59900..9f735bce5 100644 --- a/av1/common/blockd.h +++ b/av1/common/blockd.h @@ -41,19 +41,35 @@ extern "C" { #if CONFIG_EXT_INTER // Should we try rectangular interintra predictions? #define USE_RECT_INTERINTRA 1 -#if CONFIG_COMPOUND_SEGMENT -#define MAX_SEG_MASK_BITS 3 +#if CONFIG_COMPOUND_SEGMENT + +// Set COMPOUND_SEGMENT_TYPE to one of the three +// 0: Uniform +// 1: Difference weighted +#define COMPOUND_SEGMENT_TYPE 1 + +#if COMPOUND_SEGMENT_TYPE == 0 +#define MAX_SEG_MASK_BITS 1 // SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS typedef enum { UNIFORM_45 = 0, UNIFORM_45_INV, - UNIFORM_55, - UNIFORM_55_INV, SEG_MASK_TYPES, } SEG_MASK_TYPE; + +#elif COMPOUND_SEGMENT_TYPE == 1 +#define MAX_SEG_MASK_BITS 1 +// SEG_MASK_TYPES should not surpass 1 << MAX_SEG_MASK_BITS +typedef enum { + DIFFWTD_45 = 0, + DIFFWTD_45_INV, + SEG_MASK_TYPES, +} SEG_MASK_TYPE; + +#endif // COMPOUND_SEGMENT_TYPE #endif // CONFIG_COMPOUND_SEGMENT -#endif +#endif // CONFIG_EXT_INTER typedef enum { KEY_FRAME = 0, diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c index dc22483d5..4bf4aa4ee 100644 --- a/av1/common/reconinter.c +++ b/av1/common/reconinter.c @@ -299,8 +299,9 @@ const uint8_t *av1_get_compound_type_mask( } #if CONFIG_COMPOUND_SEGMENT -void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type, int h, - int w, int mask_val) { +#if COMPOUND_SEGMENT_TYPE == 0 +static void uniform_mask(uint8_t *mask, int which_inverse, BLOCK_SIZE sb_type, + int h, int w, int mask_val) { int i, j; int block_stride = block_size_wide[sb_type]; for (i = 0; i < h; ++i) @@ -321,11 +322,103 @@ void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type, switch (mask_type) { case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break; case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break; - case UNIFORM_55: uniform_mask(mask, 0, sb_type, h, w, 55); break; - case UNIFORM_55_INV: uniform_mask(mask, 1, sb_type, h, w, 55); break; default: assert(0); } } + +#if CONFIG_AOM_HIGHBITDEPTH +void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type, + const uint8_t *src0, int src0_stride, + const uint8_t *src1, int src1_stride, + BLOCK_SIZE sb_type, int h, int w, int bd) { + (void)src0; + (void)src1; + (void)src0_stride; + (void)src1_stride; + (void)bd; + switch (mask_type) { + case UNIFORM_45: uniform_mask(mask, 0, sb_type, h, w, 45); break; + case UNIFORM_45_INV: uniform_mask(mask, 1, sb_type, h, w, 45); break; + default: assert(0); + } +} +#endif // CONFIG_AOM_HIGHBITDEPTH + +#elif COMPOUND_SEGMENT_TYPE == 1 +#define DIFF_FACTOR 16 +static void diffwtd_mask(uint8_t *mask, int which_inverse, int mask_base, + const uint8_t *src0, int src0_stride, + const uint8_t *src1, int src1_stride, + BLOCK_SIZE sb_type, int h, int w) { + int i, j, m, diff; + int block_stride = block_size_wide[sb_type]; + for (i = 0; i < h; ++i) { + for (j = 0; j < w; ++j) { + diff = + abs((int)src0[i * src0_stride + j] - (int)src1[i * src1_stride + j]); + m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA); + mask[i * block_stride + j] = + which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m; + } + } +} + +void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type, + const uint8_t *src0, int src0_stride, + const uint8_t *src1, int src1_stride, + BLOCK_SIZE sb_type, int h, int w) { + switch (mask_type) { + case DIFFWTD_45: + diffwtd_mask(mask, 0, 47, src0, src0_stride, src1, src1_stride, sb_type, + h, w); + break; + case DIFFWTD_45_INV: + diffwtd_mask(mask, 1, 47, src0, src0_stride, src1, src1_stride, sb_type, + h, w); + break; + default: assert(0); + } +} + +#if CONFIG_AOM_HIGHBITDEPTH +static void diffwtd_mask_highbd(uint8_t *mask, int which_inverse, int mask_base, + const uint16_t *src0, int src0_stride, + const uint16_t *src1, int src1_stride, + BLOCK_SIZE sb_type, int h, int w, int bd) { + int i, j, m, diff; + int block_stride = block_size_wide[sb_type]; + for (i = 0; i < h; ++i) { + for (j = 0; j < w; ++j) { + diff = abs((int)src0[i * src0_stride + j] - + (int)src1[i * src1_stride + j]) >> + (bd - 8); + m = clamp(mask_base + (diff / DIFF_FACTOR), 0, AOM_BLEND_A64_MAX_ALPHA); + mask[i * block_stride + j] = + which_inverse ? AOM_BLEND_A64_MAX_ALPHA - m : m; + } + } +} + +void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type, + const uint8_t *src0, int src0_stride, + const uint8_t *src1, int src1_stride, + BLOCK_SIZE sb_type, int h, int w, int bd) { + switch (mask_type) { + case DIFFWTD_42: + diffwtd_mask_highbd(mask, 0, 42, CONVERT_TO_SHORTPTR(src0), src0_stride, + CONVERT_TO_SHORTPTR(src1), src1_stride, sb_type, h, w, + bd); + break; + case DIFFWTD_42_INV: + diffwtd_mask_highbd(mask, 1, 42, CONVERT_TO_SHORTPTR(src0), src0_stride, + CONVERT_TO_SHORTPTR(src1), src1_stride, sb_type, h, w, + bd); + break; + default: assert(0); + } +} +#endif // CONFIG_AOM_HIGHBITDEPTH +#endif // COMPOUND_SEGMENT_TYPE #endif // CONFIG_COMPOUND_SEGMENT static void init_wedge_master_masks() { @@ -470,16 +563,18 @@ static void build_masked_compound( } #if CONFIG_AOM_HIGHBITDEPTH -static void build_masked_compound_wedge_highbd( +static void build_masked_compound_highbd( uint8_t *dst_8, int dst_stride, const uint8_t *src0_8, int src0_stride, - const uint8_t *src1_8, int src1_stride, int wedge_index, int wedge_sign, - BLOCK_SIZE sb_type, int h, int w, int bd) { + const uint8_t *src1_8, int src1_stride, + const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h, + int w, int bd) { // Derive subsampling from h and w passed in. May be refactored to // pass in subsampling factors directly. const int subh = (2 << b_height_log2_lookup[sb_type]) == h; const int subw = (2 << b_width_log2_lookup[sb_type]) == w; - const uint8_t *mask = - av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type); + const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type); + // const uint8_t *mask = + // av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type); aom_highbd_blend_a64_mask(dst_8, dst_stride, src0_8, src0_stride, src1_8, src1_stride, mask, block_size_wide[sb_type], h, w, subh, subw, bd); @@ -534,11 +629,22 @@ void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride, comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w); #else +#if CONFIG_COMPOUND_SEGMENT + if (!plane && comp_data->type == COMPOUND_SEG) { + if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) + build_compound_seg_mask_highbd(comp_data->seg_mask, comp_data->mask_type, + dst, dst_stride, tmp_dst, MAX_SB_SIZE, + mi->mbmi.sb_type, h, w, xd->bd); + else + build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, dst, + dst_stride, tmp_dst, MAX_SB_SIZE, + mi->mbmi.sb_type, h, w); + } +#endif // CONFIG_COMPOUND_SEGMENT if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) - build_masked_compound_wedge_highbd( - dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE, - comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, h, w, - xd->bd); + build_masked_compound_highbd(dst, dst_stride, dst, dst_stride, tmp_dst, + MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w, + xd->bd); else build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w); @@ -2707,17 +2813,32 @@ static void build_wedge_inter_predictor_from_buf( if (is_compound && is_masked_compound_type(mbmi->interinter_compound_data.type)) { #if CONFIG_COMPOUND_SEGMENT +#if CONFIG_AOM_HIGHBITDEPTH + if (!plane && comp_data->type == COMPOUND_SEG) { + if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) + build_compound_seg_mask_highbd( + comp_data->seg_mask, comp_data->mask_type, + CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0, + CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, mbmi->sb_type, h, w, + xd->bd); + else + build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, + ext_dst0, ext_dst_stride0, ext_dst1, + ext_dst_stride1, mbmi->sb_type, h, w); + } +#else if (!plane && comp_data->type == COMPOUND_SEG) build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, ext_dst0, ext_dst_stride0, ext_dst1, ext_dst_stride1, mbmi->sb_type, h, w); +#endif // CONFIG_AOM_HIGHBITDEPTH #endif // CONFIG_COMPOUND_SEGMENT #if CONFIG_AOM_HIGHBITDEPTH if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) - build_masked_compound_wedge_highbd( + build_masked_compound_highbd( dst, dst_buf->stride, CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0, - CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data->wedge_index, - comp_data->wedge_sign, mbmi->sb_type, h, w, xd->bd); + CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data, + mbmi->sb_type, h, w, xd->bd); else #endif // CONFIG_AOM_HIGHBITDEPTH build_masked_compound(dst, dst_buf->stride, ext_dst0, ext_dst_stride0, diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h index 581a977f1..7bea9edd6 100644 --- a/av1/common/reconinter.h +++ b/av1/common/reconinter.h @@ -196,6 +196,12 @@ void build_compound_seg_mask(uint8_t *mask, SEG_MASK_TYPE mask_type, const uint8_t *src0, int src0_stride, const uint8_t *src1, int src1_stride, BLOCK_SIZE sb_type, int h, int w); +#if CONFIG_AOM_HIGHBITDEPTH +void build_compound_seg_mask_highbd(uint8_t *mask, SEG_MASK_TYPE mask_type, + const uint8_t *src0, int src0_stride, + const uint8_t *src1, int src1_stride, + BLOCK_SIZE sb_type, int h, int w, int bd); +#endif // CONFIG_AOM_HIGHBITDEPTH #endif // CONFIG_COMPOUND_SEGMENT #endif // CONFIG_EXT_INTER diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c index 304541295..ee6cc7f2f 100644 --- a/av1/encoder/rdopt.c +++ b/av1/encoder/rdopt.c @@ -6852,9 +6852,16 @@ static int64_t pick_interinter_seg_mask(const AV1_COMP *const cpi, // try each mask type and its inverse for (cur_mask_type = 0; cur_mask_type < SEG_MASK_TYPES; cur_mask_type++) { - // build mask and inverse - build_compound_seg_mask(comp_data->seg_mask, cur_mask_type, p0, bw, p1, bw, - bsize, bh, bw); +// build mask and inverse +#if CONFIG_AOM_HIGHBITDEPTH + if (hbd) + build_compound_seg_mask_highbd( + comp_data->seg_mask, cur_mask_type, CONVERT_TO_BYTEPTR(p0), bw, + CONVERT_TO_BYTEPTR(p1), bw, bsize, bh, bw, xd->bd); + else +#endif // CONFIG_AOM_HIGHBITDEPTH + build_compound_seg_mask(comp_data->seg_mask, cur_mask_type, p0, bw, p1, + bw, bsize, bh, bw); // compute rd for mask sse = av1_wedge_sse_from_residuals(r1, d10, comp_data->seg_mask, N); @@ -6871,8 +6878,15 @@ static int64_t pick_interinter_seg_mask(const AV1_COMP *const cpi, // make final mask comp_data->mask_type = best_mask_type; - build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, p0, bw, p1, - bw, bsize, bh, bw); +#if CONFIG_AOM_HIGHBITDEPTH + if (hbd) + build_compound_seg_mask_highbd( + comp_data->seg_mask, comp_data->mask_type, CONVERT_TO_BYTEPTR(p0), bw, + CONVERT_TO_BYTEPTR(p1), bw, bsize, bh, bw, xd->bd); + else +#endif // CONFIG_AOM_HIGHBITDEPTH + build_compound_seg_mask(comp_data->seg_mask, comp_data->mask_type, p0, bw, + p1, bw, bsize, bh, bw); return best_rd; }