diff --git a/av1/common/blockd.h b/av1/common/blockd.h index 2a887258c..c9fcfb2bc 100644 --- a/av1/common/blockd.h +++ b/av1/common/blockd.h @@ -163,6 +163,12 @@ static INLINE int have_newmv_in_inter_mode(PREDICTION_MODE mode) { mode == NEAREST_NEWMV || mode == NEW_NEARESTMV || mode == NEAR_NEWMV || mode == NEW_NEARMV); } + +// TODO(sarahparker) this will eventually be extended when more +// masked compound types are added +static INLINE int is_masked_compound_type(COMPOUND_TYPE type) { + return (type == COMPOUND_WEDGE); +} #else static INLINE int have_newmv_in_inter_mode(PREDICTION_MODE mode) { @@ -232,6 +238,15 @@ typedef struct RD_STATS { #endif // CONFIG_RD_DEBUG } RD_STATS; +#if CONFIG_EXT_INTER +typedef struct { + COMPOUND_TYPE type; + int wedge_index; + int wedge_sign; + // TODO(sarahparker) add neccesary data for segmentation compound type +} INTERINTER_COMPOUND_DATA; +#endif // CONFIG_EXT_INTER + // This structure now relates to 8x8 block regions. typedef struct { // Common for both INTER and INTRA blocks @@ -282,9 +297,7 @@ typedef struct { int use_wedge_interintra; int interintra_wedge_index; int interintra_wedge_sign; - COMPOUND_TYPE interinter_compound; - int interinter_wedge_index; - int interinter_wedge_sign; + INTERINTER_COMPOUND_DATA interinter_compound_data; #endif // CONFIG_EXT_INTER MOTION_MODE motion_mode; int_mv mv[2]; diff --git a/av1/common/entropymode.c b/av1/common/entropymode.c index faa9abf61..d91d6b30e 100644 --- a/av1/common/entropymode.c +++ b/av1/common/entropymode.c @@ -1910,10 +1910,9 @@ void av1_adapt_inter_frame_probs(AV1_COMMON *cm) { } for (i = 0; i < BLOCK_SIZES; ++i) { - if (is_interinter_wedge_used(i)) - aom_tree_merge_probs( - av1_compound_type_tree, pre_fc->compound_type_prob[i], - counts->compound_interinter[i], fc->compound_type_prob[i]); + aom_tree_merge_probs(av1_compound_type_tree, pre_fc->compound_type_prob[i], + counts->compound_interinter[i], + fc->compound_type_prob[i]); } #endif // CONFIG_EXT_INTER diff --git a/av1/common/reconinter.c b/av1/common/reconinter.c index bc77687ec..1ae85b051 100644 --- a/av1/common/reconinter.c +++ b/av1/common/reconinter.c @@ -251,6 +251,24 @@ const uint8_t *av1_get_soft_mask(int wedge_index, int wedge_sign, return mask; } +// get a mask according to the compound type +// TODO(sarahparker) this needs to be extended for other experiments and +// is currently only intended for ext_inter alone +#if CONFIG_EXT_INTER +const uint8_t *av1_get_compound_type_mask( + const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, + int invert) { + assert(is_masked_compound_type(comp_data->type)); + switch (comp_data->type) { + case COMPOUND_WEDGE: + return av1_get_contiguous_soft_mask( + comp_data->wedge_index, + invert ? !comp_data->wedge_sign : comp_data->wedge_sign, sb_type); + default: assert(0); return NULL; + } +} +#endif // CONFIG_EXT_INTER + static void init_wedge_master_masks() { int i, j, s; const int w = MASK_MASTER_SIZE; @@ -378,17 +396,16 @@ static void build_masked_compound_wedge_extend_highbd( #endif // CONFIG_AOM_HIGHBITDEPTH #endif // CONFIG_SUPERTX -static void build_masked_compound_wedge(uint8_t *dst, int dst_stride, - const uint8_t *src0, int src0_stride, - const uint8_t *src1, int src1_stride, - int wedge_index, int wedge_sign, - BLOCK_SIZE sb_type, int h, int w) { +static void build_masked_compound( + uint8_t *dst, int dst_stride, const uint8_t *src0, int src0_stride, + const uint8_t *src1, int src1_stride, + const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, int h, + int w) { // Derive subsampling from h and w passed in. May be refactored to // pass in subsampling factors directly. const int subh = (2 << b_height_log2_lookup[sb_type]) == h; const int subw = (2 << b_width_log2_lookup[sb_type]) == w; - const uint8_t *mask = - av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type); + const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0); aom_blend_a64_mask(dst, dst_stride, src0, src0_stride, src1, src1_stride, mask, block_size_wide[sb_type], h, w, subh, subw); } @@ -402,8 +419,7 @@ static void build_masked_compound_wedge_highbd( // pass in subsampling factors directly. const int subh = (2 << b_height_log2_lookup[sb_type]) == h; const int subw = (2 << b_width_log2_lookup[sb_type]) == w; - const uint8_t *mask = - av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type); + const uint8_t *mask = av1_get_compound_type_mask(comp_data, sb_type, 0); aom_highbd_blend_a64_mask(dst_8, dst_stride, src0_8, src0_stride, src1_8, src1_stride, mask, block_size_wide[sb_type], h, w, subh, subw, bd); @@ -426,6 +442,8 @@ void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride, #endif // CONFIG_SUPERTX const MACROBLOCKD *xd) { const MODE_INFO *mi = xd->mi[0]; + const INTERINTER_COMPOUND_DATA *const comp_data = + &mi->mbmi.interinter_compound_data; // The prediction filter types used here should be those for // the second reference block. #if CONFIG_DUAL_FILTER @@ -446,39 +464,35 @@ void av1_make_masked_inter_predictor(const uint8_t *pre, int pre_stride, if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) build_masked_compound_wedge_extend_highbd( dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE, - mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign, - mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w, xd->bd); + comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, + wedge_offset_x, wedge_offset_y, h, w, xd->bd); else build_masked_compound_wedge_extend( dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE, - mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign, - mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w); + comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, + wedge_offset_x, wedge_offset_y, h, w); #else if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) build_masked_compound_wedge_highbd( dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE, - mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign, - mi->mbmi.sb_type, h, w, xd->bd); + comp_data->wedge_index, comp_data->wedge_sign, mi->mbmi.sb_type, h, w, + xd->bd); else - build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst, - MAX_SB_SIZE, mi->mbmi.interinter_wedge_index, - mi->mbmi.interinter_wedge_sign, - mi->mbmi.sb_type, h, w); + build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, + MAX_SB_SIZE, comp_data, mi->mbmi.sb_type, h, w); #endif // CONFIG_SUPERTX #else // CONFIG_AOM_HIGHBITDEPTH DECLARE_ALIGNED(16, uint8_t, tmp_dst[MAX_SB_SQUARE]); av1_make_inter_predictor(pre, pre_stride, tmp_dst, MAX_SB_SIZE, subpel_x, subpel_y, sf, w, h, 0, tmp_ipf, xs, ys, xd); #if CONFIG_SUPERTX - build_masked_compound_wedge_extend( - dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE, - mi->mbmi.interinter_wedge_index, mi->mbmi.interinter_wedge_sign, - mi->mbmi.sb_type, wedge_offset_x, wedge_offset_y, h, w); + build_masked_compound_wedge_extend(dst, dst_stride, dst, dst_stride, tmp_dst, + MAX_SB_SIZE, comp_data->wedge_index, + comp_data->wedge_sign, mi->mbmi.sb_type, + wedge_offset_x, wedge_offset_y, h, w); #else - build_masked_compound_wedge(dst, dst_stride, dst, dst_stride, tmp_dst, - MAX_SB_SIZE, mi->mbmi.interinter_wedge_index, - mi->mbmi.interinter_wedge_sign, mi->mbmi.sb_type, - h, w); + build_masked_compound(dst, dst_stride, dst, dst_stride, tmp_dst, MAX_SB_SIZE, + comp_data, mi->mbmi.sb_type, h, w); #endif // CONFIG_SUPERTX #endif // CONFIG_AOM_HIGHBITDEPTH } @@ -630,8 +644,8 @@ void build_inter_predictors(MACROBLOCKD *xd, int plane, (scaled_mv.col >> SUBPEL_BITS); #if CONFIG_EXT_INTER - if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) && - mi->mbmi.interinter_compound == COMPOUND_WEDGE) + if (ref && + is_masked_compound_type(mi->mbmi.interinter_compound_data.type)) av1_make_masked_inter_predictor( pre, pre_buf->stride, dst, dst_buf->stride, subpel_x, subpel_y, sf, w, h, mi->mbmi.interp_filter, xs, ys, @@ -696,8 +710,7 @@ void build_inter_predictors(MACROBLOCKD *xd, int plane, (scaled_mv.col >> SUBPEL_BITS); #if CONFIG_EXT_INTER - if (ref && is_interinter_wedge_used(mi->mbmi.sb_type) && - mi->mbmi.interinter_compound == COMPOUND_WEDGE) + if (ref && is_masked_compound_type(mi->mbmi.interinter_compound_data.type)) av1_make_masked_inter_predictor(pre, pre_buf->stride, dst, dst_buf->stride, subpel_x, subpel_y, sf, w, h, mi->mbmi.interp_filter, xs, ys, @@ -1280,9 +1293,9 @@ void av1_build_obmc_inter_prediction(const AV1_COMMON *cm, MACROBLOCKD *xd, void modify_neighbor_predictor_for_obmc(MB_MODE_INFO *mbmi) { if (is_interintra_pred(mbmi)) { mbmi->ref_frame[1] = NONE; - } else if (has_second_ref(mbmi) && is_interinter_wedge_used(mbmi->sb_type) && - mbmi->interinter_compound == COMPOUND_WEDGE) { - mbmi->interinter_compound = COMPOUND_AVERAGE; + } else if (has_second_ref(mbmi) && + is_masked_compound_type(mbmi->interinter_compound_data.type)) { + mbmi->interinter_compound_data.type = COMPOUND_AVERAGE; mbmi->ref_frame[1] = NONE; } return; @@ -2080,22 +2093,22 @@ static void build_wedge_inter_predictor_from_buf( MACROBLOCKD_PLANE *const pd = &xd->plane[plane]; struct buf_2d *const dst_buf = &pd->dst; uint8_t *const dst = dst_buf->buf + dst_buf->stride * y + x; + const INTERINTER_COMPOUND_DATA *const comp_data = + &mbmi->interinter_compound_data; - if (is_compound && is_interinter_wedge_used(mbmi->sb_type) && - mbmi->interinter_compound == COMPOUND_WEDGE) { + if (is_compound && + is_masked_compound_type(mbmi->interinter_compound_data.type)) { #if CONFIG_AOM_HIGHBITDEPTH if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) build_masked_compound_wedge_highbd( dst, dst_buf->stride, CONVERT_TO_BYTEPTR(ext_dst0), ext_dst_stride0, - CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, - mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign, - mbmi->sb_type, h, w, xd->bd); + CONVERT_TO_BYTEPTR(ext_dst1), ext_dst_stride1, comp_data->wedge_index, + comp_data->wedge_sign, mbmi->sb_type, h, w, xd->bd); else #endif // CONFIG_AOM_HIGHBITDEPTH - build_masked_compound_wedge( - dst, dst_buf->stride, ext_dst0, ext_dst_stride0, ext_dst1, - ext_dst_stride1, mbmi->interinter_wedge_index, - mbmi->interinter_wedge_sign, mbmi->sb_type, h, w); + build_masked_compound(dst, dst_buf->stride, ext_dst0, ext_dst_stride0, + ext_dst1, ext_dst_stride1, comp_data, mbmi->sb_type, + h, w); } else { #if CONFIG_AOM_HIGHBITDEPTH if (xd->cur_buf->flags & YV12_FLAG_HIGHBITDEPTH) diff --git a/av1/common/reconinter.h b/av1/common/reconinter.h index 13f581e18..62a196fdb 100644 --- a/av1/common/reconinter.h +++ b/av1/common/reconinter.h @@ -522,6 +522,10 @@ const uint8_t *av1_get_soft_mask(int wedge_index, int wedge_sign, BLOCK_SIZE sb_type, int wedge_offset_x, int wedge_offset_y); +const uint8_t *av1_get_compound_type_mask( + const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE sb_type, + int invert); + void av1_build_interintra_predictors(MACROBLOCKD *xd, uint8_t *ypred, uint8_t *upred, uint8_t *vpred, int ystride, int ustride, int vstride, diff --git a/av1/decoder/decodeframe.c b/av1/decoder/decodeframe.c index 2b00c5197..8fba4cbbf 100644 --- a/av1/decoder/decodeframe.c +++ b/av1/decoder/decodeframe.c @@ -4327,10 +4327,8 @@ static int read_compressed_header(AV1Decoder *pbi, const uint8_t *data, } if (cm->reference_mode != SINGLE_REFERENCE) { for (i = 0; i < BLOCK_SIZES; i++) { - if (is_interinter_wedge_used(i)) { - for (j = 0; j < COMPOUND_TYPES - 1; j++) { - av1_diff_update_prob(&r, &fc->compound_type_prob[i][j], ACCT_STR); - } + for (j = 0; j < COMPOUND_TYPES - 1; j++) { + av1_diff_update_prob(&r, &fc->compound_type_prob[i][j], ACCT_STR); } } } diff --git a/av1/decoder/decodemv.c b/av1/decoder/decodemv.c index 101ed3e87..05dc0fa6e 100644 --- a/av1/decoder/decodemv.c +++ b/av1/decoder/decodemv.c @@ -1816,21 +1816,22 @@ static void read_inter_block_mode_info(AV1Decoder *const pbi, #endif // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION #if CONFIG_EXT_INTER - mbmi->interinter_compound = COMPOUND_AVERAGE; + mbmi->interinter_compound_data.type = COMPOUND_AVERAGE; if (cm->reference_mode != SINGLE_REFERENCE && - is_inter_compound_mode(mbmi->mode) && + is_inter_compound_mode(mbmi->mode) #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION - mbmi->motion_mode == SIMPLE_TRANSLATION && + && mbmi->motion_mode == SIMPLE_TRANSLATION #endif // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION - is_interinter_wedge_used(bsize)) { - mbmi->interinter_compound = aom_read_tree( + ) { + mbmi->interinter_compound_data.type = aom_read_tree( r, av1_compound_type_tree, cm->fc->compound_type_prob[bsize], ACCT_STR); if (xd->counts) - xd->counts->compound_interinter[bsize][mbmi->interinter_compound]++; - if (mbmi->interinter_compound == COMPOUND_WEDGE) { - mbmi->interinter_wedge_index = + xd->counts->compound_interinter[bsize] + [mbmi->interinter_compound_data.type]++; + if (mbmi->interinter_compound_data.type == COMPOUND_WEDGE) { + mbmi->interinter_compound_data.wedge_index = aom_read_literal(r, get_wedge_bits_lookup(bsize), ACCT_STR); - mbmi->interinter_wedge_sign = aom_read_bit(r, ACCT_STR); + mbmi->interinter_compound_data.wedge_sign = aom_read_bit(r, ACCT_STR); } } #endif // CONFIG_EXT_INTER diff --git a/av1/encoder/bitstream.c b/av1/encoder/bitstream.c index acce68891..d7d3701c0 100644 --- a/av1/encoder/bitstream.c +++ b/av1/encoder/bitstream.c @@ -1599,18 +1599,18 @@ static void pack_inter_mode_mvs(AV1_COMP *cpi, const MODE_INFO *mi, #if CONFIG_EXT_INTER if (cpi->common.reference_mode != SINGLE_REFERENCE && - is_inter_compound_mode(mbmi->mode) && + is_inter_compound_mode(mbmi->mode) #if CONFIG_MOTION_VAR - mbmi->motion_mode == SIMPLE_TRANSLATION && + && mbmi->motion_mode == SIMPLE_TRANSLATION #endif // CONFIG_MOTION_VAR - is_interinter_wedge_used(bsize)) { - av1_write_token(w, av1_compound_type_tree, - cm->fc->compound_type_prob[bsize], - &compound_type_encodings[mbmi->interinter_compound]); - if (mbmi->interinter_compound == COMPOUND_WEDGE) { - aom_write_literal(w, mbmi->interinter_wedge_index, + ) { + av1_write_token( + w, av1_compound_type_tree, cm->fc->compound_type_prob[bsize], + &compound_type_encodings[mbmi->interinter_compound_data.type]); + if (mbmi->interinter_compound_data.type == COMPOUND_WEDGE) { + aom_write_literal(w, mbmi->interinter_compound_data.wedge_index, get_wedge_bits_lookup(bsize)); - aom_write_bit(w, mbmi->interinter_wedge_sign); + aom_write_bit(w, mbmi->interinter_compound_data.wedge_sign); } } #endif // CONFIG_EXT_INTER @@ -4232,10 +4232,9 @@ static uint32_t write_compressed_header(AV1_COMP *cpi, uint8_t *data) { } if (cm->reference_mode != SINGLE_REFERENCE) { for (i = 0; i < BLOCK_SIZES; i++) - if (is_interinter_wedge_used(i)) - prob_diff_update(av1_compound_type_tree, fc->compound_type_prob[i], - cm->counts.compound_interinter[i], COMPOUND_TYPES, - probwt, header_bc); + prob_diff_update(av1_compound_type_tree, fc->compound_type_prob[i], + cm->counts.compound_interinter[i], COMPOUND_TYPES, + probwt, header_bc); } #endif // CONFIG_EXT_INTER diff --git a/av1/encoder/encodeframe.c b/av1/encoder/encodeframe.c index a768c1c43..50199b9ee 100644 --- a/av1/encoder/encodeframe.c +++ b/av1/encoder/encodeframe.c @@ -1992,12 +1992,13 @@ static void update_stats(const AV1_COMMON *const cm, ThreadData *td, int mi_row, #if CONFIG_EXT_INTER if (cm->reference_mode != SINGLE_REFERENCE && - is_inter_compound_mode(mbmi->mode) && + is_inter_compound_mode(mbmi->mode) #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION - mbmi->motion_mode == SIMPLE_TRANSLATION && + && mbmi->motion_mode == SIMPLE_TRANSLATION #endif // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION - is_interinter_wedge_used(bsize)) { - counts->compound_interinter[bsize][mbmi->interinter_compound]++; + ) { + counts->compound_interinter[bsize] + [mbmi->interinter_compound_data.type]++; } #endif // CONFIG_EXT_INTER } diff --git a/av1/encoder/rdopt.c b/av1/encoder/rdopt.c index 3f45757a3..8ead4acb2 100644 --- a/av1/encoder/rdopt.c +++ b/av1/encoder/rdopt.c @@ -4284,6 +4284,17 @@ static int cost_mv_ref(const AV1_COMP *const cpi, PREDICTION_MODE mode, #endif } +#if CONFIG_EXT_INTER +static int get_interinter_compound_type_bits(BLOCK_SIZE bsize, + COMPOUND_TYPE comp_type) { + switch (comp_type) { + case COMPOUND_AVERAGE: return 0; + case COMPOUND_WEDGE: return get_interinter_wedge_bits(bsize); + default: assert(0); return 0; + } +} +#endif // CONFIG_EXT_INTER + #if CONFIG_GLOBAL_MOTION #define GLOBAL_MOTION_COST_AMORTIZATION_BLKS 8 @@ -6466,19 +6477,18 @@ static void do_masked_motion_search(const AV1_COMP *const cpi, MACROBLOCK *x, } } -static void do_masked_motion_search_indexed(const AV1_COMP *const cpi, - MACROBLOCK *x, int wedge_index, - int wedge_sign, BLOCK_SIZE bsize, - int mi_row, int mi_col, - int_mv *tmp_mv, int *rate_mv, - int mv_idx[2], int which) { +static void do_masked_motion_search_indexed( + const AV1_COMP *const cpi, MACROBLOCK *x, + const INTERINTER_COMPOUND_DATA *const comp_data, BLOCK_SIZE bsize, + int mi_row, int mi_col, int_mv *tmp_mv, int *rate_mv, int mv_idx[2], + int which) { // NOTE: which values: 0 - 0 only, 1 - 1 only, 2 - both MACROBLOCKD *xd = &x->e_mbd; MB_MODE_INFO *mbmi = &xd->mi[0]->mbmi; BLOCK_SIZE sb_type = mbmi->sb_type; const uint8_t *mask; const int mask_stride = block_size_wide[bsize]; - mask = av1_get_contiguous_soft_mask(wedge_index, wedge_sign, sb_type); + mask = av1_get_compound_type_mask(comp_data, sb_type, 0); if (which == 0 || which == 2) do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col, @@ -6486,7 +6496,7 @@ static void do_masked_motion_search_indexed(const AV1_COMP *const cpi, if (which == 1 || which == 2) { // get the negative mask - mask = av1_get_contiguous_soft_mask(wedge_index, !wedge_sign, sb_type); + mask = av1_get_compound_type_mask(comp_data, sb_type, 1); do_masked_motion_search(cpi, x, mask, mask_stride, bsize, mi_row, mi_col, &tmp_mv[1], &rate_mv[1], 1, mv_idx[1]); } @@ -6827,8 +6837,8 @@ static int64_t pick_interinter_wedge(const AV1_COMP *const cpi, rd = pick_wedge(cpi, x, bsize, p0, p1, &wedge_sign, &wedge_index); } - mbmi->interinter_wedge_sign = wedge_sign; - mbmi->interinter_wedge_index = wedge_index; + mbmi->interinter_compound_data.wedge_sign = wedge_sign; + mbmi->interinter_compound_data.wedge_index = wedge_index; return rd; } @@ -6851,6 +6861,94 @@ static int64_t pick_interintra_wedge(const AV1_COMP *const cpi, mbmi->interintra_wedge_index = wedge_index; return rd; } + +static int interinter_compound_motion_search(const AV1_COMP *const cpi, + MACROBLOCK *x, + const BLOCK_SIZE bsize, + const int this_mode, int mi_row, + int mi_col) { + const MACROBLOCKD *const xd = &x->e_mbd; + MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi; + int_mv tmp_mv[2]; + int rate_mvs[2], tmp_rate_mv = 0; + if (this_mode == NEW_NEWMV) { + int mv_idxs[2] = { 0, 0 }; + do_masked_motion_search_indexed(cpi, x, &mbmi->interinter_compound_data, + bsize, mi_row, mi_col, tmp_mv, rate_mvs, + mv_idxs, 2); + tmp_rate_mv = rate_mvs[0] + rate_mvs[1]; + mbmi->mv[0].as_int = tmp_mv[0].as_int; + mbmi->mv[1].as_int = tmp_mv[1].as_int; + } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) { + int mv_idxs[2] = { 0, 0 }; + do_masked_motion_search_indexed(cpi, x, &mbmi->interinter_compound_data, + bsize, mi_row, mi_col, tmp_mv, rate_mvs, + mv_idxs, 0); + tmp_rate_mv = rate_mvs[0]; + mbmi->mv[0].as_int = tmp_mv[0].as_int; + } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) { + int mv_idxs[2] = { 0, 0 }; + do_masked_motion_search_indexed(cpi, x, &mbmi->interinter_compound_data, + bsize, mi_row, mi_col, tmp_mv, rate_mvs, + mv_idxs, 1); + tmp_rate_mv = rate_mvs[1]; + mbmi->mv[1].as_int = tmp_mv[1].as_int; + } + return tmp_rate_mv; +} + +static int64_t build_and_cost_compound_wedge( + const AV1_COMP *const cpi, MACROBLOCK *x, const int_mv *const cur_mv, + const BLOCK_SIZE bsize, const int this_mode, int rs2, int rate_mv, + int *out_rate_mv, uint8_t **preds0, uint8_t **preds1, int *strides, + int mi_row, int mi_col) { + MACROBLOCKD *xd = &x->e_mbd; + MB_MODE_INFO *const mbmi = &xd->mi[0]->mbmi; + int rate_sum; + int64_t dist_sum; + int64_t best_rd_cur = INT64_MAX; + int64_t rd = INT64_MAX; + int tmp_skip_txfm_sb; + int64_t tmp_skip_sse_sb; + + best_rd_cur = pick_interinter_wedge(cpi, x, bsize, *preds0, *preds1); + best_rd_cur += RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv, 0); + + if (have_newmv_in_inter_mode(this_mode)) { + *out_rate_mv = interinter_compound_motion_search(cpi, x, bsize, this_mode, + mi_row, mi_col); + av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize); + model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, + &tmp_skip_txfm_sb, &tmp_skip_sse_sb); + rd = RDCOST(x->rdmult, x->rddiv, rs2 + *out_rate_mv + rate_sum, dist_sum); + if (rd < best_rd_cur) { + best_rd_cur = rd; + } else { + mbmi->mv[0].as_int = cur_mv[0].as_int; + mbmi->mv[1].as_int = cur_mv[1].as_int; + *out_rate_mv = rate_mv; + av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides, + preds1, strides); + } + av1_subtract_plane(x, bsize, 0); + rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, + &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); + if (rd != INT64_MAX) + rd = RDCOST(x->rdmult, x->rddiv, rs2 + *out_rate_mv + rate_sum, dist_sum); + best_rd_cur = rd; + + } else { + av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, strides, + preds1, strides); + av1_subtract_plane(x, bsize, 0); + rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, + &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); + if (rd != INT64_MAX) + rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum); + best_rd_cur = rd; + } + return best_rd_cur; +} #endif // CONFIG_EXT_INTER static int64_t handle_inter_mode( @@ -6865,7 +6963,7 @@ static int64_t handle_inter_mode( #if CONFIG_EXT_INTER int_mv single_newmvs[2][TOTAL_REFS_PER_FRAME], int single_newmvs_rate[2][TOTAL_REFS_PER_FRAME], - int *compmode_interintra_cost, int *compmode_wedge_cost, + int *compmode_interintra_cost, int *compmode_interinter_cost, int64_t (*const modelled_rd)[TOTAL_REFS_PER_FRAME], #else int_mv single_newmv[TOTAL_REFS_PER_FRAME], @@ -6941,8 +7039,8 @@ static int64_t handle_inter_mode( #if CONFIG_EXT_INTER *compmode_interintra_cost = 0; mbmi->use_wedge_interintra = 0; - *compmode_wedge_cost = 0; - mbmi->interinter_compound = COMPOUND_AVERAGE; + *compmode_interinter_cost = 0; + mbmi->interinter_compound_data.type = COMPOUND_AVERAGE; // is_comp_interintra_pred implies !is_comp_pred assert(!is_comp_interintra_pred || (!is_comp_pred)); @@ -7351,141 +7449,107 @@ static int64_t handle_inter_mode( #endif // CONFIG_MOTION_VAR #endif // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION - if (is_comp_pred && is_interinter_wedge_used(bsize)) { + if (is_comp_pred) { int rate_sum, rs2; int64_t dist_sum; - int64_t best_rd_nowedge = INT64_MAX; - int64_t best_rd_wedge = INT64_MAX; + int64_t best_rd_compound = INT64_MAX, best_rd_cur = INT64_MAX; + INTERINTER_COMPOUND_DATA best_compound_data; + int_mv best_mv[2]; + int best_tmp_rate_mv = rate_mv; int tmp_skip_txfm_sb; int64_t tmp_skip_sse_sb; int compound_type_cost[COMPOUND_TYPES]; + uint8_t pred0[2 * MAX_SB_SQUARE]; + uint8_t pred1[2 * MAX_SB_SQUARE]; + uint8_t *preds0[1] = { pred0 }; + uint8_t *preds1[1] = { pred1 }; + int strides[1] = { bw }; + COMPOUND_TYPE cur_type; - mbmi->interinter_compound = COMPOUND_AVERAGE; + best_mv[0].as_int = cur_mv[0].as_int; + best_mv[1].as_int = cur_mv[1].as_int; + memset(&best_compound_data, 0, sizeof(INTERINTER_COMPOUND_DATA)); av1_cost_tokens(compound_type_cost, cm->fc->compound_type_prob[bsize], av1_compound_type_tree); - rs2 = compound_type_cost[mbmi->interinter_compound]; - av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize); - av1_subtract_plane(x, bsize, 0); - rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, - &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); - if (rd != INT64_MAX) - rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum); - best_rd_nowedge = rd; - - // Disbale wedge search if source variance is small - if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh && - best_rd_nowedge / 3 < ref_best_rd) { - uint8_t pred0[2 * MAX_SB_SQUARE]; - uint8_t pred1[2 * MAX_SB_SQUARE]; - uint8_t *preds0[1] = { pred0 }; - uint8_t *preds1[1] = { pred1 }; - int strides[1] = { bw }; - - mbmi->interinter_compound = COMPOUND_WEDGE; - rs2 = av1_cost_literal(get_interinter_wedge_bits(bsize)) + - compound_type_cost[mbmi->interinter_compound]; + if (is_interinter_wedge_used(bsize)) { + // get inter predictors to use for masked compound modes av1_build_inter_predictors_for_planes_single_buf( xd, bsize, 0, 0, mi_row, mi_col, 0, preds0, strides); av1_build_inter_predictors_for_planes_single_buf( xd, bsize, 0, 0, mi_row, mi_col, 1, preds1, strides); + } - // Choose the best wedge - best_rd_wedge = pick_interinter_wedge(cpi, x, bsize, pred0, pred1); - best_rd_wedge += RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv, 0); + for (cur_type = COMPOUND_AVERAGE; cur_type < COMPOUND_TYPES; cur_type++) { + best_rd_cur = INT64_MAX; + mbmi->interinter_compound_data.type = cur_type; + rs2 = av1_cost_literal(get_interinter_compound_type_bits( + bsize, mbmi->interinter_compound_data.type)) + + compound_type_cost[mbmi->interinter_compound_data.type]; - if (have_newmv_in_inter_mode(this_mode)) { - int_mv tmp_mv[2]; - int rate_mvs[2], tmp_rate_mv = 0; - if (this_mode == NEW_NEWMV) { - int mv_idxs[2] = { 0, 0 }; - do_masked_motion_search_indexed( - cpi, x, mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign, - bsize, mi_row, mi_col, tmp_mv, rate_mvs, mv_idxs, 2); - tmp_rate_mv = rate_mvs[0] + rate_mvs[1]; - mbmi->mv[0].as_int = tmp_mv[0].as_int; - mbmi->mv[1].as_int = tmp_mv[1].as_int; - } else if (this_mode == NEW_NEARESTMV || this_mode == NEW_NEARMV) { - int mv_idxs[2] = { 0, 0 }; - do_masked_motion_search_indexed( - cpi, x, mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign, - bsize, mi_row, mi_col, tmp_mv, rate_mvs, mv_idxs, 0); - tmp_rate_mv = rate_mvs[0]; - mbmi->mv[0].as_int = tmp_mv[0].as_int; - } else if (this_mode == NEAREST_NEWMV || this_mode == NEAR_NEWMV) { - int mv_idxs[2] = { 0, 0 }; - do_masked_motion_search_indexed( - cpi, x, mbmi->interinter_wedge_index, mbmi->interinter_wedge_sign, - bsize, mi_row, mi_col, tmp_mv, rate_mvs, mv_idxs, 1); - tmp_rate_mv = rate_mvs[1]; - mbmi->mv[1].as_int = tmp_mv[1].as_int; - } - av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize); - model_rd_for_sb(cpi, bsize, x, xd, 0, 0, &rate_sum, &dist_sum, - &tmp_skip_txfm_sb, &tmp_skip_sse_sb); - rd = - RDCOST(x->rdmult, x->rddiv, rs2 + tmp_rate_mv + rate_sum, dist_sum); - if (rd < best_rd_wedge) { - best_rd_wedge = rd; - } else { - mbmi->mv[0].as_int = cur_mv[0].as_int; - mbmi->mv[1].as_int = cur_mv[1].as_int; - tmp_rate_mv = rate_mv; - av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, - strides, preds1, strides); - } - av1_subtract_plane(x, bsize, 0); - rd = - estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, - &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); - if (rd != INT64_MAX) - rd = RDCOST(x->rdmult, x->rddiv, rs2 + tmp_rate_mv + rate_sum, - dist_sum); - best_rd_wedge = rd; + switch (cur_type) { + case COMPOUND_AVERAGE: + av1_build_inter_predictors_sby(xd, mi_row, mi_col, bsize); + av1_subtract_plane(x, bsize, 0); + rd = estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, + &tmp_skip_txfm_sb, &tmp_skip_sse_sb, + INT64_MAX); + if (rd != INT64_MAX) + rd = + RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum); + best_rd_compound = rd; + break; + case COMPOUND_WEDGE: + if (!is_interinter_wedge_used(bsize)) break; + if (x->source_variance > cpi->sf.disable_wedge_search_var_thresh && + best_rd_compound / 3 < ref_best_rd) { + int tmp_rate_mv = 0; + best_rd_cur = build_and_cost_compound_wedge( + cpi, x, cur_mv, bsize, this_mode, rs2, rate_mv, &tmp_rate_mv, + preds0, preds1, strides, mi_row, mi_col); - if (best_rd_wedge < best_rd_nowedge) { - mbmi->interinter_compound = COMPOUND_WEDGE; - xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int; - xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int; - rd_stats->rate += tmp_rate_mv - rate_mv; - rate_mv = tmp_rate_mv; - } else { - mbmi->interinter_compound = COMPOUND_AVERAGE; - mbmi->mv[0].as_int = cur_mv[0].as_int; - mbmi->mv[1].as_int = cur_mv[1].as_int; - xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int; - xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int; - } - } else { - av1_build_wedge_inter_predictor_from_buf(xd, bsize, 0, 0, preds0, - strides, preds1, strides); - av1_subtract_plane(x, bsize, 0); - rd = - estimate_yrd_for_sb(cpi, bsize, x, &rate_sum, &dist_sum, - &tmp_skip_txfm_sb, &tmp_skip_sse_sb, INT64_MAX); - if (rd != INT64_MAX) - rd = RDCOST(x->rdmult, x->rddiv, rs2 + rate_mv + rate_sum, dist_sum); - best_rd_wedge = rd; - if (best_rd_wedge < best_rd_nowedge) { - mbmi->interinter_compound = COMPOUND_WEDGE; - } else { - mbmi->interinter_compound = COMPOUND_AVERAGE; - } + if (best_rd_cur < best_rd_compound) { + best_rd_compound = best_rd_cur; + memcpy(&best_compound_data, &mbmi->interinter_compound_data, + sizeof(best_compound_data)); + if (have_newmv_in_inter_mode(this_mode)) { + best_tmp_rate_mv = tmp_rate_mv; + best_mv[0].as_int = mbmi->mv[0].as_int; + best_mv[1].as_int = mbmi->mv[1].as_int; + // reset to original mvs for next iteration + mbmi->mv[0].as_int = cur_mv[0].as_int; + mbmi->mv[1].as_int = cur_mv[1].as_int; + } + } + } + break; + default: assert(0); return 0; } } - if (ref_best_rd < INT64_MAX && - AOMMIN(best_rd_wedge, best_rd_nowedge) / 3 > ref_best_rd) { + memcpy(&mbmi->interinter_compound_data, &best_compound_data, + sizeof(INTERINTER_COMPOUND_DATA)); + if (have_newmv_in_inter_mode(this_mode)) { + mbmi->mv[0].as_int = best_mv[0].as_int; + mbmi->mv[1].as_int = best_mv[1].as_int; + xd->mi[0]->bmi[0].as_mv[0].as_int = mbmi->mv[0].as_int; + xd->mi[0]->bmi[0].as_mv[1].as_int = mbmi->mv[1].as_int; + if (mbmi->interinter_compound_data.type) { + rd_stats->rate += best_tmp_rate_mv - rate_mv; + rate_mv = best_tmp_rate_mv; + } + } + + if (ref_best_rd < INT64_MAX && best_rd_compound / 3 > ref_best_rd) { restore_dst_buf(xd, orig_dst, orig_dst_stride); return INT64_MAX; } pred_exists = 0; - *compmode_wedge_cost = compound_type_cost[mbmi->interinter_compound]; - - if (mbmi->interinter_compound == COMPOUND_WEDGE) - *compmode_wedge_cost += - av1_cost_literal(get_interinter_wedge_bits(bsize)); + *compmode_interinter_cost = + compound_type_cost[mbmi->interinter_compound_data.type] + + av1_cost_literal(get_interinter_compound_type_bits( + bsize, mbmi->interinter_compound_data.type)); } if (is_comp_interintra_pred) { @@ -8782,7 +8846,7 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, int compmode_cost = 0; #if CONFIG_EXT_INTER int compmode_interintra_cost = 0; - int compmode_wedge_cost = 0; + int compmode_interinter_cost = 0; #endif // CONFIG_EXT_INTER int rate2 = 0, rate_y = 0, rate_uv = 0; int64_t distortion2 = 0, distortion_y = 0, distortion_uv = 0; @@ -9184,7 +9248,7 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, #endif // CONFIG_MOTION_VAR #if CONFIG_EXT_INTER single_newmvs, single_newmvs_rate, &compmode_interintra_cost, - &compmode_wedge_cost, modelled_rd, + &compmode_interinter_cost, modelled_rd, #else single_newmv, #endif // CONFIG_EXT_INTER @@ -9280,7 +9344,7 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, int dummy_single_newmvs_rate[2][TOTAL_REFS_PER_FRAME] = { { 0 }, { 0 } }; int dummy_compmode_interintra_cost = 0; - int dummy_compmode_wedge_cost = 0; + int dummy_compmode_interinter_cost = 0; #else int_mv dummy_single_newmv[TOTAL_REFS_PER_FRAME] = { { 0 } }; #endif @@ -9295,8 +9359,8 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, #endif // CONFIG_MOTION_VAR #if CONFIG_EXT_INTER dummy_single_newmvs, dummy_single_newmvs_rate, - &dummy_compmode_interintra_cost, &dummy_compmode_wedge_cost, - NULL, + &dummy_compmode_interintra_cost, + &dummy_compmode_interinter_cost, NULL, #else dummy_single_newmv, #endif @@ -9396,7 +9460,7 @@ void av1_rd_pick_inter_mode_sb(const AV1_COMP *cpi, TileDataEnc *tile_data, #if CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION if (mbmi->motion_mode == SIMPLE_TRANSLATION) #endif // CONFIG_MOTION_VAR || CONFIG_WARPED_MOTION - rate2 += compmode_wedge_cost; + rate2 += compmode_interinter_cost; #endif // CONFIG_EXT_INTER // Estimate the reference frame signaling cost and add it @@ -10278,7 +10342,7 @@ void av1_rd_pick_inter_mode_sub8x8(const struct AV1_COMP *cpi, #endif // CONFIG_FILTER_INTRA mbmi->motion_mode = SIMPLE_TRANSLATION; #if CONFIG_EXT_INTER - mbmi->interinter_compound = COMPOUND_AVERAGE; + mbmi->interinter_compound_data.type = COMPOUND_AVERAGE; mbmi->use_wedge_interintra = 0; #endif // CONFIG_EXT_INTER #if CONFIG_WARPED_MOTION