diff --git a/vp9/common/vp9_pred_common.h b/vp9/common/vp9_pred_common.h index 19032bf62..919093080 100644 --- a/vp9/common/vp9_pred_common.h +++ b/vp9/common/vp9_pred_common.h @@ -109,32 +109,40 @@ static INLINE vp9_prob vp9_get_pred_prob_single_ref_p2(const VP9_COMMON *cm, unsigned char vp9_get_pred_context_tx_size(const MACROBLOCKD *xd); -static const vp9_prob *get_tx_probs(BLOCK_SIZE bsize, uint8_t context, +static const vp9_prob *get_tx_probs(TX_SIZE max_tx_size, int ctx, const struct tx_probs *tx_probs) { - if (bsize < BLOCK_16X16) - return tx_probs->p8x8[context]; - else if (bsize < BLOCK_32X32) - return tx_probs->p16x16[context]; - else - return tx_probs->p32x32[context]; + switch (max_tx_size) { + case TX_8X8: + return tx_probs->p8x8[ctx]; + case TX_16X16: + return tx_probs->p16x16[ctx]; + case TX_32X32: + return tx_probs->p32x32[ctx]; + default: + assert(!"Invalid max_tx_size."); + return NULL; + } } -static const vp9_prob *get_tx_probs2(const MACROBLOCKD *xd, - const struct tx_probs *tx_probs, - const MODE_INFO *m) { - const BLOCK_SIZE bsize = m->mbmi.sb_type; - const int context = vp9_get_pred_context_tx_size(xd); - return get_tx_probs(bsize, context, tx_probs); +static const vp9_prob *get_tx_probs2(TX_SIZE max_tx_size, const MACROBLOCKD *xd, + const struct tx_probs *tx_probs) { + const int ctx = vp9_get_pred_context_tx_size(xd); + return get_tx_probs(max_tx_size, ctx, tx_probs); } -static unsigned int *get_tx_counts(BLOCK_SIZE bsize, uint8_t context, +static unsigned int *get_tx_counts(TX_SIZE max_tx_size, int ctx, struct tx_counts *tx_counts) { - if (bsize < BLOCK_16X16) - return tx_counts->p8x8[context]; - else if (bsize < BLOCK_32X32) - return tx_counts->p16x16[context]; - else - return tx_counts->p32x32[context]; + switch (max_tx_size) { + case TX_8X8: + return tx_counts->p8x8[ctx]; + case TX_16X16: + return tx_counts->p16x16[ctx]; + case TX_32X32: + return tx_counts->p32x32[ctx]; + default: + assert(!"Invalid max_tx_size."); + return NULL; + } } #endif // VP9_COMMON_VP9_PRED_COMMON_H_ diff --git a/vp9/decoder/vp9_decodemv.c b/vp9/decoder/vp9_decodemv.c index 1ca578621..7a73afecf 100644 --- a/vp9/decoder/vp9_decodemv.c +++ b/vp9/decoder/vp9_decodemv.c @@ -61,31 +61,28 @@ static int read_segment_id(vp9_reader *r, const struct segmentation *seg) { } static TX_SIZE read_selected_tx_size(VP9_COMMON *cm, MACROBLOCKD *xd, - BLOCK_SIZE bsize, vp9_reader *r) { - const uint8_t context = vp9_get_pred_context_tx_size(xd); - const vp9_prob *tx_probs = get_tx_probs(bsize, context, &cm->fc.tx_probs); + TX_SIZE max_tx_size, vp9_reader *r) { + const int ctx = vp9_get_pred_context_tx_size(xd); + const vp9_prob *tx_probs = get_tx_probs(max_tx_size, ctx, &cm->fc.tx_probs); TX_SIZE tx_size = vp9_read(r, tx_probs[0]); - if (tx_size != TX_4X4 && bsize >= BLOCK_16X16) { + if (tx_size != TX_4X4 && max_tx_size >= TX_16X16) { tx_size += vp9_read(r, tx_probs[1]); - if (tx_size != TX_8X8 && bsize >= BLOCK_32X32) + if (tx_size != TX_8X8 && max_tx_size >= TX_32X32) tx_size += vp9_read(r, tx_probs[2]); } if (!cm->frame_parallel_decoding_mode) - ++get_tx_counts(bsize, context, &cm->counts.tx)[tx_size]; + ++get_tx_counts(max_tx_size, ctx, &cm->counts.tx)[tx_size]; return tx_size; } -static TX_SIZE read_tx_size(VP9_COMMON *const cm, MACROBLOCKD *const xd, - TX_MODE tx_mode, BLOCK_SIZE bsize, int allow_select, - vp9_reader *r) { - if (allow_select && tx_mode == TX_MODE_SELECT && bsize >= BLOCK_8X8) { - return read_selected_tx_size(cm, xd, bsize, r); - } else { - const TX_SIZE max_tx_size_block = max_txsize_lookup[bsize]; - const TX_SIZE max_tx_size_txmode = tx_mode_to_biggest_tx_size[tx_mode]; - return MIN(max_tx_size_block, max_tx_size_txmode); - } +static TX_SIZE read_tx_size(VP9_COMMON *cm, MACROBLOCKD *xd, TX_MODE tx_mode, + BLOCK_SIZE bsize, int allow_select, vp9_reader *r) { + const TX_SIZE max_tx_size = max_txsize_lookup[bsize]; + if (allow_select && tx_mode == TX_MODE_SELECT && bsize >= BLOCK_8X8) + return read_selected_tx_size(cm, xd, max_tx_size, r); + else + return MIN(max_tx_size, tx_mode_to_biggest_tx_size[tx_mode]); } static void set_segment_id(VP9_COMMON *cm, BLOCK_SIZE bsize, diff --git a/vp9/encoder/vp9_bitstream.c b/vp9/encoder/vp9_bitstream.c index 87bd36c2b..4d80b71e3 100644 --- a/vp9/encoder/vp9_bitstream.c +++ b/vp9/encoder/vp9_bitstream.c @@ -191,12 +191,14 @@ static void update_mbintra_mode_probs(VP9_COMP* const cpi, static void write_selected_tx_size(const VP9_COMP *cpi, MODE_INFO *m, TX_SIZE tx_size, BLOCK_SIZE bsize, vp9_writer *w) { + const TX_SIZE max_tx_size = max_txsize_lookup[bsize]; const MACROBLOCKD *const xd = &cpi->mb.e_mbd; - const vp9_prob *tx_probs = get_tx_probs2(xd, &cpi->common.fc.tx_probs, m); + const vp9_prob *const tx_probs = get_tx_probs2(max_tx_size, xd, + &cpi->common.fc.tx_probs); vp9_write(w, tx_size != TX_4X4, tx_probs[0]); - if (bsize >= BLOCK_16X16 && tx_size != TX_4X4) { + if (tx_size != TX_4X4 && max_tx_size >= TX_16X16) { vp9_write(w, tx_size != TX_8X8, tx_probs[1]); - if (bsize >= BLOCK_32X32 && tx_size != TX_8X8) + if (tx_size != TX_8X8 && max_tx_size >= TX_32X32) vp9_write(w, tx_size != TX_16X16, tx_probs[2]); } } diff --git a/vp9/encoder/vp9_encodeframe.c b/vp9/encoder/vp9_encodeframe.c index 86332bcf9..3583e54be 100644 --- a/vp9/encoder/vp9_encodeframe.c +++ b/vp9/encoder/vp9_encodeframe.c @@ -2486,7 +2486,8 @@ static void encode_superblock(VP9_COMP *cpi, TOKENEXTRA **t, int output_enabled, (mbmi->skip_coeff || vp9_segfeature_active(&cm->seg, segment_id, SEG_LVL_SKIP)))) { const uint8_t context = vp9_get_pred_context_tx_size(xd); - ++get_tx_counts(bsize, context, &cm->counts.tx)[mbmi->tx_size]; + ++get_tx_counts(max_txsize_lookup[bsize], + context, &cm->counts.tx)[mbmi->tx_size]; } else { int x, y; TX_SIZE sz = tx_mode_to_biggest_tx_size[cm->tx_mode]; diff --git a/vp9/encoder/vp9_rdopt.c b/vp9/encoder/vp9_rdopt.c index 993919e5b..f1ef9e503 100644 --- a/vp9/encoder/vp9_rdopt.c +++ b/vp9/encoder/vp9_rdopt.c @@ -740,7 +740,7 @@ static void choose_txfm_size_from_rd(VP9_COMP *cpi, MACROBLOCK *x, int n, m; int s0, s1; - const vp9_prob *tx_probs = get_tx_probs2(xd, &cm->fc.tx_probs, xd->mi_8x8[0]); + const vp9_prob *tx_probs = get_tx_probs2(max_tx_size, xd, &cm->fc.tx_probs); for (n = TX_4X4; n <= max_tx_size; n++) { r[n][1] = r[n][0]; @@ -845,7 +845,7 @@ static void choose_txfm_size_from_modelrd(VP9_COMP *cpi, MACROBLOCK *x, double scale_rd[TX_SIZES] = {1.73, 1.44, 1.20, 1.00}; // double scale_r[TX_SIZES] = {2.82, 2.00, 1.41, 1.00}; - const vp9_prob *tx_probs = get_tx_probs2(xd, &cm->fc.tx_probs, xd->mi_8x8[0]); + const vp9_prob *tx_probs = get_tx_probs2(max_tx_size, xd, &cm->fc.tx_probs); // for (n = TX_4X4; n <= max_txfm_size; n++) // r[n][0] = (r[n][0] * scale_r[n]);